mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 21:33:53 +01:00
Merge branch 'develop' of github.com:matrix-org/synapse into release-v0.13.0
This commit is contained in:
commit
e66d0bd03a
91 changed files with 1796 additions and 1101 deletions
|
@ -51,3 +51,6 @@ Steven Hammerton <steven.hammerton at openmarket.com>
|
||||||
|
|
||||||
Mads Robin Christensen <mads at v42 dot dk>
|
Mads Robin Christensen <mads at v42 dot dk>
|
||||||
* CentOS 7 installation instructions.
|
* CentOS 7 installation instructions.
|
||||||
|
|
||||||
|
Florent Violleau <floviolleau at gmail dot com>
|
||||||
|
* Add Raspberry Pi installation instructions and general troubleshooting items
|
21
README.rst
21
README.rst
|
@ -125,6 +125,15 @@ Installing prerequisites on Mac OS X::
|
||||||
sudo easy_install pip
|
sudo easy_install pip
|
||||||
sudo pip install virtualenv
|
sudo pip install virtualenv
|
||||||
|
|
||||||
|
Installing prerequisites on Raspbian::
|
||||||
|
|
||||||
|
sudo apt-get install build-essential python2.7-dev libffi-dev \
|
||||||
|
python-pip python-setuptools sqlite3 \
|
||||||
|
libssl-dev python-virtualenv libjpeg-dev
|
||||||
|
sudo pip install --upgrade pip
|
||||||
|
sudo pip install --upgrade ndg-httpsclient
|
||||||
|
sudo pip install --upgrade virtualenv
|
||||||
|
|
||||||
To install the synapse homeserver run::
|
To install the synapse homeserver run::
|
||||||
|
|
||||||
virtualenv -p python2.7 ~/.synapse
|
virtualenv -p python2.7 ~/.synapse
|
||||||
|
@ -310,6 +319,18 @@ may need to manually upgrade it::
|
||||||
|
|
||||||
sudo pip install --upgrade pip
|
sudo pip install --upgrade pip
|
||||||
|
|
||||||
|
Installing may fail with ``Could not find any downloads that satisfy the requirement pymacaroons-pynacl (from matrix-synapse==0.12.0)``.
|
||||||
|
You can fix this by manually upgrading pip and virtualenv::
|
||||||
|
|
||||||
|
sudo pip install --upgrade virtualenv
|
||||||
|
|
||||||
|
You can next rerun ``virtualenv -p python2.7 synapse`` to update the virtual env.
|
||||||
|
|
||||||
|
Installing may fail during installing virtualenv with ``InsecurePlatformWarning: A true SSLContext object is not available. This prevents urllib3 from configuring SSL appropriately and may cause certain SSL connections to fail. For more information, see https://urllib3.readthedocs.org/en/latest/security.html#insecureplatformwarning.``
|
||||||
|
You can fix this by manually installing ndg-httpsclient::
|
||||||
|
|
||||||
|
pip install --upgrade ndg-httpsclient
|
||||||
|
|
||||||
Installing may fail with ``mock requires setuptools>=17.1. Aborting installation``.
|
Installing may fail with ``mock requires setuptools>=17.1. Aborting installation``.
|
||||||
You can fix this by upgrading setuptools::
|
You can fix this by upgrading setuptools::
|
||||||
|
|
||||||
|
|
24
scripts-dev/dump_macaroon.py
Executable file
24
scripts-dev/dump_macaroon.py
Executable file
|
@ -0,0 +1,24 @@
|
||||||
|
#!/usr/bin/env python2
|
||||||
|
|
||||||
|
import pymacaroons
|
||||||
|
import sys
|
||||||
|
|
||||||
|
if len(sys.argv) == 1:
|
||||||
|
sys.stderr.write("usage: %s macaroon [key]\n" % (sys.argv[0],))
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
macaroon_string = sys.argv[1]
|
||||||
|
key = sys.argv[2] if len(sys.argv) > 2 else None
|
||||||
|
|
||||||
|
macaroon = pymacaroons.Macaroon.deserialize(macaroon_string)
|
||||||
|
print macaroon.inspect()
|
||||||
|
|
||||||
|
print ""
|
||||||
|
|
||||||
|
verifier = pymacaroons.Verifier()
|
||||||
|
verifier.satisfy_general(lambda c: True)
|
||||||
|
try:
|
||||||
|
verifier.verify(macaroon, key)
|
||||||
|
print "Signature is correct"
|
||||||
|
except Exception as e:
|
||||||
|
print e.message
|
62
scripts-dev/list_url_patterns.py
Executable file
62
scripts-dev/list_url_patterns.py
Executable file
|
@ -0,0 +1,62 @@
|
||||||
|
#! /usr/bin/python
|
||||||
|
|
||||||
|
import ast
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
PATTERNS_V1 = []
|
||||||
|
PATTERNS_V2 = []
|
||||||
|
|
||||||
|
RESULT = {
|
||||||
|
"v1": PATTERNS_V1,
|
||||||
|
"v2": PATTERNS_V2,
|
||||||
|
}
|
||||||
|
|
||||||
|
class CallVisitor(ast.NodeVisitor):
|
||||||
|
def visit_Call(self, node):
|
||||||
|
if isinstance(node.func, ast.Name):
|
||||||
|
name = node.func.id
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
if name == "client_path_patterns":
|
||||||
|
PATTERNS_V1.append(node.args[0].s)
|
||||||
|
elif name == "client_v2_patterns":
|
||||||
|
PATTERNS_V2.append(node.args[0].s)
|
||||||
|
|
||||||
|
|
||||||
|
def find_patterns_in_code(input_code):
|
||||||
|
input_ast = ast.parse(input_code)
|
||||||
|
visitor = CallVisitor()
|
||||||
|
visitor.visit(input_ast)
|
||||||
|
|
||||||
|
|
||||||
|
def find_patterns_in_file(filepath):
|
||||||
|
with open(filepath) as f:
|
||||||
|
find_patterns_in_code(f.read())
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Find url patterns.')
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"directories", nargs='+', metavar="DIR",
|
||||||
|
help="Directories to search for definitions"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
for directory in args.directories:
|
||||||
|
for root, dirs, files in os.walk(directory):
|
||||||
|
for filename in files:
|
||||||
|
if filename.endswith(".py"):
|
||||||
|
filepath = os.path.join(root, filename)
|
||||||
|
find_patterns_in_file(filepath)
|
||||||
|
|
||||||
|
PATTERNS_V1.sort()
|
||||||
|
PATTERNS_V2.sort()
|
||||||
|
|
||||||
|
yaml.dump(RESULT, sys.stdout, default_flow_style=False)
|
|
@ -16,3 +16,4 @@ ignore =
|
||||||
|
|
||||||
[flake8]
|
[flake8]
|
||||||
max-line-length = 90
|
max-line-length = 90
|
||||||
|
ignore = W503 ; W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it.
|
||||||
|
|
|
@ -24,6 +24,7 @@ from synapse.api.constants import EventTypes, Membership, JoinRules
|
||||||
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
|
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
|
||||||
from synapse.types import Requester, RoomID, UserID, EventID
|
from synapse.types import Requester, RoomID, UserID, EventID
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
|
from synapse.util.logcontext import preserve_context_over_fn
|
||||||
from unpaddedbase64 import decode_base64
|
from unpaddedbase64 import decode_base64
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
@ -529,7 +530,8 @@ class Auth(object):
|
||||||
default=[""]
|
default=[""]
|
||||||
)[0]
|
)[0]
|
||||||
if user and access_token and ip_addr:
|
if user and access_token and ip_addr:
|
||||||
self.store.insert_client_ip(
|
preserve_context_over_fn(
|
||||||
|
self.store.insert_client_ip,
|
||||||
user=user,
|
user=user,
|
||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
ip=ip_addr,
|
ip=ip_addr,
|
||||||
|
@ -696,6 +698,7 @@ class Auth(object):
|
||||||
def _look_up_user_by_access_token(self, token):
|
def _look_up_user_by_access_token(self, token):
|
||||||
ret = yield self.store.get_user_by_access_token(token)
|
ret = yield self.store.get_user_by_access_token(token)
|
||||||
if not ret:
|
if not ret:
|
||||||
|
logger.warn("Unrecognised access token - not in store: %s" % (token,))
|
||||||
raise AuthError(
|
raise AuthError(
|
||||||
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.",
|
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.",
|
||||||
errcode=Codes.UNKNOWN_TOKEN
|
errcode=Codes.UNKNOWN_TOKEN
|
||||||
|
@ -713,6 +716,7 @@ class Auth(object):
|
||||||
token = request.args["access_token"][0]
|
token = request.args["access_token"][0]
|
||||||
service = yield self.store.get_app_service_by_token(token)
|
service = yield self.store.get_app_service_by_token(token)
|
||||||
if not service:
|
if not service:
|
||||||
|
logger.warn("Unrecognised appservice access token: %s" % (token,))
|
||||||
raise AuthError(
|
raise AuthError(
|
||||||
self.TOKEN_NOT_FOUND_HTTP_STATUS,
|
self.TOKEN_NOT_FOUND_HTTP_STATUS,
|
||||||
"Unrecognised access token.",
|
"Unrecognised access token.",
|
||||||
|
|
|
@ -23,5 +23,6 @@ WEB_CLIENT_PREFIX = "/_matrix/client"
|
||||||
CONTENT_REPO_PREFIX = "/_matrix/content"
|
CONTENT_REPO_PREFIX = "/_matrix/content"
|
||||||
SERVER_KEY_PREFIX = "/_matrix/key/v1"
|
SERVER_KEY_PREFIX = "/_matrix/key/v1"
|
||||||
SERVER_KEY_V2_PREFIX = "/_matrix/key/v2"
|
SERVER_KEY_V2_PREFIX = "/_matrix/key/v2"
|
||||||
MEDIA_PREFIX = "/_matrix/media/v1"
|
MEDIA_PREFIX = "/_matrix/media/r0"
|
||||||
|
LEGACY_MEDIA_PREFIX = "/_matrix/media/v1"
|
||||||
APP_SERVICE_PREFIX = "/_matrix/appservice/v1"
|
APP_SERVICE_PREFIX = "/_matrix/appservice/v1"
|
||||||
|
|
|
@ -12,3 +12,22 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import sys
|
||||||
|
sys.dont_write_bytecode = True
|
||||||
|
|
||||||
|
from synapse.python_dependencies import (
|
||||||
|
check_requirements, MissingRequirementError
|
||||||
|
) # NOQA
|
||||||
|
|
||||||
|
try:
|
||||||
|
check_requirements()
|
||||||
|
except MissingRequirementError as e:
|
||||||
|
message = "\n".join([
|
||||||
|
"Missing Requirement: %s" % (e.message,),
|
||||||
|
"To install run:",
|
||||||
|
" pip install --upgrade --force \"%s\"" % (e.dependency,),
|
||||||
|
"",
|
||||||
|
])
|
||||||
|
sys.stderr.writelines(message)
|
||||||
|
sys.exit(1)
|
||||||
|
|
|
@ -14,27 +14,23 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import sys
|
import synapse
|
||||||
from synapse.rest import ClientRestResource
|
|
||||||
|
import contextlib
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import resource
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from synapse.config._base import ConfigError
|
||||||
|
|
||||||
sys.dont_write_bytecode = True
|
|
||||||
from synapse.python_dependencies import (
|
from synapse.python_dependencies import (
|
||||||
check_requirements, DEPENDENCY_LINKS, MissingRequirementError
|
check_requirements, DEPENDENCY_LINKS
|
||||||
)
|
)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
from synapse.rest import ClientRestResource
|
||||||
try:
|
|
||||||
check_requirements()
|
|
||||||
except MissingRequirementError as e:
|
|
||||||
message = "\n".join([
|
|
||||||
"Missing Requirement: %s" % (e.message,),
|
|
||||||
"To install run:",
|
|
||||||
" pip install --upgrade --force \"%s\"" % (e.dependency,),
|
|
||||||
"",
|
|
||||||
])
|
|
||||||
sys.stderr.writelines(message)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
from synapse.storage.engines import create_engine, IncorrectDatabaseSetup
|
from synapse.storage.engines import create_engine, IncorrectDatabaseSetup
|
||||||
from synapse.storage import are_all_users_on_domain
|
from synapse.storage import are_all_users_on_domain
|
||||||
from synapse.storage.prepare_database import UpgradeDatabaseException
|
from synapse.storage.prepare_database import UpgradeDatabaseException
|
||||||
|
@ -60,7 +56,7 @@ from synapse.rest.key.v1.server_key_resource import LocalKey
|
||||||
from synapse.rest.key.v2 import KeyApiV2Resource
|
from synapse.rest.key.v2 import KeyApiV2Resource
|
||||||
from synapse.api.urls import (
|
from synapse.api.urls import (
|
||||||
FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX,
|
FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX,
|
||||||
SERVER_KEY_PREFIX, MEDIA_PREFIX, STATIC_PREFIX,
|
SERVER_KEY_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_PREFIX, STATIC_PREFIX,
|
||||||
SERVER_KEY_V2_PREFIX,
|
SERVER_KEY_V2_PREFIX,
|
||||||
)
|
)
|
||||||
from synapse.config.homeserver import HomeServerConfig
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
|
@ -73,17 +69,6 @@ from synapse import events
|
||||||
|
|
||||||
from daemonize import Daemonize
|
from daemonize import Daemonize
|
||||||
|
|
||||||
import synapse
|
|
||||||
|
|
||||||
import contextlib
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import resource
|
|
||||||
import subprocess
|
|
||||||
import time
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger("synapse.app.homeserver")
|
logger = logging.getLogger("synapse.app.homeserver")
|
||||||
|
|
||||||
|
|
||||||
|
@ -163,8 +148,10 @@ class SynapseHomeServer(HomeServer):
|
||||||
})
|
})
|
||||||
|
|
||||||
if name in ["media", "federation", "client"]:
|
if name in ["media", "federation", "client"]:
|
||||||
|
media_repo = MediaRepositoryResource(self)
|
||||||
resources.update({
|
resources.update({
|
||||||
MEDIA_PREFIX: MediaRepositoryResource(self),
|
MEDIA_PREFIX: media_repo,
|
||||||
|
LEGACY_MEDIA_PREFIX: media_repo,
|
||||||
CONTENT_REPO_PREFIX: ContentRepoResource(
|
CONTENT_REPO_PREFIX: ContentRepoResource(
|
||||||
self, self.config.uploads_path, self.auth, self.content_addr
|
self, self.config.uploads_path, self.auth, self.content_addr
|
||||||
),
|
),
|
||||||
|
@ -366,11 +353,20 @@ def setup(config_options):
|
||||||
Returns:
|
Returns:
|
||||||
HomeServer
|
HomeServer
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
config = HomeServerConfig.load_config(
|
config = HomeServerConfig.load_config(
|
||||||
"Synapse Homeserver",
|
"Synapse Homeserver",
|
||||||
config_options,
|
config_options,
|
||||||
generate_section="Homeserver"
|
generate_section="Homeserver"
|
||||||
)
|
)
|
||||||
|
except ConfigError as e:
|
||||||
|
sys.stderr.write("\n" + e.message + "\n")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if not config:
|
||||||
|
# If a config isn't returned, and an exception isn't raised, we're just
|
||||||
|
# generating config files and shouldn't try to continue.
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
config.setup_logging()
|
config.setup_logging()
|
||||||
|
|
||||||
|
@ -690,8 +686,8 @@ def run(hs):
|
||||||
stats["uptime_seconds"] = uptime
|
stats["uptime_seconds"] = uptime
|
||||||
stats["total_users"] = yield hs.get_datastore().count_all_users()
|
stats["total_users"] = yield hs.get_datastore().count_all_users()
|
||||||
|
|
||||||
all_rooms = yield hs.get_datastore().get_rooms(False)
|
room_count = yield hs.get_datastore().get_room_count()
|
||||||
stats["total_room_count"] = len(all_rooms)
|
stats["total_room_count"] = room_count
|
||||||
|
|
||||||
stats["daily_active_users"] = yield hs.get_datastore().count_daily_users()
|
stats["daily_active_users"] = yield hs.get_datastore().count_daily_users()
|
||||||
daily_messages = yield hs.get_datastore().count_daily_messages()
|
daily_messages = yield hs.get_datastore().count_daily_messages()
|
||||||
|
@ -713,6 +709,8 @@ def run(hs):
|
||||||
phone_home_task.start(60 * 60 * 24, now=False)
|
phone_home_task.start(60 * 60 * 24, now=False)
|
||||||
|
|
||||||
def in_thread():
|
def in_thread():
|
||||||
|
# Uncomment to enable tracing of log context changes.
|
||||||
|
# sys.settrace(logcontext_tracer)
|
||||||
with LoggingContext("run"):
|
with LoggingContext("run"):
|
||||||
change_resource_limit(hs.config.soft_file_limit)
|
change_resource_limit(hs.config.soft_file_limit)
|
||||||
reactor.run()
|
reactor.run()
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from synapse.config._base import ConfigError
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import sys
|
import sys
|
||||||
|
@ -21,7 +22,11 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
if action == "read":
|
if action == "read":
|
||||||
key = sys.argv[2]
|
key = sys.argv[2]
|
||||||
|
try:
|
||||||
config = HomeServerConfig.load_config("", sys.argv[3:])
|
config = HomeServerConfig.load_config("", sys.argv[3:])
|
||||||
|
except ConfigError as e:
|
||||||
|
sys.stderr.write("\n" + e.message + "\n")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
print getattr(config, key)
|
print getattr(config, key)
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
|
@ -17,7 +17,6 @@ import argparse
|
||||||
import errno
|
import errno
|
||||||
import os
|
import os
|
||||||
import yaml
|
import yaml
|
||||||
import sys
|
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
|
|
||||||
|
|
||||||
|
@ -136,13 +135,20 @@ class Config(object):
|
||||||
results.append(getattr(cls, name)(self, *args, **kargs))
|
results.append(getattr(cls, name)(self, *args, **kargs))
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def generate_config(self, config_dir_path, server_name, report_stats=None):
|
def generate_config(
|
||||||
|
self,
|
||||||
|
config_dir_path,
|
||||||
|
server_name,
|
||||||
|
is_generating_file,
|
||||||
|
report_stats=None,
|
||||||
|
):
|
||||||
default_config = "# vim:ft=yaml\n"
|
default_config = "# vim:ft=yaml\n"
|
||||||
|
|
||||||
default_config += "\n\n".join(dedent(conf) for conf in self.invoke_all(
|
default_config += "\n\n".join(dedent(conf) for conf in self.invoke_all(
|
||||||
"default_config",
|
"default_config",
|
||||||
config_dir_path=config_dir_path,
|
config_dir_path=config_dir_path,
|
||||||
server_name=server_name,
|
server_name=server_name,
|
||||||
|
is_generating_file=is_generating_file,
|
||||||
report_stats=report_stats,
|
report_stats=report_stats,
|
||||||
))
|
))
|
||||||
|
|
||||||
|
@ -244,8 +250,10 @@ class Config(object):
|
||||||
|
|
||||||
server_name = config_args.server_name
|
server_name = config_args.server_name
|
||||||
if not server_name:
|
if not server_name:
|
||||||
print "Must specify a server_name to a generate config for."
|
raise ConfigError(
|
||||||
sys.exit(1)
|
"Must specify a server_name to a generate config for."
|
||||||
|
" Pass -H server.name."
|
||||||
|
)
|
||||||
if not os.path.exists(config_dir_path):
|
if not os.path.exists(config_dir_path):
|
||||||
os.makedirs(config_dir_path)
|
os.makedirs(config_dir_path)
|
||||||
with open(config_path, "wb") as config_file:
|
with open(config_path, "wb") as config_file:
|
||||||
|
@ -253,6 +261,7 @@ class Config(object):
|
||||||
config_dir_path=config_dir_path,
|
config_dir_path=config_dir_path,
|
||||||
server_name=server_name,
|
server_name=server_name,
|
||||||
report_stats=(config_args.report_stats == "yes"),
|
report_stats=(config_args.report_stats == "yes"),
|
||||||
|
is_generating_file=True
|
||||||
)
|
)
|
||||||
obj.invoke_all("generate_files", config)
|
obj.invoke_all("generate_files", config)
|
||||||
config_file.write(config_bytes)
|
config_file.write(config_bytes)
|
||||||
|
@ -266,7 +275,7 @@ class Config(object):
|
||||||
"If this server name is incorrect, you will need to"
|
"If this server name is incorrect, you will need to"
|
||||||
" regenerate the SSL certificates"
|
" regenerate the SSL certificates"
|
||||||
)
|
)
|
||||||
sys.exit(0)
|
return
|
||||||
else:
|
else:
|
||||||
print (
|
print (
|
||||||
"Config file %r already exists. Generating any missing key"
|
"Config file %r already exists. Generating any missing key"
|
||||||
|
@ -302,25 +311,25 @@ class Config(object):
|
||||||
specified_config.update(yaml_config)
|
specified_config.update(yaml_config)
|
||||||
|
|
||||||
if "server_name" not in specified_config:
|
if "server_name" not in specified_config:
|
||||||
sys.stderr.write("\n" + MISSING_SERVER_NAME + "\n")
|
raise ConfigError(MISSING_SERVER_NAME)
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
server_name = specified_config["server_name"]
|
server_name = specified_config["server_name"]
|
||||||
_, config = obj.generate_config(
|
_, config = obj.generate_config(
|
||||||
config_dir_path=config_dir_path,
|
config_dir_path=config_dir_path,
|
||||||
server_name=server_name
|
server_name=server_name,
|
||||||
|
is_generating_file=False,
|
||||||
)
|
)
|
||||||
config.pop("log_config")
|
config.pop("log_config")
|
||||||
config.update(specified_config)
|
config.update(specified_config)
|
||||||
if "report_stats" not in config:
|
if "report_stats" not in config:
|
||||||
sys.stderr.write(
|
raise ConfigError(
|
||||||
"\n" + MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" +
|
MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" +
|
||||||
MISSING_REPORT_STATS_SPIEL + "\n")
|
MISSING_REPORT_STATS_SPIEL
|
||||||
sys.exit(1)
|
)
|
||||||
|
|
||||||
if generate_keys:
|
if generate_keys:
|
||||||
obj.invoke_all("generate_files", config)
|
obj.invoke_all("generate_files", config)
|
||||||
sys.exit(0)
|
return
|
||||||
|
|
||||||
obj.invoke_all("read_config", config)
|
obj.invoke_all("read_config", config)
|
||||||
|
|
||||||
|
|
|
@ -22,8 +22,14 @@ from signedjson.key import (
|
||||||
read_signing_keys, write_signing_keys, NACL_ED25519
|
read_signing_keys, write_signing_keys, NACL_ED25519
|
||||||
)
|
)
|
||||||
from unpaddedbase64 import decode_base64
|
from unpaddedbase64 import decode_base64
|
||||||
|
from synapse.util.stringutils import random_string_with_symbols
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class KeyConfig(Config):
|
class KeyConfig(Config):
|
||||||
|
@ -40,9 +46,29 @@ class KeyConfig(Config):
|
||||||
config["perspectives"]
|
config["perspectives"]
|
||||||
)
|
)
|
||||||
|
|
||||||
def default_config(self, config_dir_path, server_name, **kwargs):
|
self.macaroon_secret_key = config.get(
|
||||||
|
"macaroon_secret_key", self.registration_shared_secret
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.macaroon_secret_key:
|
||||||
|
# Unfortunately, there are people out there that don't have this
|
||||||
|
# set. Lets just be "nice" and derive one from their secret key.
|
||||||
|
logger.warn("Config is missing missing macaroon_secret_key")
|
||||||
|
seed = self.signing_key[0].seed
|
||||||
|
self.macaroon_secret_key = hashlib.sha256(seed)
|
||||||
|
|
||||||
|
def default_config(self, config_dir_path, server_name, is_generating_file=False,
|
||||||
|
**kwargs):
|
||||||
base_key_name = os.path.join(config_dir_path, server_name)
|
base_key_name = os.path.join(config_dir_path, server_name)
|
||||||
|
|
||||||
|
if is_generating_file:
|
||||||
|
macaroon_secret_key = random_string_with_symbols(50)
|
||||||
|
else:
|
||||||
|
macaroon_secret_key = None
|
||||||
|
|
||||||
return """\
|
return """\
|
||||||
|
macaroon_secret_key: "%(macaroon_secret_key)s"
|
||||||
|
|
||||||
## Signing Keys ##
|
## Signing Keys ##
|
||||||
|
|
||||||
# Path to the signing key to sign messages with
|
# Path to the signing key to sign messages with
|
||||||
|
|
|
@ -23,22 +23,23 @@ from distutils.util import strtobool
|
||||||
class RegistrationConfig(Config):
|
class RegistrationConfig(Config):
|
||||||
|
|
||||||
def read_config(self, config):
|
def read_config(self, config):
|
||||||
self.disable_registration = not bool(
|
self.enable_registration = bool(
|
||||||
strtobool(str(config["enable_registration"]))
|
strtobool(str(config["enable_registration"]))
|
||||||
)
|
)
|
||||||
if "disable_registration" in config:
|
if "disable_registration" in config:
|
||||||
self.disable_registration = bool(
|
self.enable_registration = not bool(
|
||||||
strtobool(str(config["disable_registration"]))
|
strtobool(str(config["disable_registration"]))
|
||||||
)
|
)
|
||||||
|
|
||||||
self.registration_shared_secret = config.get("registration_shared_secret")
|
self.registration_shared_secret = config.get("registration_shared_secret")
|
||||||
self.macaroon_secret_key = config.get("macaroon_secret_key")
|
|
||||||
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
|
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
|
||||||
|
self.trusted_third_party_id_servers = config["trusted_third_party_id_servers"]
|
||||||
self.allow_guest_access = config.get("allow_guest_access", False)
|
self.allow_guest_access = config.get("allow_guest_access", False)
|
||||||
|
|
||||||
def default_config(self, **kwargs):
|
def default_config(self, **kwargs):
|
||||||
registration_shared_secret = random_string_with_symbols(50)
|
registration_shared_secret = random_string_with_symbols(50)
|
||||||
macaroon_secret_key = random_string_with_symbols(50)
|
|
||||||
return """\
|
return """\
|
||||||
## Registration ##
|
## Registration ##
|
||||||
|
|
||||||
|
@ -49,8 +50,6 @@ class RegistrationConfig(Config):
|
||||||
# secret, even if registration is otherwise disabled.
|
# secret, even if registration is otherwise disabled.
|
||||||
registration_shared_secret: "%(registration_shared_secret)s"
|
registration_shared_secret: "%(registration_shared_secret)s"
|
||||||
|
|
||||||
macaroon_secret_key: "%(macaroon_secret_key)s"
|
|
||||||
|
|
||||||
# Set the number of bcrypt rounds used to generate password hash.
|
# Set the number of bcrypt rounds used to generate password hash.
|
||||||
# Larger numbers increase the work factor needed to generate the hash.
|
# Larger numbers increase the work factor needed to generate the hash.
|
||||||
# The default number of rounds is 12.
|
# The default number of rounds is 12.
|
||||||
|
@ -60,6 +59,12 @@ class RegistrationConfig(Config):
|
||||||
# participate in rooms hosted on this server which have been made
|
# participate in rooms hosted on this server which have been made
|
||||||
# accessible to anonymous users.
|
# accessible to anonymous users.
|
||||||
allow_guest_access: False
|
allow_guest_access: False
|
||||||
|
|
||||||
|
# The list of identity servers trusted to verify third party
|
||||||
|
# identifiers by this server.
|
||||||
|
trusted_third_party_id_servers:
|
||||||
|
- matrix.org
|
||||||
|
- vector.im
|
||||||
""" % locals()
|
""" % locals()
|
||||||
|
|
||||||
def add_arguments(self, parser):
|
def add_arguments(self, parser):
|
||||||
|
@ -71,6 +76,6 @@ class RegistrationConfig(Config):
|
||||||
|
|
||||||
def read_arguments(self, args):
|
def read_arguments(self, args):
|
||||||
if args.enable_registration is not None:
|
if args.enable_registration is not None:
|
||||||
self.disable_registration = not bool(
|
self.enable_registration = bool(
|
||||||
strtobool(str(args.enable_registration))
|
strtobool(str(args.enable_registration))
|
||||||
)
|
)
|
||||||
|
|
|
@ -18,6 +18,10 @@ from synapse.api.errors import SynapseError, Codes
|
||||||
from synapse.util.retryutils import get_retry_limiter
|
from synapse.util.retryutils import get_retry_limiter
|
||||||
from synapse.util import unwrapFirstError
|
from synapse.util import unwrapFirstError
|
||||||
from synapse.util.async import ObservableDeferred
|
from synapse.util.async import ObservableDeferred
|
||||||
|
from synapse.util.logcontext import (
|
||||||
|
preserve_context_over_deferred, preserve_context_over_fn, PreserveLoggingContext,
|
||||||
|
preserve_fn
|
||||||
|
)
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
@ -142,6 +146,8 @@ class Keyring(object):
|
||||||
for server_name, _ in server_and_json
|
for server_name, _ in server_and_json
|
||||||
}
|
}
|
||||||
|
|
||||||
|
with PreserveLoggingContext():
|
||||||
|
|
||||||
# We want to wait for any previous lookups to complete before
|
# We want to wait for any previous lookups to complete before
|
||||||
# proceeding.
|
# proceeding.
|
||||||
wait_on_deferred = self.wait_for_previous_lookups(
|
wait_on_deferred = self.wait_for_previous_lookups(
|
||||||
|
@ -175,7 +181,8 @@ class Keyring(object):
|
||||||
# Pass those keys to handle_key_deferred so that the json object
|
# Pass those keys to handle_key_deferred so that the json object
|
||||||
# signatures can be verified
|
# signatures can be verified
|
||||||
return [
|
return [
|
||||||
handle_key_deferred(
|
preserve_context_over_fn(
|
||||||
|
handle_key_deferred,
|
||||||
group_id_to_group[g_id],
|
group_id_to_group[g_id],
|
||||||
deferreds[g_id],
|
deferreds[g_id],
|
||||||
)
|
)
|
||||||
|
@ -198,12 +205,13 @@ class Keyring(object):
|
||||||
if server_name in self.key_downloads
|
if server_name in self.key_downloads
|
||||||
]
|
]
|
||||||
if wait_on:
|
if wait_on:
|
||||||
|
with PreserveLoggingContext():
|
||||||
yield defer.DeferredList(wait_on)
|
yield defer.DeferredList(wait_on)
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
|
||||||
for server_name, deferred in server_to_deferred.items():
|
for server_name, deferred in server_to_deferred.items():
|
||||||
d = ObservableDeferred(deferred)
|
d = ObservableDeferred(preserve_context_over_deferred(deferred))
|
||||||
self.key_downloads[server_name] = d
|
self.key_downloads[server_name] = d
|
||||||
|
|
||||||
def rm(r, server_name):
|
def rm(r, server_name):
|
||||||
|
@ -244,6 +252,7 @@ class Keyring(object):
|
||||||
for group in group_id_to_group.values():
|
for group in group_id_to_group.values():
|
||||||
for key_id in group.key_ids:
|
for key_id in group.key_ids:
|
||||||
if key_id in merged_results[group.server_name]:
|
if key_id in merged_results[group.server_name]:
|
||||||
|
with PreserveLoggingContext():
|
||||||
group_id_to_deferred[group.group_id].callback((
|
group_id_to_deferred[group.group_id].callback((
|
||||||
group.group_id,
|
group.group_id,
|
||||||
group.server_name,
|
group.server_name,
|
||||||
|
@ -504,7 +513,7 @@ class Keyring(object):
|
||||||
|
|
||||||
yield defer.gatherResults(
|
yield defer.gatherResults(
|
||||||
[
|
[
|
||||||
self.store_keys(
|
preserve_fn(self.store_keys)(
|
||||||
server_name=key_server_name,
|
server_name=key_server_name,
|
||||||
from_server=server_name,
|
from_server=server_name,
|
||||||
verify_keys=verify_keys,
|
verify_keys=verify_keys,
|
||||||
|
@ -573,7 +582,7 @@ class Keyring(object):
|
||||||
|
|
||||||
yield defer.gatherResults(
|
yield defer.gatherResults(
|
||||||
[
|
[
|
||||||
self.store.store_server_keys_json(
|
preserve_fn(self.store.store_server_keys_json)(
|
||||||
server_name=server_name,
|
server_name=server_name,
|
||||||
key_id=key_id,
|
key_id=key_id,
|
||||||
from_server=server_name,
|
from_server=server_name,
|
||||||
|
@ -675,7 +684,7 @@ class Keyring(object):
|
||||||
# TODO(markjh): Store whether the keys have expired.
|
# TODO(markjh): Store whether the keys have expired.
|
||||||
yield defer.gatherResults(
|
yield defer.gatherResults(
|
||||||
[
|
[
|
||||||
self.store.store_server_verify_key(
|
preserve_fn(self.store.store_server_verify_key)(
|
||||||
server_name, server_name, key.time_added, key
|
server_name, server_name, key.time_added, key
|
||||||
)
|
)
|
||||||
for key_id, key in verify_keys.items()
|
for key_id, key in verify_keys.items()
|
||||||
|
|
|
@ -20,3 +20,4 @@ class EventContext(object):
|
||||||
self.current_state = current_state
|
self.current_state = current_state
|
||||||
self.state_group = None
|
self.state_group = None
|
||||||
self.rejected = False
|
self.rejected = False
|
||||||
|
self.push_actions = []
|
||||||
|
|
|
@ -57,7 +57,7 @@ class FederationClient(FederationBase):
|
||||||
cache_name="get_pdu_cache",
|
cache_name="get_pdu_cache",
|
||||||
clock=self._clock,
|
clock=self._clock,
|
||||||
max_len=1000,
|
max_len=1000,
|
||||||
expiry_ms=120*1000,
|
expiry_ms=120 * 1000,
|
||||||
reset_expiry_on_get=False,
|
reset_expiry_on_get=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -126,10 +126,8 @@ class FederationServer(FederationBase):
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
for pdu in pdu_list:
|
for pdu in pdu_list:
|
||||||
d = self._handle_new_pdu(transaction.origin, pdu)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield d
|
yield self._handle_new_pdu(transaction.origin, pdu)
|
||||||
results.append({})
|
results.append({})
|
||||||
except FederationError as e:
|
except FederationError as e:
|
||||||
self.send_failure(e, transaction.origin)
|
self.send_failure(e, transaction.origin)
|
||||||
|
|
|
@ -103,7 +103,6 @@ class TransactionQueue(object):
|
||||||
else:
|
else:
|
||||||
return not destination.startswith("localhost")
|
return not destination.startswith("localhost")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def enqueue_pdu(self, pdu, destinations, order):
|
def enqueue_pdu(self, pdu, destinations, order):
|
||||||
# We loop through all destinations to see whether we already have
|
# We loop through all destinations to see whether we already have
|
||||||
# a transaction in progress. If we do, stick it in the pending_pdus
|
# a transaction in progress. If we do, stick it in the pending_pdus
|
||||||
|
@ -141,8 +140,6 @@ class TransactionQueue(object):
|
||||||
|
|
||||||
deferreds.append(deferred)
|
deferreds.append(deferred)
|
||||||
|
|
||||||
yield defer.DeferredList(deferreds, consumeErrors=True)
|
|
||||||
|
|
||||||
# NO inlineCallbacks
|
# NO inlineCallbacks
|
||||||
def enqueue_edu(self, edu):
|
def enqueue_edu(self, edu):
|
||||||
destination = edu.destination
|
destination = edu.destination
|
||||||
|
|
|
@ -53,25 +53,10 @@ class BaseHandler(object):
|
||||||
self.event_builder_factory = hs.get_event_builder_factory()
|
self.event_builder_factory = hs.get_event_builder_factory()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _filter_events_for_clients(self, user_tuples, events):
|
def _filter_events_for_clients(self, user_tuples, events, event_id_to_state):
|
||||||
""" Returns dict of user_id -> list of events that user is allowed to
|
""" Returns dict of user_id -> list of events that user is allowed to
|
||||||
see.
|
see.
|
||||||
"""
|
"""
|
||||||
# If there is only one user, just get the state for that one user,
|
|
||||||
# otherwise just get all the state.
|
|
||||||
if len(user_tuples) == 1:
|
|
||||||
types = (
|
|
||||||
(EventTypes.RoomHistoryVisibility, ""),
|
|
||||||
(EventTypes.Member, user_tuples[0][0]),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
types = None
|
|
||||||
|
|
||||||
event_id_to_state = yield self.store.get_state_for_events(
|
|
||||||
frozenset(e.event_id for e in events),
|
|
||||||
types=types
|
|
||||||
)
|
|
||||||
|
|
||||||
forgotten = yield defer.gatherResults([
|
forgotten = yield defer.gatherResults([
|
||||||
self.store.who_forgot_in_room(
|
self.store.who_forgot_in_room(
|
||||||
room_id,
|
room_id,
|
||||||
|
@ -135,7 +120,17 @@ class BaseHandler(object):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _filter_events_for_client(self, user_id, events, is_peeking=False):
|
def _filter_events_for_client(self, user_id, events, is_peeking=False):
|
||||||
# Assumes that user has at some point joined the room if not is_guest.
|
# Assumes that user has at some point joined the room if not is_guest.
|
||||||
res = yield self._filter_events_for_clients([(user_id, is_peeking)], events)
|
types = (
|
||||||
|
(EventTypes.RoomHistoryVisibility, ""),
|
||||||
|
(EventTypes.Member, user_id),
|
||||||
|
)
|
||||||
|
event_id_to_state = yield self.store.get_state_for_events(
|
||||||
|
frozenset(e.event_id for e in events),
|
||||||
|
types=types
|
||||||
|
)
|
||||||
|
res = yield self._filter_events_for_clients(
|
||||||
|
[(user_id, is_peeking)], events, event_id_to_state
|
||||||
|
)
|
||||||
defer.returnValue(res.get(user_id, []))
|
defer.returnValue(res.get(user_id, []))
|
||||||
|
|
||||||
def ratelimit(self, user_id):
|
def ratelimit(self, user_id):
|
||||||
|
@ -147,7 +142,7 @@ class BaseHandler(object):
|
||||||
)
|
)
|
||||||
if not allowed:
|
if not allowed:
|
||||||
raise LimitExceededError(
|
raise LimitExceededError(
|
||||||
retry_after_ms=int(1000*(time_allowed - time_now)),
|
retry_after_ms=int(1000 * (time_allowed - time_now)),
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -269,13 +264,13 @@ class BaseHandler(object):
|
||||||
"You don't have permission to redact events"
|
"You don't have permission to redact events"
|
||||||
)
|
)
|
||||||
|
|
||||||
(event_stream_id, max_stream_id) = yield self.store.persist_event(
|
|
||||||
event, context=context
|
|
||||||
)
|
|
||||||
|
|
||||||
action_generator = ActionGenerator(self.hs)
|
action_generator = ActionGenerator(self.hs)
|
||||||
yield action_generator.handle_push_actions_for_event(
|
yield action_generator.handle_push_actions_for_event(
|
||||||
event, self
|
event, context, self
|
||||||
|
)
|
||||||
|
|
||||||
|
(event_stream_id, max_stream_id) = yield self.store.persist_event(
|
||||||
|
event, context=context
|
||||||
)
|
)
|
||||||
|
|
||||||
destinations = set()
|
destinations = set()
|
||||||
|
@ -293,19 +288,11 @@ class BaseHandler(object):
|
||||||
|
|
||||||
with PreserveLoggingContext():
|
with PreserveLoggingContext():
|
||||||
# Don't block waiting on waking up all the listeners.
|
# Don't block waiting on waking up all the listeners.
|
||||||
notify_d = self.notifier.on_new_room_event(
|
self.notifier.on_new_room_event(
|
||||||
event, event_stream_id, max_stream_id,
|
event, event_stream_id, max_stream_id,
|
||||||
extra_users=extra_users
|
extra_users=extra_users
|
||||||
)
|
)
|
||||||
|
|
||||||
def log_failure(f):
|
|
||||||
logger.warn(
|
|
||||||
"Failed to notify about %s: %s",
|
|
||||||
event.event_id, f.value
|
|
||||||
)
|
|
||||||
|
|
||||||
notify_d.addErrback(log_failure)
|
|
||||||
|
|
||||||
# If invite, remove room_state from unsigned before sending.
|
# If invite, remove room_state from unsigned before sending.
|
||||||
event.unsigned.pop("invite_room_state", None)
|
event.unsigned.pop("invite_room_state", None)
|
||||||
|
|
||||||
|
|
|
@ -175,8 +175,8 @@ class DirectoryHandler(BaseHandler):
|
||||||
# If this server is in the list of servers, return it first.
|
# If this server is in the list of servers, return it first.
|
||||||
if self.server_name in servers:
|
if self.server_name in servers:
|
||||||
servers = (
|
servers = (
|
||||||
[self.server_name]
|
[self.server_name] +
|
||||||
+ [s for s in servers if s != self.server_name]
|
[s for s in servers if s != self.server_name]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
servers = list(servers)
|
servers = list(servers)
|
||||||
|
|
|
@ -18,6 +18,7 @@ from twisted.internet import defer
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
from synapse.events.utils import serialize_event
|
from synapse.events.utils import serialize_event
|
||||||
|
from synapse.util.logcontext import preserve_context_over_fn
|
||||||
|
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
|
||||||
|
@ -29,11 +30,17 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def started_user_eventstream(distributor, user):
|
def started_user_eventstream(distributor, user):
|
||||||
return distributor.fire("started_user_eventstream", user)
|
return preserve_context_over_fn(
|
||||||
|
distributor.fire,
|
||||||
|
"started_user_eventstream", user
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def stopped_user_eventstream(distributor, user):
|
def stopped_user_eventstream(distributor, user):
|
||||||
return distributor.fire("stopped_user_eventstream", user)
|
return preserve_context_over_fn(
|
||||||
|
distributor.fire,
|
||||||
|
"stopped_user_eventstream", user
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class EventStreamHandler(BaseHandler):
|
class EventStreamHandler(BaseHandler):
|
||||||
|
@ -130,7 +137,7 @@ class EventStreamHandler(BaseHandler):
|
||||||
|
|
||||||
# Add some randomness to this value to try and mitigate against
|
# Add some randomness to this value to try and mitigate against
|
||||||
# thundering herds on restart.
|
# thundering herds on restart.
|
||||||
timeout = random.randint(int(timeout*0.9), int(timeout*1.1))
|
timeout = random.randint(int(timeout * 0.9), int(timeout * 1.1))
|
||||||
|
|
||||||
events, tokens = yield self.notifier.get_events_for(
|
events, tokens = yield self.notifier.get_events_for(
|
||||||
auth_user, pagin_config, timeout,
|
auth_user, pagin_config, timeout,
|
||||||
|
|
|
@ -221,19 +221,11 @@ class FederationHandler(BaseHandler):
|
||||||
extra_users.append(target_user)
|
extra_users.append(target_user)
|
||||||
|
|
||||||
with PreserveLoggingContext():
|
with PreserveLoggingContext():
|
||||||
d = self.notifier.on_new_room_event(
|
self.notifier.on_new_room_event(
|
||||||
event, event_stream_id, max_stream_id,
|
event, event_stream_id, max_stream_id,
|
||||||
extra_users=extra_users
|
extra_users=extra_users
|
||||||
)
|
)
|
||||||
|
|
||||||
def log_failure(f):
|
|
||||||
logger.warn(
|
|
||||||
"Failed to notify about %s: %s",
|
|
||||||
event.event_id, f.value
|
|
||||||
)
|
|
||||||
|
|
||||||
d.addErrback(log_failure)
|
|
||||||
|
|
||||||
if event.type == EventTypes.Member:
|
if event.type == EventTypes.Member:
|
||||||
if event.membership == Membership.JOIN:
|
if event.membership == Membership.JOIN:
|
||||||
prev_state = context.current_state.get((event.type, event.state_key))
|
prev_state = context.current_state.get((event.type, event.state_key))
|
||||||
|
@ -244,12 +236,6 @@ class FederationHandler(BaseHandler):
|
||||||
user = UserID.from_string(event.state_key)
|
user = UserID.from_string(event.state_key)
|
||||||
yield user_joined_room(self.distributor, user, event.room_id)
|
yield user_joined_room(self.distributor, user, event.room_id)
|
||||||
|
|
||||||
if not backfilled and not event.internal_metadata.is_outlier():
|
|
||||||
action_generator = ActionGenerator(self.hs)
|
|
||||||
yield action_generator.handle_push_actions_for_event(
|
|
||||||
event, self
|
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _filter_events_for_server(self, server_name, room_id, events):
|
def _filter_events_for_server(self, server_name, room_id, events):
|
||||||
event_to_state = yield self.store.get_state_for_events(
|
event_to_state = yield self.store.get_state_for_events(
|
||||||
|
@ -643,19 +629,11 @@ class FederationHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
|
|
||||||
with PreserveLoggingContext():
|
with PreserveLoggingContext():
|
||||||
d = self.notifier.on_new_room_event(
|
self.notifier.on_new_room_event(
|
||||||
event, event_stream_id, max_stream_id,
|
event, event_stream_id, max_stream_id,
|
||||||
extra_users=[joinee]
|
extra_users=[joinee]
|
||||||
)
|
)
|
||||||
|
|
||||||
def log_failure(f):
|
|
||||||
logger.warn(
|
|
||||||
"Failed to notify about %s: %s",
|
|
||||||
event.event_id, f.value
|
|
||||||
)
|
|
||||||
|
|
||||||
d.addErrback(log_failure)
|
|
||||||
|
|
||||||
logger.debug("Finished joining %s to %s", joinee, room_id)
|
logger.debug("Finished joining %s to %s", joinee, room_id)
|
||||||
finally:
|
finally:
|
||||||
room_queue = self.room_queues[room_id]
|
room_queue = self.room_queues[room_id]
|
||||||
|
@ -730,18 +708,10 @@ class FederationHandler(BaseHandler):
|
||||||
extra_users.append(target_user)
|
extra_users.append(target_user)
|
||||||
|
|
||||||
with PreserveLoggingContext():
|
with PreserveLoggingContext():
|
||||||
d = self.notifier.on_new_room_event(
|
self.notifier.on_new_room_event(
|
||||||
event, event_stream_id, max_stream_id, extra_users=extra_users
|
event, event_stream_id, max_stream_id, extra_users=extra_users
|
||||||
)
|
)
|
||||||
|
|
||||||
def log_failure(f):
|
|
||||||
logger.warn(
|
|
||||||
"Failed to notify about %s: %s",
|
|
||||||
event.event_id, f.value
|
|
||||||
)
|
|
||||||
|
|
||||||
d.addErrback(log_failure)
|
|
||||||
|
|
||||||
if event.type == EventTypes.Member:
|
if event.type == EventTypes.Member:
|
||||||
if event.content["membership"] == Membership.JOIN:
|
if event.content["membership"] == Membership.JOIN:
|
||||||
user = UserID.from_string(event.state_key)
|
user = UserID.from_string(event.state_key)
|
||||||
|
@ -811,19 +781,11 @@ class FederationHandler(BaseHandler):
|
||||||
|
|
||||||
target_user = UserID.from_string(event.state_key)
|
target_user = UserID.from_string(event.state_key)
|
||||||
with PreserveLoggingContext():
|
with PreserveLoggingContext():
|
||||||
d = self.notifier.on_new_room_event(
|
self.notifier.on_new_room_event(
|
||||||
event, event_stream_id, max_stream_id,
|
event, event_stream_id, max_stream_id,
|
||||||
extra_users=[target_user],
|
extra_users=[target_user],
|
||||||
)
|
)
|
||||||
|
|
||||||
def log_failure(f):
|
|
||||||
logger.warn(
|
|
||||||
"Failed to notify about %s: %s",
|
|
||||||
event.event_id, f.value
|
|
||||||
)
|
|
||||||
|
|
||||||
d.addErrback(log_failure)
|
|
||||||
|
|
||||||
defer.returnValue(event)
|
defer.returnValue(event)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -948,18 +910,10 @@ class FederationHandler(BaseHandler):
|
||||||
extra_users.append(target_user)
|
extra_users.append(target_user)
|
||||||
|
|
||||||
with PreserveLoggingContext():
|
with PreserveLoggingContext():
|
||||||
d = self.notifier.on_new_room_event(
|
self.notifier.on_new_room_event(
|
||||||
event, event_stream_id, max_stream_id, extra_users=extra_users
|
event, event_stream_id, max_stream_id, extra_users=extra_users
|
||||||
)
|
)
|
||||||
|
|
||||||
def log_failure(f):
|
|
||||||
logger.warn(
|
|
||||||
"Failed to notify about %s: %s",
|
|
||||||
event.event_id, f.value
|
|
||||||
)
|
|
||||||
|
|
||||||
d.addErrback(log_failure)
|
|
||||||
|
|
||||||
new_pdu = event
|
new_pdu = event
|
||||||
|
|
||||||
destinations = set()
|
destinations = set()
|
||||||
|
@ -1113,6 +1067,12 @@ class FederationHandler(BaseHandler):
|
||||||
auth_events=auth_events,
|
auth_events=auth_events,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not backfilled and not event.internal_metadata.is_outlier():
|
||||||
|
action_generator = ActionGenerator(self.hs)
|
||||||
|
yield action_generator.handle_push_actions_for_event(
|
||||||
|
event, context, self
|
||||||
|
)
|
||||||
|
|
||||||
event_stream_id, max_stream_id = yield self.store.persist_event(
|
event_stream_id, max_stream_id = yield self.store.persist_event(
|
||||||
event,
|
event,
|
||||||
context=context,
|
context=context,
|
||||||
|
|
|
@ -36,14 +36,15 @@ class IdentityHandler(BaseHandler):
|
||||||
|
|
||||||
self.http_client = hs.get_simple_http_client()
|
self.http_client = hs.get_simple_http_client()
|
||||||
|
|
||||||
|
self.trusted_id_servers = set(hs.config.trusted_third_party_id_servers)
|
||||||
|
self.trust_any_id_server_just_for_testing_do_not_use = (
|
||||||
|
hs.config.use_insecure_ssl_client_just_for_testing_do_not_use
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def threepid_from_creds(self, creds):
|
def threepid_from_creds(self, creds):
|
||||||
yield run_on_reactor()
|
yield run_on_reactor()
|
||||||
|
|
||||||
# XXX: make this configurable!
|
|
||||||
# trustedIdServers = ['matrix.org', 'localhost:8090']
|
|
||||||
trustedIdServers = ['matrix.org', 'vector.im']
|
|
||||||
|
|
||||||
if 'id_server' in creds:
|
if 'id_server' in creds:
|
||||||
id_server = creds['id_server']
|
id_server = creds['id_server']
|
||||||
elif 'idServer' in creds:
|
elif 'idServer' in creds:
|
||||||
|
@ -58,7 +59,16 @@ class IdentityHandler(BaseHandler):
|
||||||
else:
|
else:
|
||||||
raise SynapseError(400, "No client_secret in creds")
|
raise SynapseError(400, "No client_secret in creds")
|
||||||
|
|
||||||
if id_server not in trustedIdServers:
|
if id_server not in self.trusted_id_servers:
|
||||||
|
if self.trust_any_id_server_just_for_testing_do_not_use:
|
||||||
|
logger.warn(
|
||||||
|
"Trusting untrustworthy ID server %r even though it isn't"
|
||||||
|
" in the trusted id list for testing because"
|
||||||
|
" 'use_insecure_ssl_client_just_for_testing_do_not_use'"
|
||||||
|
" is set in the config",
|
||||||
|
id_server,
|
||||||
|
)
|
||||||
|
else:
|
||||||
logger.warn('%s is not a trusted ID server: rejecting 3pid ' +
|
logger.warn('%s is not a trusted ID server: rejecting 3pid ' +
|
||||||
'credentials', id_server)
|
'credentials', id_server)
|
||||||
defer.returnValue(None)
|
defer.returnValue(None)
|
||||||
|
|
|
@ -34,7 +34,7 @@ metrics = synapse.metrics.get_metrics_for(__name__)
|
||||||
|
|
||||||
|
|
||||||
# Don't bother bumping "last active" time if it differs by less than 60 seconds
|
# Don't bother bumping "last active" time if it differs by less than 60 seconds
|
||||||
LAST_ACTIVE_GRANULARITY = 60*1000
|
LAST_ACTIVE_GRANULARITY = 60 * 1000
|
||||||
|
|
||||||
# Keep no more than this number of offline serial revisions
|
# Keep no more than this number of offline serial revisions
|
||||||
MAX_OFFLINE_SERIALS = 1000
|
MAX_OFFLINE_SERIALS = 1000
|
||||||
|
@ -378,9 +378,9 @@ class PresenceHandler(BaseHandler):
|
||||||
was_polling = target_user in self._user_cachemap
|
was_polling = target_user in self._user_cachemap
|
||||||
|
|
||||||
if now_online and not was_polling:
|
if now_online and not was_polling:
|
||||||
self.start_polling_presence(target_user, state=state)
|
yield self.start_polling_presence(target_user, state=state)
|
||||||
elif not now_online and was_polling:
|
elif not now_online and was_polling:
|
||||||
self.stop_polling_presence(target_user)
|
yield self.stop_polling_presence(target_user)
|
||||||
|
|
||||||
# TODO(paul): perform a presence push as part of start/stop poll so
|
# TODO(paul): perform a presence push as part of start/stop poll so
|
||||||
# we don't have to do this all the time
|
# we don't have to do this all the time
|
||||||
|
@ -394,6 +394,7 @@ class PresenceHandler(BaseHandler):
|
||||||
if now - prev_state.state.get("last_active", 0) < LAST_ACTIVE_GRANULARITY:
|
if now - prev_state.state.get("last_active", 0) < LAST_ACTIVE_GRANULARITY:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
with PreserveLoggingContext():
|
||||||
self.changed_presencelike_data(user, {"last_active": now})
|
self.changed_presencelike_data(user, {"last_active": now})
|
||||||
|
|
||||||
def get_joined_rooms_for_user(self, user):
|
def get_joined_rooms_for_user(self, user):
|
||||||
|
@ -466,6 +467,7 @@ class PresenceHandler(BaseHandler):
|
||||||
local_user, room_ids=[room_id], add_to_cache=False
|
local_user, room_ids=[room_id], add_to_cache=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with PreserveLoggingContext():
|
||||||
self.push_update_to_local_and_remote(
|
self.push_update_to_local_and_remote(
|
||||||
observed_user=local_user,
|
observed_user=local_user,
|
||||||
users_to_push=[user],
|
users_to_push=[user],
|
||||||
|
@ -556,7 +558,7 @@ class PresenceHandler(BaseHandler):
|
||||||
observer_user.localpart, observed_user.to_string()
|
observer_user.localpart, observed_user.to_string()
|
||||||
)
|
)
|
||||||
|
|
||||||
self.start_polling_presence(
|
yield self.start_polling_presence(
|
||||||
observer_user, target_user=observed_user
|
observer_user, target_user=observed_user
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,6 @@ from synapse.api.errors import (
|
||||||
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
|
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
|
||||||
)
|
)
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
import synapse.util.stringutils as stringutils
|
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
from synapse.http.client import CaptchaServerHttpClient
|
from synapse.http.client import CaptchaServerHttpClient
|
||||||
|
|
||||||
|
@ -45,6 +44,8 @@ class RegistrationHandler(BaseHandler):
|
||||||
self.distributor.declare("registered_user")
|
self.distributor.declare("registered_user")
|
||||||
self.captcha_client = CaptchaServerHttpClient(hs)
|
self.captcha_client = CaptchaServerHttpClient(hs)
|
||||||
|
|
||||||
|
self._next_generated_user_id = None
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def check_username(self, localpart, guest_access_token=None):
|
def check_username(self, localpart, guest_access_token=None):
|
||||||
yield run_on_reactor()
|
yield run_on_reactor()
|
||||||
|
@ -91,7 +92,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
localpart : The local part of the user ID to register. If None,
|
localpart : The local part of the user ID to register. If None,
|
||||||
one will be randomly generated.
|
one will be generated.
|
||||||
password (str) : The password to assign to this user so they can
|
password (str) : The password to assign to this user so they can
|
||||||
login again. This can be None which means they cannot login again
|
login again. This can be None which means they cannot login again
|
||||||
via a password (e.g. the user is an application service user).
|
via a password (e.g. the user is an application service user).
|
||||||
|
@ -108,6 +109,18 @@ class RegistrationHandler(BaseHandler):
|
||||||
if localpart:
|
if localpart:
|
||||||
yield self.check_username(localpart, guest_access_token=guest_access_token)
|
yield self.check_username(localpart, guest_access_token=guest_access_token)
|
||||||
|
|
||||||
|
was_guest = guest_access_token is not None
|
||||||
|
|
||||||
|
if not was_guest:
|
||||||
|
try:
|
||||||
|
int(localpart)
|
||||||
|
raise RegistrationError(
|
||||||
|
400,
|
||||||
|
"Numeric user IDs are reserved for guest users."
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
user = UserID(localpart, self.hs.hostname)
|
user = UserID(localpart, self.hs.hostname)
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
|
|
||||||
|
@ -118,38 +131,36 @@ class RegistrationHandler(BaseHandler):
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
token=token,
|
token=token,
|
||||||
password_hash=password_hash,
|
password_hash=password_hash,
|
||||||
was_guest=guest_access_token is not None,
|
was_guest=was_guest,
|
||||||
make_guest=make_guest,
|
make_guest=make_guest,
|
||||||
)
|
)
|
||||||
|
|
||||||
yield registered_user(self.distributor, user)
|
yield registered_user(self.distributor, user)
|
||||||
else:
|
else:
|
||||||
# autogen a random user ID
|
# autogen a sequential user ID
|
||||||
attempts = 0
|
attempts = 0
|
||||||
user_id = None
|
|
||||||
token = None
|
token = None
|
||||||
while not user_id:
|
user = None
|
||||||
try:
|
while not user:
|
||||||
localpart = self._generate_user_id()
|
localpart = yield self._generate_user_id(attempts > 0)
|
||||||
user = UserID(localpart, self.hs.hostname)
|
user = UserID(localpart, self.hs.hostname)
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
yield self.check_user_id_is_valid(user_id)
|
yield self.check_user_id_is_valid(user_id)
|
||||||
if generate_token:
|
if generate_token:
|
||||||
token = self.auth_handler().generate_access_token(user_id)
|
token = self.auth_handler().generate_access_token(user_id)
|
||||||
|
try:
|
||||||
yield self.store.register(
|
yield self.store.register(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
token=token,
|
token=token,
|
||||||
password_hash=password_hash)
|
password_hash=password_hash,
|
||||||
|
make_guest=make_guest
|
||||||
yield registered_user(self.distributor, user)
|
)
|
||||||
except SynapseError:
|
except SynapseError:
|
||||||
# if user id is taken, just generate another
|
# if user id is taken, just generate another
|
||||||
user_id = None
|
user_id = None
|
||||||
token = None
|
token = None
|
||||||
attempts += 1
|
attempts += 1
|
||||||
if attempts > 5:
|
yield registered_user(self.distributor, user)
|
||||||
raise RegistrationError(
|
|
||||||
500, "Cannot generate user ID.")
|
|
||||||
|
|
||||||
# We used to generate default identicons here, but nowadays
|
# We used to generate default identicons here, but nowadays
|
||||||
# we want clients to generate their own as part of their branding
|
# we want clients to generate their own as part of their branding
|
||||||
|
@ -175,7 +186,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
token=token,
|
token=token,
|
||||||
password_hash=""
|
password_hash=""
|
||||||
)
|
)
|
||||||
registered_user(self.distributor, user)
|
yield registered_user(self.distributor, user)
|
||||||
defer.returnValue((user_id, token))
|
defer.returnValue((user_id, token))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -281,8 +292,16 @@ class RegistrationHandler(BaseHandler):
|
||||||
errcode=Codes.EXCLUSIVE
|
errcode=Codes.EXCLUSIVE
|
||||||
)
|
)
|
||||||
|
|
||||||
def _generate_user_id(self):
|
@defer.inlineCallbacks
|
||||||
return "-" + stringutils.random_string(18)
|
def _generate_user_id(self, reseed=False):
|
||||||
|
if reseed or self._next_generated_user_id is None:
|
||||||
|
self._next_generated_user_id = (
|
||||||
|
yield self.store.find_next_generated_user_id_localpart()
|
||||||
|
)
|
||||||
|
|
||||||
|
id = self._next_generated_user_id
|
||||||
|
self._next_generated_user_id += 1
|
||||||
|
defer.returnValue(str(id))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _validate_captcha(self, ip_addr, private_key, challenge, response):
|
def _validate_captcha(self, ip_addr, private_key, challenge, response):
|
||||||
|
|
|
@ -18,13 +18,14 @@ from twisted.internet import defer
|
||||||
|
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
|
||||||
from synapse.types import UserID, RoomAlias, RoomID
|
from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken
|
||||||
from synapse.api.constants import (
|
from synapse.api.constants import (
|
||||||
EventTypes, Membership, JoinRules, RoomCreationPreset,
|
EventTypes, Membership, JoinRules, RoomCreationPreset,
|
||||||
)
|
)
|
||||||
from synapse.api.errors import AuthError, StoreError, SynapseError, Codes
|
from synapse.api.errors import AuthError, StoreError, SynapseError, Codes
|
||||||
from synapse.util import stringutils, unwrapFirstError
|
from synapse.util import stringutils, unwrapFirstError
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
|
from synapse.util.logcontext import preserve_context_over_fn
|
||||||
|
|
||||||
from signedjson.sign import verify_signed_json
|
from signedjson.sign import verify_signed_json
|
||||||
from signedjson.key import decode_verify_key_bytes
|
from signedjson.key import decode_verify_key_bytes
|
||||||
|
@ -46,11 +47,17 @@ def collect_presencelike_data(distributor, user, content):
|
||||||
|
|
||||||
|
|
||||||
def user_left_room(distributor, user, room_id):
|
def user_left_room(distributor, user, room_id):
|
||||||
return distributor.fire("user_left_room", user=user, room_id=room_id)
|
return preserve_context_over_fn(
|
||||||
|
distributor.fire,
|
||||||
|
"user_left_room", user=user, room_id=room_id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def user_joined_room(distributor, user, room_id):
|
def user_joined_room(distributor, user, room_id):
|
||||||
return distributor.fire("user_joined_room", user=user, room_id=room_id)
|
return preserve_context_over_fn(
|
||||||
|
distributor.fire,
|
||||||
|
"user_joined_room", user=user, room_id=room_id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RoomCreationHandler(BaseHandler):
|
class RoomCreationHandler(BaseHandler):
|
||||||
|
@ -876,39 +883,71 @@ class RoomListHandler(BaseHandler):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_public_room_list(self):
|
def get_public_room_list(self):
|
||||||
chunk = yield self.store.get_rooms(is_public=True)
|
room_ids = yield self.store.get_public_room_ids()
|
||||||
|
|
||||||
room_members = yield defer.gatherResults(
|
|
||||||
[
|
|
||||||
self.store.get_users_in_room(room["room_id"])
|
|
||||||
for room in chunk
|
|
||||||
],
|
|
||||||
consumeErrors=True,
|
|
||||||
).addErrback(unwrapFirstError)
|
|
||||||
|
|
||||||
avatar_urls = yield defer.gatherResults(
|
|
||||||
[
|
|
||||||
self.get_room_avatar_url(room["room_id"])
|
|
||||||
for room in chunk
|
|
||||||
],
|
|
||||||
consumeErrors=True,
|
|
||||||
).addErrback(unwrapFirstError)
|
|
||||||
|
|
||||||
for i, room in enumerate(chunk):
|
|
||||||
room["num_joined_members"] = len(room_members[i])
|
|
||||||
if avatar_urls[i]:
|
|
||||||
room["avatar_url"] = avatar_urls[i]
|
|
||||||
|
|
||||||
# FIXME (erikj): START is no longer a valid value
|
|
||||||
defer.returnValue({"start": "START", "end": "END", "chunk": chunk})
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_room_avatar_url(self, room_id):
|
def handle_room(room_id):
|
||||||
event = yield self.hs.get_state_handler().get_current_state(
|
aliases = yield self.store.get_aliases_for_room(room_id)
|
||||||
room_id, "m.room.avatar"
|
if not aliases:
|
||||||
|
defer.returnValue(None)
|
||||||
|
|
||||||
|
state = yield self.state_handler.get_current_state(room_id)
|
||||||
|
|
||||||
|
result = {"aliases": aliases, "room_id": room_id}
|
||||||
|
|
||||||
|
name_event = state.get((EventTypes.Name, ""), None)
|
||||||
|
if name_event:
|
||||||
|
name = name_event.content.get("name", None)
|
||||||
|
if name:
|
||||||
|
result["name"] = name
|
||||||
|
|
||||||
|
topic_event = state.get((EventTypes.Topic, ""), None)
|
||||||
|
if topic_event:
|
||||||
|
topic = topic_event.content.get("topic", None)
|
||||||
|
if topic:
|
||||||
|
result["topic"] = topic
|
||||||
|
|
||||||
|
canonical_event = state.get((EventTypes.CanonicalAlias, ""), None)
|
||||||
|
if canonical_event:
|
||||||
|
canonical_alias = canonical_event.content.get("alias", None)
|
||||||
|
if canonical_alias:
|
||||||
|
result["canonical_alias"] = canonical_alias
|
||||||
|
|
||||||
|
visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
|
||||||
|
visibility = None
|
||||||
|
if visibility_event:
|
||||||
|
visibility = visibility_event.content.get("history_visibility", None)
|
||||||
|
result["world_readable"] = visibility == "world_readable"
|
||||||
|
|
||||||
|
guest_event = state.get((EventTypes.GuestAccess, ""), None)
|
||||||
|
guest = None
|
||||||
|
if guest_event:
|
||||||
|
guest = guest_event.content.get("guest_access", None)
|
||||||
|
result["guest_can_join"] = guest == "can_join"
|
||||||
|
|
||||||
|
avatar_event = state.get(("m.room.avatar", ""), None)
|
||||||
|
if avatar_event:
|
||||||
|
avatar_url = avatar_event.content.get("url", None)
|
||||||
|
if avatar_url:
|
||||||
|
result["avatar_url"] = avatar_url
|
||||||
|
|
||||||
|
result["num_joined_members"] = sum(
|
||||||
|
1 for (event_type, _), ev in state.items()
|
||||||
|
if event_type == EventTypes.Member and ev.membership == Membership.JOIN
|
||||||
)
|
)
|
||||||
if event and "url" in event.content:
|
|
||||||
defer.returnValue(event.content["url"])
|
defer.returnValue(result)
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for chunk in (room_ids[i:i + 10] for i in xrange(0, len(room_ids), 10)):
|
||||||
|
chunk_result = yield defer.gatherResults([
|
||||||
|
handle_room(room_id)
|
||||||
|
for room_id in chunk
|
||||||
|
], consumeErrors=True).addErrback(unwrapFirstError)
|
||||||
|
result.extend(v for v in chunk_result if v)
|
||||||
|
|
||||||
|
# FIXME (erikj): START is no longer a valid value
|
||||||
|
defer.returnValue({"start": "START", "end": "END", "chunk": result})
|
||||||
|
|
||||||
|
|
||||||
class RoomContextHandler(BaseHandler):
|
class RoomContextHandler(BaseHandler):
|
||||||
|
@ -927,7 +966,7 @@ class RoomContextHandler(BaseHandler):
|
||||||
Returns:
|
Returns:
|
||||||
dict, or None if the event isn't found
|
dict, or None if the event isn't found
|
||||||
"""
|
"""
|
||||||
before_limit = math.floor(limit/2.)
|
before_limit = math.floor(limit / 2.)
|
||||||
after_limit = limit - before_limit
|
after_limit = limit - before_limit
|
||||||
|
|
||||||
now_token = yield self.hs.get_event_sources().get_current_token()
|
now_token = yield self.hs.get_event_sources().get_current_token()
|
||||||
|
@ -997,6 +1036,11 @@ class RoomEventSource(object):
|
||||||
|
|
||||||
to_key = yield self.get_current_key()
|
to_key = yield self.get_current_key()
|
||||||
|
|
||||||
|
from_token = RoomStreamToken.parse(from_key)
|
||||||
|
if from_token.topological:
|
||||||
|
logger.warn("Stream has topological part!!!! %r", from_key)
|
||||||
|
from_key = "s%s" % (from_token.stream,)
|
||||||
|
|
||||||
app_service = yield self.store.get_app_service_by_user_id(
|
app_service = yield self.store.get_app_service_by_user_id(
|
||||||
user.to_string()
|
user.to_string()
|
||||||
)
|
)
|
||||||
|
@ -1008,15 +1052,30 @@ class RoomEventSource(object):
|
||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
events, end_key = yield self.store.get_room_events_stream(
|
room_events = yield self.store.get_membership_changes_for_user(
|
||||||
user_id=user.to_string(),
|
user.to_string(), from_key, to_key
|
||||||
|
)
|
||||||
|
|
||||||
|
room_to_events = yield self.store.get_room_events_stream_for_rooms(
|
||||||
|
room_ids=room_ids,
|
||||||
from_key=from_key,
|
from_key=from_key,
|
||||||
to_key=to_key,
|
to_key=to_key,
|
||||||
limit=limit,
|
limit=limit or 10,
|
||||||
room_ids=room_ids,
|
|
||||||
is_guest=is_guest,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
events = list(room_events)
|
||||||
|
events.extend(e for evs, _ in room_to_events.values() for e in evs)
|
||||||
|
|
||||||
|
events.sort(key=lambda e: e.internal_metadata.order)
|
||||||
|
|
||||||
|
if limit:
|
||||||
|
events[:] = events[:limit]
|
||||||
|
|
||||||
|
if events:
|
||||||
|
end_key = events[-1].internal_metadata.after
|
||||||
|
else:
|
||||||
|
end_key = to_key
|
||||||
|
|
||||||
defer.returnValue((events, end_key))
|
defer.returnValue((events, end_key))
|
||||||
|
|
||||||
def get_current_key(self, direction='f'):
|
def get_current_key(self, direction='f'):
|
||||||
|
|
|
@ -18,11 +18,14 @@ from ._base import BaseHandler
|
||||||
from synapse.streams.config import PaginationConfig
|
from synapse.streams.config import PaginationConfig
|
||||||
from synapse.api.constants import Membership, EventTypes
|
from synapse.api.constants import Membership, EventTypes
|
||||||
from synapse.util import unwrapFirstError
|
from synapse.util import unwrapFirstError
|
||||||
|
from synapse.util.logcontext import LoggingContext, preserve_fn
|
||||||
|
from synapse.util.metrics import Measure
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
import logging
|
import logging
|
||||||
|
import itertools
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -139,6 +142,15 @@ class SyncHandler(BaseHandler):
|
||||||
A Deferred SyncResult.
|
A Deferred SyncResult.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
context = LoggingContext.current_context()
|
||||||
|
if context:
|
||||||
|
if since_token is None:
|
||||||
|
context.tag = "initial_sync"
|
||||||
|
elif full_state:
|
||||||
|
context.tag = "full_state_sync"
|
||||||
|
else:
|
||||||
|
context.tag = "incremental_sync"
|
||||||
|
|
||||||
if timeout == 0 or since_token is None or full_state:
|
if timeout == 0 or since_token is None or full_state:
|
||||||
# we are going to return immediately, so don't bother calling
|
# we are going to return immediately, so don't bother calling
|
||||||
# notifier.wait_for_events.
|
# notifier.wait_for_events.
|
||||||
|
@ -167,18 +179,6 @@ class SyncHandler(BaseHandler):
|
||||||
else:
|
else:
|
||||||
return self.incremental_sync_with_gap(sync_config, since_token)
|
return self.incremental_sync_with_gap(sync_config, since_token)
|
||||||
|
|
||||||
def last_read_event_id_for_room_and_user(self, room_id, user_id, ephemeral_by_room):
|
|
||||||
if room_id not in ephemeral_by_room:
|
|
||||||
return None
|
|
||||||
for e in ephemeral_by_room[room_id]:
|
|
||||||
if e['type'] != 'm.receipt':
|
|
||||||
continue
|
|
||||||
for receipt_event_id, val in e['content'].items():
|
|
||||||
if 'm.read' in val:
|
|
||||||
if user_id in val['m.read']:
|
|
||||||
return receipt_event_id
|
|
||||||
return None
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def full_state_sync(self, sync_config, timeline_since_token):
|
def full_state_sync(self, sync_config, timeline_since_token):
|
||||||
"""Get a sync for a client which is starting without any state.
|
"""Get a sync for a client which is starting without any state.
|
||||||
|
@ -228,9 +228,14 @@ class SyncHandler(BaseHandler):
|
||||||
invited = []
|
invited = []
|
||||||
archived = []
|
archived = []
|
||||||
deferreds = []
|
deferreds = []
|
||||||
for event in room_list:
|
|
||||||
|
room_list_chunks = [room_list[i:i + 10] for i in xrange(0, len(room_list), 10)]
|
||||||
|
for room_list_chunk in room_list_chunks:
|
||||||
|
for event in room_list_chunk:
|
||||||
if event.membership == Membership.JOIN:
|
if event.membership == Membership.JOIN:
|
||||||
room_sync_deferred = self.full_state_sync_for_joined_room(
|
room_sync_deferred = preserve_fn(
|
||||||
|
self.full_state_sync_for_joined_room
|
||||||
|
)(
|
||||||
room_id=event.room_id,
|
room_id=event.room_id,
|
||||||
sync_config=sync_config,
|
sync_config=sync_config,
|
||||||
now_token=now_token,
|
now_token=now_token,
|
||||||
|
@ -251,7 +256,9 @@ class SyncHandler(BaseHandler):
|
||||||
leave_token = now_token.copy_and_replace(
|
leave_token = now_token.copy_and_replace(
|
||||||
"room_key", "s%d" % (event.stream_ordering,)
|
"room_key", "s%d" % (event.stream_ordering,)
|
||||||
)
|
)
|
||||||
room_sync_deferred = self.full_state_sync_for_archived_room(
|
room_sync_deferred = preserve_fn(
|
||||||
|
self.full_state_sync_for_archived_room
|
||||||
|
)(
|
||||||
sync_config=sync_config,
|
sync_config=sync_config,
|
||||||
room_id=event.room_id,
|
room_id=event.room_id,
|
||||||
leave_event_id=event.event_id,
|
leave_event_id=event.event_id,
|
||||||
|
@ -305,7 +312,6 @@ class SyncHandler(BaseHandler):
|
||||||
ephemeral_by_room=ephemeral_by_room,
|
ephemeral_by_room=ephemeral_by_room,
|
||||||
tags_by_room=tags_by_room,
|
tags_by_room=tags_by_room,
|
||||||
account_data_by_room=account_data_by_room,
|
account_data_by_room=account_data_by_room,
|
||||||
all_ephemeral_by_room=ephemeral_by_room,
|
|
||||||
batch=batch,
|
batch=batch,
|
||||||
full_state=True,
|
full_state=True,
|
||||||
)
|
)
|
||||||
|
@ -355,6 +361,7 @@ class SyncHandler(BaseHandler):
|
||||||
typing events for that room.
|
typing events for that room.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
with Measure(self.clock, "ephemeral_by_room"):
|
||||||
typing_key = since_token.typing_key if since_token else "0"
|
typing_key = since_token.typing_key if since_token else "0"
|
||||||
|
|
||||||
rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string())
|
rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string())
|
||||||
|
@ -438,13 +445,6 @@ class SyncHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
now_token = now_token.copy_and_replace("presence_key", presence_key)
|
now_token = now_token.copy_and_replace("presence_key", presence_key)
|
||||||
|
|
||||||
# We now fetch all ephemeral events for this room in order to get
|
|
||||||
# this users current read receipt. This could almost certainly be
|
|
||||||
# optimised.
|
|
||||||
_, all_ephemeral_by_room = yield self.ephemeral_by_room(
|
|
||||||
sync_config, now_token
|
|
||||||
)
|
|
||||||
|
|
||||||
now_token, ephemeral_by_room = yield self.ephemeral_by_room(
|
now_token, ephemeral_by_room = yield self.ephemeral_by_room(
|
||||||
sync_config, now_token, since_token
|
sync_config, now_token, since_token
|
||||||
)
|
)
|
||||||
|
@ -478,7 +478,7 @@ class SyncHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get a list of membership change events that have happened.
|
# Get a list of membership change events that have happened.
|
||||||
rooms_changed = yield self.store.get_room_changes_for_user(
|
rooms_changed = yield self.store.get_membership_changes_for_user(
|
||||||
user_id, since_token.room_key, now_token.room_key
|
user_id, since_token.room_key, now_token.room_key
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -576,7 +576,6 @@ class SyncHandler(BaseHandler):
|
||||||
ephemeral_by_room=ephemeral_by_room,
|
ephemeral_by_room=ephemeral_by_room,
|
||||||
tags_by_room=tags_by_room,
|
tags_by_room=tags_by_room,
|
||||||
account_data_by_room=account_data_by_room,
|
account_data_by_room=account_data_by_room,
|
||||||
all_ephemeral_by_room=all_ephemeral_by_room,
|
|
||||||
batch=batch,
|
batch=batch,
|
||||||
full_state=full_state,
|
full_state=full_state,
|
||||||
)
|
)
|
||||||
|
@ -606,6 +605,7 @@ class SyncHandler(BaseHandler):
|
||||||
"""
|
"""
|
||||||
:returns a Deferred TimelineBatch
|
:returns a Deferred TimelineBatch
|
||||||
"""
|
"""
|
||||||
|
with Measure(self.clock, "load_filtered_recents"):
|
||||||
filtering_factor = 2
|
filtering_factor = 2
|
||||||
timeline_limit = sync_config.filter_collection.timeline_limit()
|
timeline_limit = sync_config.filter_collection.timeline_limit()
|
||||||
load_limit = max(timeline_limit * filtering_factor, 10)
|
load_limit = max(timeline_limit * filtering_factor, 10)
|
||||||
|
@ -613,7 +613,10 @@ class SyncHandler(BaseHandler):
|
||||||
room_key = now_token.room_key
|
room_key = now_token.room_key
|
||||||
end_key = room_key
|
end_key = room_key
|
||||||
|
|
||||||
limited = recents is None or newly_joined_room or timeline_limit < len(recents)
|
if recents is None or newly_joined_room or timeline_limit < len(recents):
|
||||||
|
limited = True
|
||||||
|
else:
|
||||||
|
limited = False
|
||||||
|
|
||||||
if recents is not None:
|
if recents is not None:
|
||||||
recents = sync_config.filter_collection.filter_room_timeline(recents)
|
recents = sync_config.filter_collection.filter_room_timeline(recents)
|
||||||
|
@ -636,7 +639,9 @@ class SyncHandler(BaseHandler):
|
||||||
from_key=since_key,
|
from_key=since_key,
|
||||||
to_key=end_key,
|
to_key=end_key,
|
||||||
)
|
)
|
||||||
loaded_recents = sync_config.filter_collection.filter_room_timeline(events)
|
loaded_recents = sync_config.filter_collection.filter_room_timeline(
|
||||||
|
events
|
||||||
|
)
|
||||||
loaded_recents = yield self._filter_events_for_client(
|
loaded_recents = yield self._filter_events_for_client(
|
||||||
sync_config.user.to_string(),
|
sync_config.user.to_string(),
|
||||||
loaded_recents,
|
loaded_recents,
|
||||||
|
@ -670,37 +675,11 @@ class SyncHandler(BaseHandler):
|
||||||
since_token, now_token,
|
since_token, now_token,
|
||||||
ephemeral_by_room, tags_by_room,
|
ephemeral_by_room, tags_by_room,
|
||||||
account_data_by_room,
|
account_data_by_room,
|
||||||
all_ephemeral_by_room,
|
|
||||||
batch, full_state=False):
|
batch, full_state=False):
|
||||||
if full_state:
|
|
||||||
state = yield self.get_state_at(room_id, now_token)
|
|
||||||
|
|
||||||
elif batch.limited:
|
|
||||||
current_state = yield self.get_state_at(room_id, now_token)
|
|
||||||
|
|
||||||
state_at_previous_sync = yield self.get_state_at(
|
|
||||||
room_id, stream_position=since_token
|
|
||||||
)
|
|
||||||
|
|
||||||
state = yield self.compute_state_delta(
|
state = yield self.compute_state_delta(
|
||||||
since_token=since_token,
|
room_id, batch, sync_config, since_token, now_token,
|
||||||
previous_state=state_at_previous_sync,
|
full_state=full_state
|
||||||
current_state=current_state,
|
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
state = {
|
|
||||||
(event.type, event.state_key): event
|
|
||||||
for event in batch.events if event.is_state()
|
|
||||||
}
|
|
||||||
|
|
||||||
just_joined = yield self.check_joined_room(sync_config, state)
|
|
||||||
if just_joined:
|
|
||||||
state = yield self.get_state_at(room_id, now_token)
|
|
||||||
|
|
||||||
state = {
|
|
||||||
(e.type, e.state_key): e
|
|
||||||
for e in sync_config.filter_collection.filter_room_state(state.values())
|
|
||||||
}
|
|
||||||
|
|
||||||
account_data = self.account_data_for_room(
|
account_data = self.account_data_for_room(
|
||||||
room_id, tags_by_room, account_data_by_room
|
room_id, tags_by_room, account_data_by_room
|
||||||
|
@ -726,14 +705,12 @@ class SyncHandler(BaseHandler):
|
||||||
|
|
||||||
if room_sync:
|
if room_sync:
|
||||||
notifs = yield self.unread_notifs_for_room_id(
|
notifs = yield self.unread_notifs_for_room_id(
|
||||||
room_id, sync_config, all_ephemeral_by_room
|
room_id, sync_config
|
||||||
)
|
)
|
||||||
|
|
||||||
if notifs is not None:
|
if notifs is not None:
|
||||||
unread_notifications["notification_count"] = len(notifs)
|
unread_notifications["notification_count"] = notifs["notify_count"]
|
||||||
unread_notifications["highlight_count"] = len([
|
unread_notifications["highlight_count"] = notifs["highlight_count"]
|
||||||
1 for notif in notifs if _action_has_highlight(notif["actions"])
|
|
||||||
])
|
|
||||||
|
|
||||||
logger.debug("Room sync: %r", room_sync)
|
logger.debug("Room sync: %r", room_sync)
|
||||||
|
|
||||||
|
@ -766,29 +743,10 @@ class SyncHandler(BaseHandler):
|
||||||
|
|
||||||
logger.debug("Recents %r", batch)
|
logger.debug("Recents %r", batch)
|
||||||
|
|
||||||
state_events_at_leave = yield self.store.get_state_for_event(
|
|
||||||
leave_event_id
|
|
||||||
)
|
|
||||||
|
|
||||||
if not full_state:
|
|
||||||
state_at_previous_sync = yield self.get_state_at(
|
|
||||||
room_id, stream_position=since_token
|
|
||||||
)
|
|
||||||
|
|
||||||
state_events_delta = yield self.compute_state_delta(
|
state_events_delta = yield self.compute_state_delta(
|
||||||
since_token=since_token,
|
room_id, batch, sync_config, since_token, leave_token,
|
||||||
previous_state=state_at_previous_sync,
|
full_state=full_state
|
||||||
current_state=state_events_at_leave,
|
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
state_events_delta = state_events_at_leave
|
|
||||||
|
|
||||||
state_events_delta = {
|
|
||||||
(e.type, e.state_key): e
|
|
||||||
for e in sync_config.filter_collection.filter_room_state(
|
|
||||||
state_events_delta.values()
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
account_data = self.account_data_for_room(
|
account_data = self.account_data_for_room(
|
||||||
room_id, tags_by_room, account_data_by_room
|
room_id, tags_by_room, account_data_by_room
|
||||||
|
@ -843,15 +801,19 @@ class SyncHandler(BaseHandler):
|
||||||
state = {}
|
state = {}
|
||||||
defer.returnValue(state)
|
defer.returnValue(state)
|
||||||
|
|
||||||
def compute_state_delta(self, since_token, previous_state, current_state):
|
@defer.inlineCallbacks
|
||||||
""" Works out the differnce in state between the current state and the
|
def compute_state_delta(self, room_id, batch, sync_config, since_token, now_token,
|
||||||
state the client got when it last performed a sync.
|
full_state):
|
||||||
|
""" Works out the differnce in state between the start of the timeline
|
||||||
|
and the previous sync.
|
||||||
|
|
||||||
:param str since_token: the point we are comparing against
|
:param str room_id
|
||||||
:param dict[(str,str), synapse.events.FrozenEvent] previous_state: the
|
:param TimelineBatch batch: The timeline batch for the room that will
|
||||||
state to compare to
|
be sent to the user.
|
||||||
:param dict[(str,str), synapse.events.FrozenEvent] current_state: the
|
:param sync_config
|
||||||
new state
|
:param str since_token: Token of the end of the previous batch. May be None.
|
||||||
|
:param str now_token: Token of the end of the current batch.
|
||||||
|
:param bool full_state: Whether to force returning the full state.
|
||||||
|
|
||||||
:returns A new event dictionary
|
:returns A new event dictionary
|
||||||
"""
|
"""
|
||||||
|
@ -860,12 +822,53 @@ class SyncHandler(BaseHandler):
|
||||||
# updates even if they occured logically before the previous event.
|
# updates even if they occured logically before the previous event.
|
||||||
# TODO(mjark) Check for new redactions in the state events.
|
# TODO(mjark) Check for new redactions in the state events.
|
||||||
|
|
||||||
state_delta = {}
|
with Measure(self.clock, "compute_state_delta"):
|
||||||
for key, event in current_state.iteritems():
|
if full_state:
|
||||||
if (key not in previous_state or
|
if batch:
|
||||||
previous_state[key].event_id != event.event_id):
|
state = yield self.store.get_state_for_event(
|
||||||
state_delta[key] = event
|
batch.events[0].event_id
|
||||||
return state_delta
|
)
|
||||||
|
else:
|
||||||
|
state = yield self.get_state_at(
|
||||||
|
room_id, stream_position=now_token
|
||||||
|
)
|
||||||
|
|
||||||
|
timeline_state = {
|
||||||
|
(event.type, event.state_key): event
|
||||||
|
for event in batch.events if event.is_state()
|
||||||
|
}
|
||||||
|
|
||||||
|
state = _calculate_state(
|
||||||
|
timeline_contains=timeline_state,
|
||||||
|
timeline_start=state,
|
||||||
|
previous={},
|
||||||
|
)
|
||||||
|
elif batch.limited:
|
||||||
|
state_at_previous_sync = yield self.get_state_at(
|
||||||
|
room_id, stream_position=since_token
|
||||||
|
)
|
||||||
|
|
||||||
|
state_at_timeline_start = yield self.store.get_state_for_event(
|
||||||
|
batch.events[0].event_id
|
||||||
|
)
|
||||||
|
|
||||||
|
timeline_state = {
|
||||||
|
(event.type, event.state_key): event
|
||||||
|
for event in batch.events if event.is_state()
|
||||||
|
}
|
||||||
|
|
||||||
|
state = _calculate_state(
|
||||||
|
timeline_contains=timeline_state,
|
||||||
|
timeline_start=state_at_timeline_start,
|
||||||
|
previous=state_at_previous_sync,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
state = {}
|
||||||
|
|
||||||
|
defer.returnValue({
|
||||||
|
(e.type, e.state_key): e
|
||||||
|
for e in sync_config.filter_collection.filter_room_state(state.values())
|
||||||
|
})
|
||||||
|
|
||||||
def check_joined_room(self, sync_config, state_delta):
|
def check_joined_room(self, sync_config, state_delta):
|
||||||
"""
|
"""
|
||||||
|
@ -886,9 +889,12 @@ class SyncHandler(BaseHandler):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def unread_notifs_for_room_id(self, room_id, sync_config, ephemeral_by_room):
|
def unread_notifs_for_room_id(self, room_id, sync_config):
|
||||||
last_unread_event_id = self.last_read_event_id_for_room_and_user(
|
with Measure(self.clock, "unread_notifs_for_room_id"):
|
||||||
room_id, sync_config.user.to_string(), ephemeral_by_room
|
last_unread_event_id = yield self.store.get_last_receipt_event_id_for_user(
|
||||||
|
user_id=sync_config.user.to_string(),
|
||||||
|
room_id=room_id,
|
||||||
|
receipt_type="m.read"
|
||||||
)
|
)
|
||||||
|
|
||||||
notifs = []
|
notifs = []
|
||||||
|
@ -912,3 +918,37 @@ def _action_has_highlight(actions):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _calculate_state(timeline_contains, timeline_start, previous):
|
||||||
|
"""Works out what state to include in a sync response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeline_contains (dict): state in the timeline
|
||||||
|
timeline_start (dict): state at the start of the timeline
|
||||||
|
previous (dict): state at the end of the previous sync (or empty dict
|
||||||
|
if this is an initial sync)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict
|
||||||
|
"""
|
||||||
|
event_id_to_state = {
|
||||||
|
e.event_id: e
|
||||||
|
for e in itertools.chain(
|
||||||
|
timeline_contains.values(),
|
||||||
|
previous.values(),
|
||||||
|
timeline_start.values(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
tc_ids = set(e.event_id for e in timeline_contains.values())
|
||||||
|
p_ids = set(e.event_id for e in previous.values())
|
||||||
|
ts_ids = set(e.event_id for e in timeline_start.values())
|
||||||
|
|
||||||
|
state_ids = (ts_ids - p_ids) - tc_ids
|
||||||
|
|
||||||
|
evs = (event_id_to_state[e] for e in state_ids)
|
||||||
|
return {
|
||||||
|
(e.type, e.state_key): e
|
||||||
|
for e in evs
|
||||||
|
}
|
||||||
|
|
|
@ -19,6 +19,7 @@ from ._base import BaseHandler
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError, AuthError
|
from synapse.api.errors import SynapseError, AuthError
|
||||||
from synapse.util.logcontext import PreserveLoggingContext
|
from synapse.util.logcontext import PreserveLoggingContext
|
||||||
|
from synapse.util.metrics import Measure
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
@ -222,6 +223,7 @@ class TypingNotificationHandler(BaseHandler):
|
||||||
class TypingNotificationEventSource(object):
|
class TypingNotificationEventSource(object):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
|
self.clock = hs.get_clock()
|
||||||
self._handler = None
|
self._handler = None
|
||||||
self._room_member_handler = None
|
self._room_member_handler = None
|
||||||
|
|
||||||
|
@ -247,6 +249,7 @@ class TypingNotificationEventSource(object):
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_new_events(self, from_key, room_ids, **kwargs):
|
def get_new_events(self, from_key, room_ids, **kwargs):
|
||||||
|
with Measure(self.clock, "typing.get_new_events"):
|
||||||
from_key = int(from_key)
|
from_key = int(from_key)
|
||||||
handler = self.handler()
|
handler = self.handler()
|
||||||
|
|
||||||
|
|
|
@ -152,7 +152,7 @@ class MatrixFederationHttpClient(object):
|
||||||
|
|
||||||
return self.clock.time_bound_deferred(
|
return self.clock.time_bound_deferred(
|
||||||
request_deferred,
|
request_deferred,
|
||||||
time_out=timeout/1000. if timeout else 60,
|
time_out=timeout / 1000. if timeout else 60,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = yield preserve_context_over_fn(
|
response = yield preserve_context_over_fn(
|
||||||
|
|
|
@ -41,7 +41,7 @@ metrics = synapse.metrics.get_metrics_for(__name__)
|
||||||
|
|
||||||
incoming_requests_counter = metrics.register_counter(
|
incoming_requests_counter = metrics.register_counter(
|
||||||
"requests",
|
"requests",
|
||||||
labels=["method", "servlet"],
|
labels=["method", "servlet", "tag"],
|
||||||
)
|
)
|
||||||
outgoing_responses_counter = metrics.register_counter(
|
outgoing_responses_counter = metrics.register_counter(
|
||||||
"responses",
|
"responses",
|
||||||
|
@ -50,23 +50,23 @@ outgoing_responses_counter = metrics.register_counter(
|
||||||
|
|
||||||
response_timer = metrics.register_distribution(
|
response_timer = metrics.register_distribution(
|
||||||
"response_time",
|
"response_time",
|
||||||
labels=["method", "servlet"]
|
labels=["method", "servlet", "tag"]
|
||||||
)
|
)
|
||||||
|
|
||||||
response_ru_utime = metrics.register_distribution(
|
response_ru_utime = metrics.register_distribution(
|
||||||
"response_ru_utime", labels=["method", "servlet"]
|
"response_ru_utime", labels=["method", "servlet", "tag"]
|
||||||
)
|
)
|
||||||
|
|
||||||
response_ru_stime = metrics.register_distribution(
|
response_ru_stime = metrics.register_distribution(
|
||||||
"response_ru_stime", labels=["method", "servlet"]
|
"response_ru_stime", labels=["method", "servlet", "tag"]
|
||||||
)
|
)
|
||||||
|
|
||||||
response_db_txn_count = metrics.register_distribution(
|
response_db_txn_count = metrics.register_distribution(
|
||||||
"response_db_txn_count", labels=["method", "servlet"]
|
"response_db_txn_count", labels=["method", "servlet", "tag"]
|
||||||
)
|
)
|
||||||
|
|
||||||
response_db_txn_duration = metrics.register_distribution(
|
response_db_txn_duration = metrics.register_distribution(
|
||||||
"response_db_txn_duration", labels=["method", "servlet"]
|
"response_db_txn_duration", labels=["method", "servlet", "tag"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -99,9 +99,8 @@ def request_handler(request_handler):
|
||||||
request_context.request = request_id
|
request_context.request = request_id
|
||||||
with request.processing():
|
with request.processing():
|
||||||
try:
|
try:
|
||||||
d = request_handler(self, request)
|
with PreserveLoggingContext(request_context):
|
||||||
with PreserveLoggingContext():
|
yield request_handler(self, request)
|
||||||
yield d
|
|
||||||
except CodeMessageException as e:
|
except CodeMessageException as e:
|
||||||
code = e.code
|
code = e.code
|
||||||
if isinstance(e, SynapseError):
|
if isinstance(e, SynapseError):
|
||||||
|
@ -208,6 +207,9 @@ class JsonResource(HttpServer, resource.Resource):
|
||||||
if request.method == "OPTIONS":
|
if request.method == "OPTIONS":
|
||||||
self._send_response(request, 200, {})
|
self._send_response(request, 200, {})
|
||||||
return
|
return
|
||||||
|
|
||||||
|
start_context = LoggingContext.current_context()
|
||||||
|
|
||||||
# Loop through all the registered callbacks to check if the method
|
# Loop through all the registered callbacks to check if the method
|
||||||
# and path regex match
|
# and path regex match
|
||||||
for path_entry in self.path_regexs.get(request.method, []):
|
for path_entry in self.path_regexs.get(request.method, []):
|
||||||
|
@ -226,7 +228,6 @@ class JsonResource(HttpServer, resource.Resource):
|
||||||
servlet_classname = servlet_instance.__class__.__name__
|
servlet_classname = servlet_instance.__class__.__name__
|
||||||
else:
|
else:
|
||||||
servlet_classname = "%r" % callback
|
servlet_classname = "%r" % callback
|
||||||
incoming_requests_counter.inc(request.method, servlet_classname)
|
|
||||||
|
|
||||||
args = [
|
args = [
|
||||||
urllib.unquote(u).decode("UTF-8") if u else u for u in m.groups()
|
urllib.unquote(u).decode("UTF-8") if u else u for u in m.groups()
|
||||||
|
@ -237,21 +238,40 @@ class JsonResource(HttpServer, resource.Resource):
|
||||||
code, response = callback_return
|
code, response = callback_return
|
||||||
self._send_response(request, code, response)
|
self._send_response(request, code, response)
|
||||||
|
|
||||||
response_timer.inc_by(
|
|
||||||
self.clock.time_msec() - start, request.method, servlet_classname
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
context = LoggingContext.current_context()
|
context = LoggingContext.current_context()
|
||||||
|
|
||||||
|
tag = ""
|
||||||
|
if context:
|
||||||
|
tag = context.tag
|
||||||
|
|
||||||
|
if context != start_context:
|
||||||
|
logger.warn(
|
||||||
|
"Context have unexpectedly changed %r, %r",
|
||||||
|
context, self.start_context
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
incoming_requests_counter.inc(request.method, servlet_classname, tag)
|
||||||
|
|
||||||
|
response_timer.inc_by(
|
||||||
|
self.clock.time_msec() - start, request.method,
|
||||||
|
servlet_classname, tag
|
||||||
|
)
|
||||||
|
|
||||||
ru_utime, ru_stime = context.get_resource_usage()
|
ru_utime, ru_stime = context.get_resource_usage()
|
||||||
|
|
||||||
response_ru_utime.inc_by(ru_utime, request.method, servlet_classname)
|
response_ru_utime.inc_by(
|
||||||
response_ru_stime.inc_by(ru_stime, request.method, servlet_classname)
|
ru_utime, request.method, servlet_classname, tag
|
||||||
|
)
|
||||||
|
response_ru_stime.inc_by(
|
||||||
|
ru_stime, request.method, servlet_classname, tag
|
||||||
|
)
|
||||||
response_db_txn_count.inc_by(
|
response_db_txn_count.inc_by(
|
||||||
context.db_txn_count, request.method, servlet_classname
|
context.db_txn_count, request.method, servlet_classname, tag
|
||||||
)
|
)
|
||||||
response_db_txn_duration.inc_by(
|
response_db_txn_duration.inc_by(
|
||||||
context.db_txn_duration, request.method, servlet_classname
|
context.db_txn_duration, request.method, servlet_classname, tag
|
||||||
)
|
)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -18,10 +18,13 @@ from synapse.api.constants import EventTypes
|
||||||
from synapse.api.errors import AuthError
|
from synapse.api.errors import AuthError
|
||||||
|
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from synapse.util.async import run_on_reactor, ObservableDeferred
|
from synapse.util.async import ObservableDeferred
|
||||||
|
from synapse.util.logcontext import PreserveLoggingContext
|
||||||
from synapse.types import StreamToken
|
from synapse.types import StreamToken
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
|
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
@ -71,6 +74,7 @@ class _NotifierUserStream(object):
|
||||||
self.current_token = current_token
|
self.current_token = current_token
|
||||||
self.last_notified_ms = time_now_ms
|
self.last_notified_ms = time_now_ms
|
||||||
|
|
||||||
|
with PreserveLoggingContext():
|
||||||
self.notify_deferred = ObservableDeferred(defer.Deferred())
|
self.notify_deferred = ObservableDeferred(defer.Deferred())
|
||||||
|
|
||||||
def notify(self, stream_key, stream_id, time_now_ms):
|
def notify(self, stream_key, stream_id, time_now_ms):
|
||||||
|
@ -86,6 +90,8 @@ class _NotifierUserStream(object):
|
||||||
)
|
)
|
||||||
self.last_notified_ms = time_now_ms
|
self.last_notified_ms = time_now_ms
|
||||||
noify_deferred = self.notify_deferred
|
noify_deferred = self.notify_deferred
|
||||||
|
|
||||||
|
with PreserveLoggingContext():
|
||||||
self.notify_deferred = ObservableDeferred(defer.Deferred())
|
self.notify_deferred = ObservableDeferred(defer.Deferred())
|
||||||
noify_deferred.callback(self.current_token)
|
noify_deferred.callback(self.current_token)
|
||||||
|
|
||||||
|
@ -118,6 +124,11 @@ class _NotifierUserStream(object):
|
||||||
return _NotificationListener(self.notify_deferred.observe())
|
return _NotificationListener(self.notify_deferred.observe())
|
||||||
|
|
||||||
|
|
||||||
|
class EventStreamResult(namedtuple("EventStreamResult", ("events", "tokens"))):
|
||||||
|
def __nonzero__(self):
|
||||||
|
return bool(self.events)
|
||||||
|
|
||||||
|
|
||||||
class Notifier(object):
|
class Notifier(object):
|
||||||
""" This class is responsible for notifying any listeners when there are
|
""" This class is responsible for notifying any listeners when there are
|
||||||
new events available for it.
|
new events available for it.
|
||||||
|
@ -177,8 +188,6 @@ class Notifier(object):
|
||||||
lambda: count(bool, self.appservice_to_user_streams.values()),
|
lambda: count(bool, self.appservice_to_user_streams.values()),
|
||||||
)
|
)
|
||||||
|
|
||||||
@log_function
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def on_new_room_event(self, event, room_stream_id, max_room_stream_id,
|
def on_new_room_event(self, event, room_stream_id, max_room_stream_id,
|
||||||
extra_users=[]):
|
extra_users=[]):
|
||||||
""" Used by handlers to inform the notifier something has happened
|
""" Used by handlers to inform the notifier something has happened
|
||||||
|
@ -192,8 +201,7 @@ class Notifier(object):
|
||||||
until all previous events have been persisted before notifying
|
until all previous events have been persisted before notifying
|
||||||
the client streams.
|
the client streams.
|
||||||
"""
|
"""
|
||||||
yield run_on_reactor()
|
with PreserveLoggingContext():
|
||||||
|
|
||||||
self.pending_new_room_events.append((
|
self.pending_new_room_events.append((
|
||||||
room_stream_id, event, extra_users
|
room_stream_id, event, extra_users
|
||||||
))
|
))
|
||||||
|
@ -244,15 +252,13 @@ class Notifier(object):
|
||||||
extra_streams=app_streams,
|
extra_streams=app_streams,
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
@log_function
|
|
||||||
def on_new_event(self, stream_key, new_token, users=[], rooms=[],
|
def on_new_event(self, stream_key, new_token, users=[], rooms=[],
|
||||||
extra_streams=set()):
|
extra_streams=set()):
|
||||||
""" Used to inform listeners that something has happend event wise.
|
""" Used to inform listeners that something has happend event wise.
|
||||||
|
|
||||||
Will wake up all listeners for the given users and rooms.
|
Will wake up all listeners for the given users and rooms.
|
||||||
"""
|
"""
|
||||||
yield run_on_reactor()
|
with PreserveLoggingContext():
|
||||||
user_streams = set()
|
user_streams = set()
|
||||||
|
|
||||||
for user in users:
|
for user in users:
|
||||||
|
@ -301,7 +307,7 @@ class Notifier(object):
|
||||||
def timed_out():
|
def timed_out():
|
||||||
if listener:
|
if listener:
|
||||||
listener.deferred.cancel()
|
listener.deferred.cancel()
|
||||||
timer = self.clock.call_later(timeout/1000., timed_out)
|
timer = self.clock.call_later(timeout / 1000., timed_out)
|
||||||
|
|
||||||
prev_token = from_token
|
prev_token = from_token
|
||||||
while not result:
|
while not result:
|
||||||
|
@ -318,6 +324,7 @@ class Notifier(object):
|
||||||
# that we don't miss any current_token updates.
|
# that we don't miss any current_token updates.
|
||||||
prev_token = current_token
|
prev_token = current_token
|
||||||
listener = user_stream.new_listener(prev_token)
|
listener = user_stream.new_listener(prev_token)
|
||||||
|
with PreserveLoggingContext():
|
||||||
yield listener.deferred
|
yield listener.deferred
|
||||||
except defer.CancelledError:
|
except defer.CancelledError:
|
||||||
break
|
break
|
||||||
|
@ -356,7 +363,7 @@ class Notifier(object):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def check_for_updates(before_token, after_token):
|
def check_for_updates(before_token, after_token):
|
||||||
if not after_token.is_after(before_token):
|
if not after_token.is_after(before_token):
|
||||||
defer.returnValue(None)
|
defer.returnValue(EventStreamResult([], (from_token, from_token)))
|
||||||
|
|
||||||
events = []
|
events = []
|
||||||
end_token = from_token
|
end_token = from_token
|
||||||
|
@ -369,6 +376,7 @@ class Notifier(object):
|
||||||
continue
|
continue
|
||||||
if only_keys and name not in only_keys:
|
if only_keys and name not in only_keys:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
new_events, new_key = yield source.get_new_events(
|
new_events, new_key = yield source.get_new_events(
|
||||||
user=user,
|
user=user,
|
||||||
from_key=getattr(from_token, keyname),
|
from_key=getattr(from_token, keyname),
|
||||||
|
@ -388,10 +396,7 @@ class Notifier(object):
|
||||||
events.extend(new_events)
|
events.extend(new_events)
|
||||||
end_token = end_token.copy_and_replace(keyname, new_key)
|
end_token = end_token.copy_and_replace(keyname, new_key)
|
||||||
|
|
||||||
if events:
|
defer.returnValue(EventStreamResult(events, (from_token, end_token)))
|
||||||
defer.returnValue((events, (from_token, end_token)))
|
|
||||||
else:
|
|
||||||
defer.returnValue(None)
|
|
||||||
|
|
||||||
user_id_for_stream = user.to_string()
|
user_id_for_stream = user.to_string()
|
||||||
if is_peeking:
|
if is_peeking:
|
||||||
|
@ -415,9 +420,6 @@ class Notifier(object):
|
||||||
from_token=from_token,
|
from_token=from_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
if result is None:
|
|
||||||
result = ([], (from_token, from_token))
|
|
||||||
|
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
|
@ -17,6 +17,8 @@ from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.streams.config import PaginationConfig
|
from synapse.streams.config import PaginationConfig
|
||||||
from synapse.types import StreamToken
|
from synapse.types import StreamToken
|
||||||
|
from synapse.util.logcontext import LoggingContext
|
||||||
|
from synapse.util.metrics import Measure
|
||||||
|
|
||||||
import synapse.util.async
|
import synapse.util.async
|
||||||
import push_rule_evaluator as push_rule_evaluator
|
import push_rule_evaluator as push_rule_evaluator
|
||||||
|
@ -27,6 +29,16 @@ import random
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
_NEXT_ID = 1
|
||||||
|
|
||||||
|
|
||||||
|
def _get_next_id():
|
||||||
|
global _NEXT_ID
|
||||||
|
_id = _NEXT_ID
|
||||||
|
_NEXT_ID += 1
|
||||||
|
return _id
|
||||||
|
|
||||||
|
|
||||||
# Pushers could now be moved to pull out of the event_push_actions table instead
|
# Pushers could now be moved to pull out of the event_push_actions table instead
|
||||||
# of listening on the event stream: this would avoid them having to run the
|
# of listening on the event stream: this would avoid them having to run the
|
||||||
# rules again.
|
# rules again.
|
||||||
|
@ -57,6 +69,8 @@ class Pusher(object):
|
||||||
self.alive = True
|
self.alive = True
|
||||||
self.badge = None
|
self.badge = None
|
||||||
|
|
||||||
|
self.name = "Pusher-%d" % (_get_next_id(),)
|
||||||
|
|
||||||
# The last value of last_active_time that we saw
|
# The last value of last_active_time that we saw
|
||||||
self.last_last_active_time = 0
|
self.last_last_active_time = 0
|
||||||
self.has_unread = True
|
self.has_unread = True
|
||||||
|
@ -86,6 +100,7 @@ class Pusher(object):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def start(self):
|
def start(self):
|
||||||
|
with LoggingContext(self.name):
|
||||||
if not self.last_token:
|
if not self.last_token:
|
||||||
# First-time setup: get a token to start from (we can't
|
# First-time setup: get a token to start from (we can't
|
||||||
# just start from no token, ie. 'now'
|
# just start from no token, ie. 'now'
|
||||||
|
@ -96,17 +111,24 @@ class Pusher(object):
|
||||||
self.user_id, config, timeout=0, affect_presence=False
|
self.user_id, config, timeout=0, affect_presence=False
|
||||||
)
|
)
|
||||||
self.last_token = chunk['end']
|
self.last_token = chunk['end']
|
||||||
self.store.update_pusher_last_token(
|
yield self.store.update_pusher_last_token(
|
||||||
self.app_id, self.pushkey, self.user_id, self.last_token
|
self.app_id, self.pushkey, self.user_id, self.last_token
|
||||||
)
|
)
|
||||||
logger.info("Pusher %s for user %s starting from token %s",
|
logger.info("New pusher %s for user %s starting from token %s",
|
||||||
self.pushkey, self.user_id, self.last_token)
|
self.pushkey, self.user_id, self.last_token)
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"Old pusher %s for user %s starting",
|
||||||
|
self.pushkey, self.user_id,
|
||||||
|
)
|
||||||
|
|
||||||
wait = 0
|
wait = 0
|
||||||
while self.alive:
|
while self.alive:
|
||||||
try:
|
try:
|
||||||
if wait > 0:
|
if wait > 0:
|
||||||
yield synapse.util.async.sleep(wait)
|
yield synapse.util.async.sleep(wait)
|
||||||
|
with Measure(self.clock, "push"):
|
||||||
yield self.get_and_dispatch()
|
yield self.get_and_dispatch()
|
||||||
wait = 0
|
wait = 0
|
||||||
except:
|
except:
|
||||||
|
@ -316,7 +338,7 @@ class Pusher(object):
|
||||||
r.room_id, self.user_id, last_unread_event_id
|
r.room_id, self.user_id, last_unread_event_id
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
badge += len(notifs)
|
badge += notifs["notify_count"]
|
||||||
defer.returnValue(badge)
|
defer.returnValue(badge)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -19,8 +19,6 @@ import bulk_push_rule_evaluator
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -36,21 +34,15 @@ class ActionGenerator:
|
||||||
# tag (ie. we just need all the users).
|
# tag (ie. we just need all the users).
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def handle_push_actions_for_event(self, event, handler):
|
def handle_push_actions_for_event(self, event, context, handler):
|
||||||
if event.type == EventTypes.Redaction and event.redacts is not None:
|
|
||||||
yield self.store.remove_push_actions_for_event_id(
|
|
||||||
event.room_id, event.redacts
|
|
||||||
)
|
|
||||||
|
|
||||||
bulk_evaluator = yield bulk_push_rule_evaluator.evaluator_for_room_id(
|
bulk_evaluator = yield bulk_push_rule_evaluator.evaluator_for_room_id(
|
||||||
event.room_id, self.hs, self.store
|
event.room_id, self.hs, self.store
|
||||||
)
|
)
|
||||||
|
|
||||||
actions_by_user = yield bulk_evaluator.action_for_event_by_user(event, handler)
|
actions_by_user = yield bulk_evaluator.action_for_event_by_user(
|
||||||
|
event, handler, context.current_state
|
||||||
|
)
|
||||||
|
|
||||||
yield self.store.set_push_actions_for_event_and_users(
|
context.push_actions = [
|
||||||
event,
|
|
||||||
[
|
|
||||||
(uid, None, actions) for uid, actions in actions_by_user.items()
|
(uid, None, actions) for uid, actions in actions_by_user.items()
|
||||||
]
|
]
|
||||||
)
|
|
||||||
|
|
|
@ -98,25 +98,21 @@ class BulkPushRuleEvaluator:
|
||||||
self.store = store
|
self.store = store
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def action_for_event_by_user(self, event, handler):
|
def action_for_event_by_user(self, event, handler, current_state):
|
||||||
actions_by_user = {}
|
actions_by_user = {}
|
||||||
|
|
||||||
users_dict = yield self.store.are_guests(self.rules_by_user.keys())
|
users_dict = yield self.store.are_guests(self.rules_by_user.keys())
|
||||||
|
|
||||||
filtered_by_user = yield handler._filter_events_for_clients(
|
filtered_by_user = yield handler._filter_events_for_clients(
|
||||||
users_dict.items(), [event]
|
users_dict.items(), [event], {event.event_id: current_state}
|
||||||
)
|
)
|
||||||
|
|
||||||
evaluator = PushRuleEvaluatorForEvent(event, len(self.users_in_room))
|
evaluator = PushRuleEvaluatorForEvent(event, len(self.users_in_room))
|
||||||
|
|
||||||
condition_cache = {}
|
condition_cache = {}
|
||||||
|
|
||||||
member_state = yield self.store.get_state_for_event(
|
|
||||||
event.event_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
display_names = {}
|
display_names = {}
|
||||||
for ev in member_state.values():
|
for ev in current_state.values():
|
||||||
nm = ev.content.get("displayname", None)
|
nm = ev.content.get("displayname", None)
|
||||||
if nm and ev.type == EventTypes.Member:
|
if nm and ev.type == EventTypes.Member:
|
||||||
display_names[ev.state_key] = nm
|
display_names[ev.state_key] = nm
|
||||||
|
|
|
@ -304,7 +304,7 @@ def _flatten_dict(d, prefix=[], result={}):
|
||||||
if isinstance(value, basestring):
|
if isinstance(value, basestring):
|
||||||
result[".".join(prefix + [key])] = value.lower()
|
result[".".join(prefix + [key])] = value.lower()
|
||||||
elif hasattr(value, "items"):
|
elif hasattr(value, "items"):
|
||||||
_flatten_dict(value, prefix=(prefix+[key]), result=result)
|
_flatten_dict(value, prefix=(prefix + [key]), result=result)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@ from twisted.internet import defer
|
||||||
|
|
||||||
from httppusher import HttpPusher
|
from httppusher import HttpPusher
|
||||||
from synapse.push import PusherConfigException
|
from synapse.push import PusherConfigException
|
||||||
|
from synapse.util.logcontext import preserve_fn
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
@ -76,7 +77,7 @@ class PusherPool:
|
||||||
"Removing pusher for app id %s, pushkey %s, user %s",
|
"Removing pusher for app id %s, pushkey %s, user %s",
|
||||||
app_id, pushkey, p['user_name']
|
app_id, pushkey, p['user_name']
|
||||||
)
|
)
|
||||||
self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
|
yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def remove_pushers_by_user(self, user_id):
|
def remove_pushers_by_user(self, user_id):
|
||||||
|
@ -91,7 +92,7 @@ class PusherPool:
|
||||||
"Removing pusher for app id %s, pushkey %s, user %s",
|
"Removing pusher for app id %s, pushkey %s, user %s",
|
||||||
p['app_id'], p['pushkey'], p['user_name']
|
p['app_id'], p['pushkey'], p['user_name']
|
||||||
)
|
)
|
||||||
self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
|
yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _add_pusher_to_store(self, user_id, access_token, profile_tag, kind,
|
def _add_pusher_to_store(self, user_id, access_token, profile_tag, kind,
|
||||||
|
@ -110,7 +111,7 @@ class PusherPool:
|
||||||
lang=lang,
|
lang=lang,
|
||||||
data=data,
|
data=data,
|
||||||
)
|
)
|
||||||
self._refresh_pusher(app_id, pushkey, user_id)
|
yield self._refresh_pusher(app_id, pushkey, user_id)
|
||||||
|
|
||||||
def _create_pusher(self, pusherdict):
|
def _create_pusher(self, pusherdict):
|
||||||
if pusherdict['kind'] == 'http':
|
if pusherdict['kind'] == 'http':
|
||||||
|
@ -166,7 +167,7 @@ class PusherPool:
|
||||||
if fullid in self.pushers:
|
if fullid in self.pushers:
|
||||||
self.pushers[fullid].stop()
|
self.pushers[fullid].stop()
|
||||||
self.pushers[fullid] = p
|
self.pushers[fullid] = p
|
||||||
p.start()
|
preserve_fn(p.start)()
|
||||||
|
|
||||||
logger.info("Started pushers")
|
logger.info("Started pushers")
|
||||||
|
|
||||||
|
|
|
@ -89,7 +89,7 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
LoginRestServlet.SAML2_TYPE):
|
LoginRestServlet.SAML2_TYPE):
|
||||||
relay_state = ""
|
relay_state = ""
|
||||||
if "relay_state" in login_submission:
|
if "relay_state" in login_submission:
|
||||||
relay_state = "&RelayState="+urllib.quote(
|
relay_state = "&RelayState=" + urllib.quote(
|
||||||
login_submission["relay_state"])
|
login_submission["relay_state"])
|
||||||
result = {
|
result = {
|
||||||
"uri": "%s%s" % (self.idp_redirect_url, relay_state)
|
"uri": "%s%s" % (self.idp_redirect_url, relay_state)
|
||||||
|
|
|
@ -33,7 +33,11 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
|
||||||
user,
|
user,
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue((200, {"displayname": displayname}))
|
ret = {}
|
||||||
|
if displayname is not None:
|
||||||
|
ret["displayname"] = displayname
|
||||||
|
|
||||||
|
defer.returnValue((200, ret))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_PUT(self, request, user_id):
|
def on_PUT(self, request, user_id):
|
||||||
|
@ -66,7 +70,11 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
|
||||||
user,
|
user,
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue((200, {"avatar_url": avatar_url}))
|
ret = {}
|
||||||
|
if avatar_url is not None:
|
||||||
|
ret["avatar_url"] = avatar_url
|
||||||
|
|
||||||
|
defer.returnValue((200, ret))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_PUT(self, request, user_id):
|
def on_PUT(self, request, user_id):
|
||||||
|
@ -102,10 +110,13 @@ class ProfileRestServlet(ClientV1RestServlet):
|
||||||
user,
|
user,
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue((200, {
|
ret = {}
|
||||||
"displayname": displayname,
|
if displayname is not None:
|
||||||
"avatar_url": avatar_url
|
ret["displayname"] = displayname
|
||||||
}))
|
if avatar_url is not None:
|
||||||
|
ret["avatar_url"] = avatar_url
|
||||||
|
|
||||||
|
defer.returnValue((200, ret))
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs, http_server):
|
def register_servlets(hs, http_server):
|
||||||
|
|
|
@ -52,7 +52,7 @@ class PusherRestServlet(ClientV1RestServlet):
|
||||||
if i not in content:
|
if i not in content:
|
||||||
missing.append(i)
|
missing.append(i)
|
||||||
if len(missing):
|
if len(missing):
|
||||||
raise SynapseError(400, "Missing parameters: "+','.join(missing),
|
raise SynapseError(400, "Missing parameters: " + ','.join(missing),
|
||||||
errcode=Codes.MISSING_PARAM)
|
errcode=Codes.MISSING_PARAM)
|
||||||
|
|
||||||
logger.debug("set pushkey %s to kind %s", content['pushkey'], content['kind'])
|
logger.debug("set pushkey %s to kind %s", content['pushkey'], content['kind'])
|
||||||
|
@ -83,7 +83,7 @@ class PusherRestServlet(ClientV1RestServlet):
|
||||||
data=content['data']
|
data=content['data']
|
||||||
)
|
)
|
||||||
except PusherConfigException as pce:
|
except PusherConfigException as pce:
|
||||||
raise SynapseError(400, "Config Error: "+pce.message,
|
raise SynapseError(400, "Config Error: " + pce.message,
|
||||||
errcode=Codes.MISSING_PARAM)
|
errcode=Codes.MISSING_PARAM)
|
||||||
|
|
||||||
defer.returnValue((200, {}))
|
defer.returnValue((200, {}))
|
||||||
|
|
|
@ -38,7 +38,8 @@ logger = logging.getLogger(__name__)
|
||||||
if hasattr(hmac, "compare_digest"):
|
if hasattr(hmac, "compare_digest"):
|
||||||
compare_digest = hmac.compare_digest
|
compare_digest = hmac.compare_digest
|
||||||
else:
|
else:
|
||||||
compare_digest = lambda a, b: a == b
|
def compare_digest(a, b):
|
||||||
|
return a == b
|
||||||
|
|
||||||
|
|
||||||
class RegisterRestServlet(ClientV1RestServlet):
|
class RegisterRestServlet(ClientV1RestServlet):
|
||||||
|
@ -58,7 +59,7 @@ class RegisterRestServlet(ClientV1RestServlet):
|
||||||
# }
|
# }
|
||||||
# TODO: persistent storage
|
# TODO: persistent storage
|
||||||
self.sessions = {}
|
self.sessions = {}
|
||||||
self.disable_registration = hs.config.disable_registration
|
self.enable_registration = hs.config.enable_registration
|
||||||
|
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
if self.hs.config.enable_registration_captcha:
|
if self.hs.config.enable_registration_captcha:
|
||||||
|
@ -112,7 +113,7 @@ class RegisterRestServlet(ClientV1RestServlet):
|
||||||
is_using_shared_secret = login_type == LoginType.SHARED_SECRET
|
is_using_shared_secret = login_type == LoginType.SHARED_SECRET
|
||||||
|
|
||||||
can_register = (
|
can_register = (
|
||||||
not self.disable_registration
|
self.enable_registration
|
||||||
or is_application_server
|
or is_application_server
|
||||||
or is_using_shared_secret
|
or is_using_shared_secret
|
||||||
)
|
)
|
||||||
|
|
|
@ -429,8 +429,6 @@ class RoomEventContext(ClientV1RestServlet):
|
||||||
serialize_event(event, time_now) for event in results["state"]
|
serialize_event(event, time_now) for event in results["state"]
|
||||||
]
|
]
|
||||||
|
|
||||||
logger.info("Responding with %r", results)
|
|
||||||
|
|
||||||
defer.returnValue((200, results))
|
defer.returnValue((200, results))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -116,9 +116,10 @@ class ThreepidRestServlet(RestServlet):
|
||||||
|
|
||||||
body = parse_json_dict_from_request(request)
|
body = parse_json_dict_from_request(request)
|
||||||
|
|
||||||
if 'threePidCreds' not in body:
|
threePidCreds = body.get('threePidCreds')
|
||||||
|
threePidCreds = body.get('three_pid_creds', threePidCreds)
|
||||||
|
if threePidCreds is None:
|
||||||
raise SynapseError(400, "Missing param", Codes.MISSING_PARAM)
|
raise SynapseError(400, "Missing param", Codes.MISSING_PARAM)
|
||||||
threePidCreds = body['threePidCreds']
|
|
||||||
|
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
user_id = requester.user.to_string()
|
user_id = requester.user.to_string()
|
||||||
|
|
|
@ -57,7 +57,7 @@ class AccountDataServlet(RestServlet):
|
||||||
user_id, account_data_type, body
|
user_id, account_data_type, body
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.notifier.on_new_event(
|
self.notifier.on_new_event(
|
||||||
"account_data_key", max_id, users=[user_id]
|
"account_data_key", max_id, users=[user_id]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -99,7 +99,7 @@ class RoomAccountDataServlet(RestServlet):
|
||||||
user_id, room_id, account_data_type, body
|
user_id, room_id, account_data_type, body
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.notifier.on_new_event(
|
self.notifier.on_new_event(
|
||||||
"account_data_key", max_id, users=[user_id]
|
"account_data_key", max_id, users=[user_id]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -34,7 +34,8 @@ from synapse.util.async import run_on_reactor
|
||||||
if hasattr(hmac, "compare_digest"):
|
if hasattr(hmac, "compare_digest"):
|
||||||
compare_digest = hmac.compare_digest
|
compare_digest = hmac.compare_digest
|
||||||
else:
|
else:
|
||||||
compare_digest = lambda a, b: a == b
|
def compare_digest(a, b):
|
||||||
|
return a == b
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -116,7 +117,7 @@ class RegisterRestServlet(RestServlet):
|
||||||
return
|
return
|
||||||
|
|
||||||
# == Normal User Registration == (everyone else)
|
# == Normal User Registration == (everyone else)
|
||||||
if self.hs.config.disable_registration:
|
if not self.hs.config.enable_registration:
|
||||||
raise SynapseError(403, "Registration has been disabled")
|
raise SynapseError(403, "Registration has been disabled")
|
||||||
|
|
||||||
guest_access_token = body.get("guest_access_token", None)
|
guest_access_token = body.get("guest_access_token", None)
|
||||||
|
@ -152,6 +153,7 @@ class RegisterRestServlet(RestServlet):
|
||||||
|
|
||||||
desired_username = params.get("username", None)
|
desired_username = params.get("username", None)
|
||||||
new_password = params.get("password", None)
|
new_password = params.get("password", None)
|
||||||
|
guest_access_token = params.get("guest_access_token", None)
|
||||||
|
|
||||||
(user_id, token) = yield self.registration_handler.register(
|
(user_id, token) = yield self.registration_handler.register(
|
||||||
localpart=desired_username,
|
localpart=desired_username,
|
||||||
|
|
|
@ -20,7 +20,6 @@ from synapse.http.servlet import (
|
||||||
)
|
)
|
||||||
from synapse.handlers.sync import SyncConfig
|
from synapse.handlers.sync import SyncConfig
|
||||||
from synapse.types import StreamToken
|
from synapse.types import StreamToken
|
||||||
from synapse.events import FrozenEvent
|
|
||||||
from synapse.events.utils import (
|
from synapse.events.utils import (
|
||||||
serialize_event, format_event_for_client_v2_without_room_id,
|
serialize_event, format_event_for_client_v2_without_room_id,
|
||||||
)
|
)
|
||||||
|
@ -287,9 +286,6 @@ class SyncRestServlet(RestServlet):
|
||||||
state_dict = room.state
|
state_dict = room.state
|
||||||
timeline_events = room.timeline.events
|
timeline_events = room.timeline.events
|
||||||
|
|
||||||
state_dict = SyncRestServlet._rollback_state_for_timeline(
|
|
||||||
state_dict, timeline_events)
|
|
||||||
|
|
||||||
state_events = state_dict.values()
|
state_events = state_dict.values()
|
||||||
|
|
||||||
serialized_state = [serialize(e) for e in state_events]
|
serialized_state = [serialize(e) for e in state_events]
|
||||||
|
@ -314,77 +310,6 @@ class SyncRestServlet(RestServlet):
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _rollback_state_for_timeline(state, timeline):
|
|
||||||
"""
|
|
||||||
Wind the state dictionary backwards, so that it represents the
|
|
||||||
state at the start of the timeline, rather than at the end.
|
|
||||||
|
|
||||||
:param dict[(str, str), synapse.events.EventBase] state: the
|
|
||||||
state dictionary. Will be updated to the state before the timeline.
|
|
||||||
:param list[synapse.events.EventBase] timeline: the event timeline
|
|
||||||
:return: updated state dictionary
|
|
||||||
"""
|
|
||||||
|
|
||||||
result = state.copy()
|
|
||||||
|
|
||||||
for timeline_event in reversed(timeline):
|
|
||||||
if not timeline_event.is_state():
|
|
||||||
continue
|
|
||||||
|
|
||||||
event_key = (timeline_event.type, timeline_event.state_key)
|
|
||||||
|
|
||||||
logger.debug("Considering %s for removal", event_key)
|
|
||||||
|
|
||||||
state_event = result.get(event_key)
|
|
||||||
if (state_event is None or
|
|
||||||
state_event.event_id != timeline_event.event_id):
|
|
||||||
# the event in the timeline isn't present in the state
|
|
||||||
# dictionary.
|
|
||||||
#
|
|
||||||
# the most likely cause for this is that there was a fork in
|
|
||||||
# the event graph, and the state is no longer valid. Really,
|
|
||||||
# the event shouldn't be in the timeline. We're going to ignore
|
|
||||||
# it for now, however.
|
|
||||||
logger.debug("Found state event %r in timeline which doesn't "
|
|
||||||
"match state dictionary", timeline_event)
|
|
||||||
continue
|
|
||||||
|
|
||||||
prev_event_id = timeline_event.unsigned.get("replaces_state", None)
|
|
||||||
|
|
||||||
prev_content = timeline_event.unsigned.get('prev_content')
|
|
||||||
prev_sender = timeline_event.unsigned.get('prev_sender')
|
|
||||||
# Empircally it seems possible for the event to have a
|
|
||||||
# "replaces_state" key but not a prev_content or prev_sender
|
|
||||||
# markjh conjectures that it could be due to the server not
|
|
||||||
# having a copy of that event.
|
|
||||||
# If this is the case the we ignore the previous event. This will
|
|
||||||
# cause the displayname calculations on the client to be incorrect
|
|
||||||
if prev_event_id is None or not prev_content or not prev_sender:
|
|
||||||
logger.debug(
|
|
||||||
"Removing %r from the state dict, as it is missing"
|
|
||||||
" prev_content (prev_event_id=%r)",
|
|
||||||
timeline_event.event_id, prev_event_id
|
|
||||||
)
|
|
||||||
del result[event_key]
|
|
||||||
else:
|
|
||||||
logger.debug(
|
|
||||||
"Replacing %r with %r in state dict",
|
|
||||||
timeline_event.event_id, prev_event_id
|
|
||||||
)
|
|
||||||
result[event_key] = FrozenEvent({
|
|
||||||
"type": timeline_event.type,
|
|
||||||
"state_key": timeline_event.state_key,
|
|
||||||
"content": prev_content,
|
|
||||||
"sender": prev_sender,
|
|
||||||
"event_id": prev_event_id,
|
|
||||||
"room_id": timeline_event.room_id,
|
|
||||||
})
|
|
||||||
|
|
||||||
logger.debug("New value: %r", result.get(event_key))
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs, http_server):
|
def register_servlets(hs, http_server):
|
||||||
SyncRestServlet(hs).register(http_server)
|
SyncRestServlet(hs).register(http_server)
|
||||||
|
|
|
@ -80,7 +80,7 @@ class TagServlet(RestServlet):
|
||||||
|
|
||||||
max_id = yield self.store.add_tag_to_room(user_id, room_id, tag, body)
|
max_id = yield self.store.add_tag_to_room(user_id, room_id, tag, body)
|
||||||
|
|
||||||
yield self.notifier.on_new_event(
|
self.notifier.on_new_event(
|
||||||
"account_data_key", max_id, users=[user_id]
|
"account_data_key", max_id, users=[user_id]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -94,7 +94,7 @@ class TagServlet(RestServlet):
|
||||||
|
|
||||||
max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag)
|
max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag)
|
||||||
|
|
||||||
yield self.notifier.on_new_event(
|
self.notifier.on_new_event(
|
||||||
"account_data_key", max_id, users=[user_id]
|
"account_data_key", max_id, users=[user_id]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -26,9 +26,7 @@ class VersionsRestServlet(RestServlet):
|
||||||
|
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
return (200, {
|
return (200, {
|
||||||
"versions": [
|
"versions": ["r0.0.1"]
|
||||||
"r0.0.1",
|
|
||||||
]
|
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -28,6 +28,7 @@ from twisted.protocols.basic import FileSender
|
||||||
|
|
||||||
from synapse.util.async import ObservableDeferred
|
from synapse.util.async import ObservableDeferred
|
||||||
from synapse.util.stringutils import is_ascii
|
from synapse.util.stringutils import is_ascii
|
||||||
|
from synapse.util.logcontext import preserve_context_over_fn
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
@ -276,7 +277,8 @@ class BaseMediaResource(Resource):
|
||||||
)
|
)
|
||||||
self._makedirs(t_path)
|
self._makedirs(t_path)
|
||||||
|
|
||||||
t_len = yield threads.deferToThread(
|
t_len = yield preserve_context_over_fn(
|
||||||
|
threads.deferToThread,
|
||||||
self._generate_thumbnail,
|
self._generate_thumbnail,
|
||||||
input_path, t_path, t_width, t_height, t_method, t_type
|
input_path, t_path, t_width, t_height, t_method, t_type
|
||||||
)
|
)
|
||||||
|
@ -298,7 +300,8 @@ class BaseMediaResource(Resource):
|
||||||
)
|
)
|
||||||
self._makedirs(t_path)
|
self._makedirs(t_path)
|
||||||
|
|
||||||
t_len = yield threads.deferToThread(
|
t_len = yield preserve_context_over_fn(
|
||||||
|
threads.deferToThread,
|
||||||
self._generate_thumbnail,
|
self._generate_thumbnail,
|
||||||
input_path, t_path, t_width, t_height, t_method, t_type
|
input_path, t_path, t_width, t_height, t_method, t_type
|
||||||
)
|
)
|
||||||
|
@ -372,7 +375,7 @@ class BaseMediaResource(Resource):
|
||||||
media_id, t_width, t_height, t_type, t_method, t_len
|
media_id, t_width, t_height, t_type, t_method, t_len
|
||||||
))
|
))
|
||||||
|
|
||||||
yield threads.deferToThread(generate_thumbnails)
|
yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails)
|
||||||
|
|
||||||
for l in local_thumbnails:
|
for l in local_thumbnails:
|
||||||
yield self.store.store_local_thumbnail(*l)
|
yield self.store.store_local_thumbnail(*l)
|
||||||
|
@ -445,7 +448,7 @@ class BaseMediaResource(Resource):
|
||||||
t_width, t_height, t_type, t_method, t_len
|
t_width, t_height, t_type, t_method, t_len
|
||||||
])
|
])
|
||||||
|
|
||||||
yield threads.deferToThread(generate_thumbnails)
|
yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails)
|
||||||
|
|
||||||
for r in remote_thumbnails:
|
for r in remote_thumbnails:
|
||||||
yield self.store.store_remote_media_thumbnail(*r)
|
yield self.store.store_remote_media_thumbnail(*r)
|
||||||
|
|
|
@ -63,7 +63,7 @@ class StateHandler(object):
|
||||||
cache_name="state_cache",
|
cache_name="state_cache",
|
||||||
clock=self.clock,
|
clock=self.clock,
|
||||||
max_len=SIZE_OF_CACHE,
|
max_len=SIZE_OF_CACHE,
|
||||||
expiry_ms=EVICTION_TIMEOUT_SECONDS*1000,
|
expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
|
||||||
reset_expiry_on_get=True,
|
reset_expiry_on_get=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -45,9 +45,10 @@ from .search import SearchStore
|
||||||
from .tags import TagsStore
|
from .tags import TagsStore
|
||||||
from .account_data import AccountDataStore
|
from .account_data import AccountDataStore
|
||||||
|
|
||||||
|
|
||||||
from util.id_generators import IdGenerator, StreamIdGenerator
|
from util.id_generators import IdGenerator, StreamIdGenerator
|
||||||
|
|
||||||
|
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
@ -58,7 +59,7 @@ logger = logging.getLogger(__name__)
|
||||||
# Number of msec of granularity to store the user IP 'last seen' time. Smaller
|
# Number of msec of granularity to store the user IP 'last seen' time. Smaller
|
||||||
# times give more inserts into the database even for readonly API hits
|
# times give more inserts into the database even for readonly API hits
|
||||||
# 120 seconds == 2 minutes
|
# 120 seconds == 2 minutes
|
||||||
LAST_SEEN_GRANULARITY = 120*1000
|
LAST_SEEN_GRANULARITY = 120 * 1000
|
||||||
|
|
||||||
|
|
||||||
class DataStore(RoomMemberStore, RoomStore,
|
class DataStore(RoomMemberStore, RoomStore,
|
||||||
|
@ -84,6 +85,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||||
|
|
||||||
def __init__(self, db_conn, hs):
|
def __init__(self, db_conn, hs):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
|
self.database_engine = hs.database_engine
|
||||||
|
|
||||||
cur = db_conn.cursor()
|
cur = db_conn.cursor()
|
||||||
try:
|
try:
|
||||||
|
@ -117,8 +119,61 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||||
self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
|
self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
|
||||||
self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
|
self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
|
||||||
|
|
||||||
|
events_max = self._stream_id_gen.get_max_token(None)
|
||||||
|
event_cache_prefill, min_event_val = self._get_cache_dict(
|
||||||
|
db_conn, "events",
|
||||||
|
entity_column="room_id",
|
||||||
|
stream_column="stream_ordering",
|
||||||
|
max_value=events_max,
|
||||||
|
)
|
||||||
|
self._events_stream_cache = StreamChangeCache(
|
||||||
|
"EventsRoomStreamChangeCache", min_event_val,
|
||||||
|
prefilled_cache=event_cache_prefill,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._membership_stream_cache = StreamChangeCache(
|
||||||
|
"MembershipStreamChangeCache", events_max,
|
||||||
|
)
|
||||||
|
|
||||||
|
account_max = self._account_data_id_gen.get_max_token(None)
|
||||||
|
self._account_data_stream_cache = StreamChangeCache(
|
||||||
|
"AccountDataAndTagsChangeCache", account_max,
|
||||||
|
)
|
||||||
|
|
||||||
super(DataStore, self).__init__(hs)
|
super(DataStore, self).__init__(hs)
|
||||||
|
|
||||||
|
def _get_cache_dict(self, db_conn, table, entity_column, stream_column, max_value):
|
||||||
|
# Fetch a mapping of room_id -> max stream position for "recent" rooms.
|
||||||
|
# It doesn't really matter how many we get, the StreamChangeCache will
|
||||||
|
# do the right thing to ensure it respects the max size of cache.
|
||||||
|
sql = (
|
||||||
|
"SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s"
|
||||||
|
" WHERE %(stream)s > ? - 100000"
|
||||||
|
" GROUP BY %(entity)s"
|
||||||
|
) % {
|
||||||
|
"table": table,
|
||||||
|
"entity": entity_column,
|
||||||
|
"stream": stream_column,
|
||||||
|
}
|
||||||
|
|
||||||
|
sql = self.database_engine.convert_param_style(sql)
|
||||||
|
|
||||||
|
txn = db_conn.cursor()
|
||||||
|
txn.execute(sql, (int(max_value),))
|
||||||
|
rows = txn.fetchall()
|
||||||
|
|
||||||
|
cache = {
|
||||||
|
row[0]: int(row[1])
|
||||||
|
for row in rows
|
||||||
|
}
|
||||||
|
|
||||||
|
if cache:
|
||||||
|
min_val = min(cache.values())
|
||||||
|
else:
|
||||||
|
min_val = max_value
|
||||||
|
|
||||||
|
return cache, min_val
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def insert_client_ip(self, user, access_token, ip, user_agent):
|
def insert_client_ip(self, user, access_token, ip, user_agent):
|
||||||
now = int(self._clock.time_msec())
|
now = int(self._clock.time_msec())
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from synapse.api.errors import StoreError
|
from synapse.api.errors import StoreError
|
||||||
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
|
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
||||||
from synapse.util.caches.dictionary_cache import DictionaryCache
|
from synapse.util.caches.dictionary_cache import DictionaryCache
|
||||||
from synapse.util.caches.descriptors import Cache
|
from synapse.util.caches.descriptors import Cache
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
|
@ -185,7 +185,7 @@ class SQLBaseStore(object):
|
||||||
time_then = self._previous_loop_ts
|
time_then = self._previous_loop_ts
|
||||||
self._previous_loop_ts = time_now
|
self._previous_loop_ts = time_now
|
||||||
|
|
||||||
ratio = (curr - prev)/(time_now - time_then)
|
ratio = (curr - prev) / (time_now - time_then)
|
||||||
|
|
||||||
top_three_counters = self._txn_perf_counters.interval(
|
top_three_counters = self._txn_perf_counters.interval(
|
||||||
time_now - time_then, limit=3
|
time_now - time_then, limit=3
|
||||||
|
@ -298,8 +298,8 @@ class SQLBaseStore(object):
|
||||||
func, *args, **kwargs
|
func, *args, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
result = yield preserve_context_over_fn(
|
with PreserveLoggingContext():
|
||||||
self._db_pool.runWithConnection,
|
result = yield self._db_pool.runWithConnection(
|
||||||
inner_func, *args, **kwargs
|
inner_func, *args, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -326,8 +326,8 @@ class SQLBaseStore(object):
|
||||||
|
|
||||||
return func(conn, *args, **kwargs)
|
return func(conn, *args, **kwargs)
|
||||||
|
|
||||||
result = yield preserve_context_over_fn(
|
with PreserveLoggingContext():
|
||||||
self._db_pool.runWithConnection,
|
result = yield self._db_pool.runWithConnection(
|
||||||
inner_func, *args, **kwargs
|
inner_func, *args, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -643,7 +643,10 @@ class SQLBaseStore(object):
|
||||||
if not iterable:
|
if not iterable:
|
||||||
defer.returnValue(results)
|
defer.returnValue(results)
|
||||||
|
|
||||||
chunks = [iterable[i:i+batch_size] for i in xrange(0, len(iterable), batch_size)]
|
chunks = [
|
||||||
|
iterable[i:i + batch_size]
|
||||||
|
for i in xrange(0, len(iterable), batch_size)
|
||||||
|
]
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
rows = yield self.runInteraction(
|
rows = yield self.runInteraction(
|
||||||
desc,
|
desc,
|
||||||
|
|
|
@ -14,7 +14,6 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore
|
||||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
import ujson as json
|
import ujson as json
|
||||||
|
@ -24,14 +23,6 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AccountDataStore(SQLBaseStore):
|
class AccountDataStore(SQLBaseStore):
|
||||||
def __init__(self, hs):
|
|
||||||
super(AccountDataStore, self).__init__(hs)
|
|
||||||
|
|
||||||
self._account_data_stream_cache = StreamChangeCache(
|
|
||||||
"AccountDataAndTagsChangeCache",
|
|
||||||
self._account_data_id_gen.get_max_token(None),
|
|
||||||
max_size=10000,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_account_data_for_user(self, user_id):
|
def get_account_data_for_user(self, user_id):
|
||||||
"""Get all the client account_data for a user.
|
"""Get all the client account_data for a user.
|
||||||
|
@ -166,6 +157,10 @@ class AccountDataStore(SQLBaseStore):
|
||||||
"content": content_json,
|
"content": content_json,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
txn.call_after(
|
||||||
|
self._account_data_stream_cache.entity_has_changed,
|
||||||
|
user_id, next_id,
|
||||||
|
)
|
||||||
self._update_max_stream_id(txn, next_id)
|
self._update_max_stream_id(txn, next_id)
|
||||||
|
|
||||||
with (yield self._account_data_id_gen.get_next(self)) as next_id:
|
with (yield self._account_data_id_gen.get_next(self)) as next_id:
|
||||||
|
|
|
@ -276,7 +276,8 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
|
||||||
"application_services_state",
|
"application_services_state",
|
||||||
dict(as_id=service.id),
|
dict(as_id=service.id),
|
||||||
["state"],
|
["state"],
|
||||||
allow_none=True
|
allow_none=True,
|
||||||
|
desc="get_appservice_state",
|
||||||
)
|
)
|
||||||
if result:
|
if result:
|
||||||
defer.returnValue(result.get("state"))
|
defer.returnValue(result.get("state"))
|
||||||
|
|
|
@ -54,7 +54,7 @@ class Sqlite3Engine(object):
|
||||||
|
|
||||||
def _parse_match_info(buf):
|
def _parse_match_info(buf):
|
||||||
bufsize = len(buf)
|
bufsize = len(buf)
|
||||||
return [struct.unpack('@I', buf[i:i+4])[0] for i in range(0, bufsize, 4)]
|
return [struct.unpack('@I', buf[i:i + 4])[0] for i in range(0, bufsize, 4)]
|
||||||
|
|
||||||
|
|
||||||
def _rank(raw_match_info):
|
def _rank(raw_match_info):
|
||||||
|
|
|
@ -58,7 +58,7 @@ class EventFederationStore(SQLBaseStore):
|
||||||
new_front = set()
|
new_front = set()
|
||||||
front_list = list(front)
|
front_list = list(front)
|
||||||
chunks = [
|
chunks = [
|
||||||
front_list[x:x+100]
|
front_list[x:x + 100]
|
||||||
for x in xrange(0, len(front), 100)
|
for x in xrange(0, len(front), 100)
|
||||||
]
|
]
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
|
|
|
@ -24,8 +24,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class EventPushActionsStore(SQLBaseStore):
|
class EventPushActionsStore(SQLBaseStore):
|
||||||
@defer.inlineCallbacks
|
def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples):
|
||||||
def set_push_actions_for_event_and_users(self, event, tuples):
|
|
||||||
"""
|
"""
|
||||||
:param event: the event set actions for
|
:param event: the event set actions for
|
||||||
:param tuples: list of tuples of (user_id, profile_tag, actions)
|
:param tuples: list of tuples of (user_id, profile_tag, actions)
|
||||||
|
@ -37,21 +36,19 @@ class EventPushActionsStore(SQLBaseStore):
|
||||||
'event_id': event.event_id,
|
'event_id': event.event_id,
|
||||||
'user_id': uid,
|
'user_id': uid,
|
||||||
'profile_tag': profile_tag,
|
'profile_tag': profile_tag,
|
||||||
'actions': json.dumps(actions)
|
'actions': json.dumps(actions),
|
||||||
|
'stream_ordering': event.internal_metadata.stream_ordering,
|
||||||
|
'topological_ordering': event.depth,
|
||||||
|
'notif': 1,
|
||||||
|
'highlight': 1 if _action_has_highlight(actions) else 0,
|
||||||
})
|
})
|
||||||
|
|
||||||
def f(txn):
|
|
||||||
for uid, _, __ in tuples:
|
for uid, _, __ in tuples:
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
|
self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
|
||||||
(event.room_id, uid)
|
(event.room_id, uid)
|
||||||
)
|
)
|
||||||
return self._simple_insert_many_txn(txn, "event_push_actions", values)
|
self._simple_insert_many_txn(txn, "event_push_actions", values)
|
||||||
|
|
||||||
yield self.runInteraction(
|
|
||||||
"set_actions_for_event_and_users",
|
|
||||||
f,
|
|
||||||
)
|
|
||||||
|
|
||||||
@cachedInlineCallbacks(num_args=3, lru=True, tree=True)
|
@cachedInlineCallbacks(num_args=3, lru=True, tree=True)
|
||||||
def get_unread_event_push_actions_by_room_for_user(
|
def get_unread_event_push_actions_by_room_for_user(
|
||||||
|
@ -68,32 +65,34 @@ class EventPushActionsStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
results = txn.fetchall()
|
results = txn.fetchall()
|
||||||
if len(results) == 0:
|
if len(results) == 0:
|
||||||
return []
|
return {"notify_count": 0, "highlight_count": 0}
|
||||||
|
|
||||||
stream_ordering = results[0][0]
|
stream_ordering = results[0][0]
|
||||||
topological_ordering = results[0][1]
|
topological_ordering = results[0][1]
|
||||||
|
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT ea.event_id, ea.actions"
|
"SELECT sum(notif), sum(highlight)"
|
||||||
" FROM event_push_actions ea, events e"
|
" FROM event_push_actions ea"
|
||||||
" WHERE ea.room_id = e.room_id"
|
" WHERE"
|
||||||
" AND ea.event_id = e.event_id"
|
" user_id = ?"
|
||||||
" AND ea.user_id = ?"
|
" AND room_id = ?"
|
||||||
" AND ea.room_id = ?"
|
|
||||||
" AND ("
|
" AND ("
|
||||||
" e.topological_ordering > ?"
|
" topological_ordering > ?"
|
||||||
" OR (e.topological_ordering = ? AND e.stream_ordering > ?)"
|
" OR (topological_ordering = ? AND stream_ordering > ?)"
|
||||||
")"
|
")"
|
||||||
)
|
)
|
||||||
txn.execute(sql, (
|
txn.execute(sql, (
|
||||||
user_id, room_id,
|
user_id, room_id,
|
||||||
topological_ordering, topological_ordering, stream_ordering
|
topological_ordering, topological_ordering, stream_ordering
|
||||||
)
|
))
|
||||||
)
|
row = txn.fetchone()
|
||||||
return [
|
if row:
|
||||||
{"event_id": row[0], "actions": json.loads(row[1])}
|
return {
|
||||||
for row in txn.fetchall()
|
"notify_count": row[0] or 0,
|
||||||
]
|
"highlight_count": row[1] or 0,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {"notify_count": 0, "highlight_count": 0}
|
||||||
|
|
||||||
ret = yield self.runInteraction(
|
ret = yield self.runInteraction(
|
||||||
"get_unread_event_push_actions_by_room",
|
"get_unread_event_push_actions_by_room",
|
||||||
|
@ -101,9 +100,7 @@ class EventPushActionsStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
defer.returnValue(ret)
|
defer.returnValue(ret)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
|
||||||
def remove_push_actions_for_event_id(self, room_id, event_id):
|
|
||||||
def f(txn):
|
|
||||||
# Sad that we have to blow away the cache for the whole room here
|
# Sad that we have to blow away the cache for the whole room here
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
|
self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
|
||||||
|
@ -113,7 +110,14 @@ class EventPushActionsStore(SQLBaseStore):
|
||||||
"DELETE FROM event_push_actions WHERE room_id = ? AND event_id = ?",
|
"DELETE FROM event_push_actions WHERE room_id = ? AND event_id = ?",
|
||||||
(room_id, event_id)
|
(room_id, event_id)
|
||||||
)
|
)
|
||||||
yield self.runInteraction(
|
|
||||||
"remove_push_actions_for_event_id",
|
|
||||||
f
|
def _action_has_highlight(actions):
|
||||||
)
|
for action in actions:
|
||||||
|
try:
|
||||||
|
if action.get("set_tweak", None) == "highlight":
|
||||||
|
return action.get("value", True)
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
|
@ -19,7 +19,7 @@ from twisted.internet import defer, reactor
|
||||||
from synapse.events import FrozenEvent, USE_FROZEN_DICTS
|
from synapse.events import FrozenEvent, USE_FROZEN_DICTS
|
||||||
from synapse.events.utils import prune_event
|
from synapse.events.utils import prune_event
|
||||||
|
|
||||||
from synapse.util.logcontext import preserve_context_over_deferred
|
from synapse.util.logcontext import preserve_fn, PreserveLoggingContext
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
|
|
||||||
|
@ -84,7 +84,7 @@ class EventsStore(SQLBaseStore):
|
||||||
event.internal_metadata.stream_ordering = stream
|
event.internal_metadata.stream_ordering = stream
|
||||||
|
|
||||||
chunks = [
|
chunks = [
|
||||||
events_and_contexts[x:x+100]
|
events_and_contexts[x:x + 100]
|
||||||
for x in xrange(0, len(events_and_contexts), 100)
|
for x in xrange(0, len(events_and_contexts), 100)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -205,25 +205,31 @@ class EventsStore(SQLBaseStore):
|
||||||
@log_function
|
@log_function
|
||||||
def _persist_events_txn(self, txn, events_and_contexts, backfilled,
|
def _persist_events_txn(self, txn, events_and_contexts, backfilled,
|
||||||
is_new_state=True):
|
is_new_state=True):
|
||||||
|
depth_updates = {}
|
||||||
|
for event, context in events_and_contexts:
|
||||||
# Remove the any existing cache entries for the event_ids
|
# Remove the any existing cache entries for the event_ids
|
||||||
for event, _ in events_and_contexts:
|
|
||||||
txn.call_after(self._invalidate_get_event_cache, event.event_id)
|
txn.call_after(self._invalidate_get_event_cache, event.event_id)
|
||||||
|
|
||||||
if not backfilled:
|
if not backfilled:
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
self._events_stream_cache.entity_has_changed,
|
self._events_stream_cache.entity_has_changed,
|
||||||
event.room_id, event.internal_metadata.stream_ordering,
|
event.room_id, event.internal_metadata.stream_ordering,
|
||||||
)
|
)
|
||||||
|
|
||||||
depth_updates = {}
|
if not event.internal_metadata.is_outlier():
|
||||||
for event, _ in events_and_contexts:
|
|
||||||
if event.internal_metadata.is_outlier():
|
|
||||||
continue
|
|
||||||
depth_updates[event.room_id] = max(
|
depth_updates[event.room_id] = max(
|
||||||
event.depth, depth_updates.get(event.room_id, event.depth)
|
event.depth, depth_updates.get(event.room_id, event.depth)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if context.push_actions:
|
||||||
|
self._set_push_actions_for_event_and_users_txn(
|
||||||
|
txn, event, context.push_actions
|
||||||
|
)
|
||||||
|
|
||||||
|
if event.type == EventTypes.Redaction and event.redacts is not None:
|
||||||
|
self._remove_push_actions_for_event_id_txn(
|
||||||
|
txn, event.room_id, event.redacts
|
||||||
|
)
|
||||||
|
|
||||||
for room_id, depth in depth_updates.items():
|
for room_id, depth in depth_updates.items():
|
||||||
self._update_min_depth_for_room_txn(txn, room_id, depth)
|
self._update_min_depth_for_room_txn(txn, room_id, depth)
|
||||||
|
|
||||||
|
@ -664,6 +670,7 @@ class EventsStore(SQLBaseStore):
|
||||||
for ids, d in lst:
|
for ids, d in lst:
|
||||||
if not d.called:
|
if not d.called:
|
||||||
try:
|
try:
|
||||||
|
with PreserveLoggingContext():
|
||||||
d.callback([
|
d.callback([
|
||||||
res[i]
|
res[i]
|
||||||
for i in ids
|
for i in ids
|
||||||
|
@ -671,6 +678,7 @@ class EventsStore(SQLBaseStore):
|
||||||
])
|
])
|
||||||
except:
|
except:
|
||||||
logger.exception("Failed to callback")
|
logger.exception("Failed to callback")
|
||||||
|
with PreserveLoggingContext():
|
||||||
reactor.callFromThread(fire, event_list, row_dict)
|
reactor.callFromThread(fire, event_list, row_dict)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("do_fetch")
|
logger.exception("do_fetch")
|
||||||
|
@ -679,9 +687,11 @@ class EventsStore(SQLBaseStore):
|
||||||
def fire(evs):
|
def fire(evs):
|
||||||
for _, d in evs:
|
for _, d in evs:
|
||||||
if not d.called:
|
if not d.called:
|
||||||
|
with PreserveLoggingContext():
|
||||||
d.errback(e)
|
d.errback(e)
|
||||||
|
|
||||||
if event_list:
|
if event_list:
|
||||||
|
with PreserveLoggingContext():
|
||||||
reactor.callFromThread(fire, event_list)
|
reactor.callFromThread(fire, event_list)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -709,18 +719,20 @@ class EventsStore(SQLBaseStore):
|
||||||
should_start = False
|
should_start = False
|
||||||
|
|
||||||
if should_start:
|
if should_start:
|
||||||
|
with PreserveLoggingContext():
|
||||||
self.runWithConnection(
|
self.runWithConnection(
|
||||||
self._do_fetch
|
self._do_fetch
|
||||||
)
|
)
|
||||||
|
|
||||||
rows = yield preserve_context_over_deferred(events_d)
|
with PreserveLoggingContext():
|
||||||
|
rows = yield events_d
|
||||||
|
|
||||||
if not allow_rejected:
|
if not allow_rejected:
|
||||||
rows[:] = [r for r in rows if not r["rejects"]]
|
rows[:] = [r for r in rows if not r["rejects"]]
|
||||||
|
|
||||||
res = yield defer.gatherResults(
|
res = yield defer.gatherResults(
|
||||||
[
|
[
|
||||||
self._get_event_from_row(
|
preserve_fn(self._get_event_from_row)(
|
||||||
row["internal_metadata"], row["json"], row["redacts"],
|
row["internal_metadata"], row["json"], row["redacts"],
|
||||||
check_redacted=check_redacted,
|
check_redacted=check_redacted,
|
||||||
get_prev_content=get_prev_content,
|
get_prev_content=get_prev_content,
|
||||||
|
@ -740,7 +752,7 @@ class EventsStore(SQLBaseStore):
|
||||||
rows = []
|
rows = []
|
||||||
N = 200
|
N = 200
|
||||||
for i in range(1 + len(events) / N):
|
for i in range(1 + len(events) / N):
|
||||||
evs = events[i*N:(i + 1)*N]
|
evs = events[i * N:(i + 1) * N]
|
||||||
if not evs:
|
if not evs:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -755,7 +767,7 @@ class EventsStore(SQLBaseStore):
|
||||||
" LEFT JOIN rejections as rej USING (event_id)"
|
" LEFT JOIN rejections as rej USING (event_id)"
|
||||||
" LEFT JOIN redactions as r ON e.event_id = r.redacts"
|
" LEFT JOIN redactions as r ON e.event_id = r.redacts"
|
||||||
" WHERE e.event_id IN (%s)"
|
" WHERE e.event_id IN (%s)"
|
||||||
) % (",".join(["?"]*len(evs)),)
|
) % (",".join(["?"] * len(evs)),)
|
||||||
|
|
||||||
txn.execute(sql, evs)
|
txn.execute(sql, evs)
|
||||||
rows.extend(self.cursor_to_dict(txn))
|
rows.extend(self.cursor_to_dict(txn))
|
||||||
|
|
|
@ -39,6 +39,7 @@ class KeyStore(SQLBaseStore):
|
||||||
table="server_tls_certificates",
|
table="server_tls_certificates",
|
||||||
keyvalues={"server_name": server_name},
|
keyvalues={"server_name": server_name},
|
||||||
retcols=("tls_certificate",),
|
retcols=("tls_certificate",),
|
||||||
|
desc="get_server_certificate",
|
||||||
)
|
)
|
||||||
tls_certificate = OpenSSL.crypto.load_certificate(
|
tls_certificate = OpenSSL.crypto.load_certificate(
|
||||||
OpenSSL.crypto.FILETYPE_ASN1, tls_certificate_bytes,
|
OpenSSL.crypto.FILETYPE_ASN1, tls_certificate_bytes,
|
||||||
|
|
|
@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Remember to update this number every time a change is made to database
|
# Remember to update this number every time a change is made to database
|
||||||
# schema files, so the users will be informed on server restarts.
|
# schema files, so the users will be informed on server restarts.
|
||||||
SCHEMA_VERSION = 28
|
SCHEMA_VERSION = 29
|
||||||
|
|
||||||
dir_path = os.path.abspath(os.path.dirname(__file__))
|
dir_path = os.path.abspath(os.path.dirname(__file__))
|
||||||
|
|
||||||
|
@ -211,7 +211,7 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
|
||||||
logger.debug("applied_delta_files: %s", applied_delta_files)
|
logger.debug("applied_delta_files: %s", applied_delta_files)
|
||||||
|
|
||||||
for v in range(start_ver, SCHEMA_VERSION + 1):
|
for v in range(start_ver, SCHEMA_VERSION + 1):
|
||||||
logger.debug("Upgrading schema to v%d", v)
|
logger.info("Upgrading schema to v%d", v)
|
||||||
|
|
||||||
delta_dir = os.path.join(dir_path, "schema", "delta", str(v))
|
delta_dir = os.path.join(dir_path, "schema", "delta", str(v))
|
||||||
|
|
||||||
|
|
|
@ -68,8 +68,9 @@ class PresenceStore(SQLBaseStore):
|
||||||
for row in rows
|
for row in rows
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def set_presence_state(self, user_localpart, new_state):
|
def set_presence_state(self, user_localpart, new_state):
|
||||||
res = self._simple_update_one(
|
res = yield self._simple_update_one(
|
||||||
table="presence",
|
table="presence",
|
||||||
keyvalues={"user_id": user_localpart},
|
keyvalues={"user_id": user_localpart},
|
||||||
updatevalues={"state": new_state["state"],
|
updatevalues={"state": new_state["state"],
|
||||||
|
@ -79,7 +80,7 @@ class PresenceStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.get_presence_state.invalidate((user_localpart,))
|
self.get_presence_state.invalidate((user_localpart,))
|
||||||
return res
|
defer.returnValue(res)
|
||||||
|
|
||||||
def allow_presence_visible(self, observed_localpart, observer_userid):
|
def allow_presence_visible(self, observed_localpart, observer_userid):
|
||||||
return self._simple_insert(
|
return self._simple_insert(
|
||||||
|
|
|
@ -46,6 +46,20 @@ class ReceiptsStore(SQLBaseStore):
|
||||||
desc="get_receipts_for_room",
|
desc="get_receipts_for_room",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@cached(num_args=3)
|
||||||
|
def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type):
|
||||||
|
return self._simple_select_one_onecol(
|
||||||
|
table="receipts_linearized",
|
||||||
|
keyvalues={
|
||||||
|
"room_id": room_id,
|
||||||
|
"receipt_type": receipt_type,
|
||||||
|
"user_id": user_id
|
||||||
|
},
|
||||||
|
retcol="event_id",
|
||||||
|
desc="get_own_receipt_for_user",
|
||||||
|
allow_none=True,
|
||||||
|
)
|
||||||
|
|
||||||
@cachedInlineCallbacks(num_args=2)
|
@cachedInlineCallbacks(num_args=2)
|
||||||
def get_receipts_for_user(self, user_id, receipt_type):
|
def get_receipts_for_user(self, user_id, receipt_type):
|
||||||
def f(txn):
|
def f(txn):
|
||||||
|
@ -226,6 +240,11 @@ class ReceiptsStore(SQLBaseStore):
|
||||||
room_id, stream_id
|
room_id, stream_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
txn.call_after(
|
||||||
|
self.get_last_receipt_event_id_for_user.invalidate,
|
||||||
|
(user_id, room_id, receipt_type)
|
||||||
|
)
|
||||||
|
|
||||||
# We don't want to clobber receipts for more recent events, so we
|
# We don't want to clobber receipts for more recent events, so we
|
||||||
# have to compare orderings of existing receipts
|
# have to compare orderings of existing receipts
|
||||||
sql = (
|
sql = (
|
||||||
|
|
|
@ -13,6 +13,8 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.errors import StoreError, Codes
|
from synapse.api.errors import StoreError, Codes
|
||||||
|
@ -134,6 +136,7 @@ class RegistrationStore(SQLBaseStore):
|
||||||
},
|
},
|
||||||
retcols=["name", "password_hash", "is_guest"],
|
retcols=["name", "password_hash", "is_guest"],
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
|
desc="get_user_by_id",
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_users_by_id_case_insensitive(self, user_id):
|
def get_users_by_id_case_insensitive(self, user_id):
|
||||||
|
@ -350,3 +353,37 @@ class RegistrationStore(SQLBaseStore):
|
||||||
|
|
||||||
ret = yield self.runInteraction("count_users", _count_users)
|
ret = yield self.runInteraction("count_users", _count_users)
|
||||||
defer.returnValue(ret)
|
defer.returnValue(ret)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def find_next_generated_user_id_localpart(self):
|
||||||
|
"""
|
||||||
|
Gets the localpart of the next generated user ID.
|
||||||
|
|
||||||
|
Generated user IDs are integers, and we aim for them to be as small as
|
||||||
|
we can. Unfortunately, it's possible some of them are already taken by
|
||||||
|
existing users, and there may be gaps in the already taken range. This
|
||||||
|
function returns the start of the first allocatable gap. This is to
|
||||||
|
avoid the case of ID 10000000 being pre-allocated, so us wasting the
|
||||||
|
first (and shortest) many generated user IDs.
|
||||||
|
"""
|
||||||
|
def _find_next_generated_user_id(txn):
|
||||||
|
txn.execute("SELECT name FROM users")
|
||||||
|
rows = self.cursor_to_dict(txn)
|
||||||
|
|
||||||
|
regex = re.compile("^@(\d+):")
|
||||||
|
|
||||||
|
found = set()
|
||||||
|
|
||||||
|
for r in rows:
|
||||||
|
user_id = r["name"]
|
||||||
|
match = regex.search(user_id)
|
||||||
|
if match:
|
||||||
|
found.add(int(match.group(1)))
|
||||||
|
for i in xrange(len(found) + 1):
|
||||||
|
if i not in found:
|
||||||
|
return i
|
||||||
|
|
||||||
|
defer.returnValue((yield self.runInteraction(
|
||||||
|
"find_next_generated_user_id",
|
||||||
|
_find_next_generated_user_id
|
||||||
|
)))
|
||||||
|
|
|
@ -87,90 +87,20 @@ class RoomStore(SQLBaseStore):
|
||||||
desc="get_public_room_ids",
|
desc="get_public_room_ids",
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
def get_room_count(self):
|
||||||
def get_rooms(self, is_public):
|
"""Retrieve a list of all rooms
|
||||||
"""Retrieve a list of all public rooms.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
is_public (bool): True if the rooms returned should be public.
|
|
||||||
Returns:
|
|
||||||
A list of room dicts containing at least a "room_id" key, a
|
|
||||||
"topic" key if one is set, and a "name" key if one is set
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def f(txn):
|
def f(txn):
|
||||||
def subquery(table_name, column_name=None):
|
sql = "SELECT count(*) FROM rooms"
|
||||||
column_name = column_name or table_name
|
txn.execute(sql)
|
||||||
return (
|
row = txn.fetchone()
|
||||||
"SELECT %(table_name)s.event_id as event_id, "
|
return row[0] or 0
|
||||||
"%(table_name)s.room_id as room_id, %(column_name)s "
|
|
||||||
"FROM %(table_name)s "
|
|
||||||
"INNER JOIN current_state_events as c "
|
|
||||||
"ON c.event_id = %(table_name)s.event_id " % {
|
|
||||||
"column_name": column_name,
|
|
||||||
"table_name": table_name,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
sql = (
|
return self.runInteraction(
|
||||||
"SELECT"
|
|
||||||
" r.room_id,"
|
|
||||||
" max(n.name),"
|
|
||||||
" max(t.topic),"
|
|
||||||
" max(v.history_visibility),"
|
|
||||||
" max(g.guest_access)"
|
|
||||||
" FROM rooms AS r"
|
|
||||||
" LEFT JOIN (%(topic)s) AS t ON t.room_id = r.room_id"
|
|
||||||
" LEFT JOIN (%(name)s) AS n ON n.room_id = r.room_id"
|
|
||||||
" LEFT JOIN (%(history_visibility)s) AS v ON v.room_id = r.room_id"
|
|
||||||
" LEFT JOIN (%(guest_access)s) AS g ON g.room_id = r.room_id"
|
|
||||||
" WHERE r.is_public = ?"
|
|
||||||
" GROUP BY r.room_id" % {
|
|
||||||
"topic": subquery("topics", "topic"),
|
|
||||||
"name": subquery("room_names", "name"),
|
|
||||||
"history_visibility": subquery("history_visibility"),
|
|
||||||
"guest_access": subquery("guest_access"),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
txn.execute(sql, (is_public,))
|
|
||||||
|
|
||||||
rows = txn.fetchall()
|
|
||||||
|
|
||||||
for i, row in enumerate(rows):
|
|
||||||
room_id = row[0]
|
|
||||||
aliases = self._simple_select_onecol_txn(
|
|
||||||
txn,
|
|
||||||
table="room_aliases",
|
|
||||||
keyvalues={
|
|
||||||
"room_id": room_id
|
|
||||||
},
|
|
||||||
retcol="room_alias",
|
|
||||||
)
|
|
||||||
|
|
||||||
rows[i] = list(row) + [aliases]
|
|
||||||
|
|
||||||
return rows
|
|
||||||
|
|
||||||
rows = yield self.runInteraction(
|
|
||||||
"get_rooms", f
|
"get_rooms", f
|
||||||
)
|
)
|
||||||
|
|
||||||
ret = [
|
|
||||||
{
|
|
||||||
"room_id": r[0],
|
|
||||||
"name": r[1],
|
|
||||||
"topic": r[2],
|
|
||||||
"world_readable": r[3] == "world_readable",
|
|
||||||
"guest_can_join": r[4] == "can_join",
|
|
||||||
"aliases": r[5],
|
|
||||||
}
|
|
||||||
for r in rows
|
|
||||||
if r[5] # We only return rooms that have at least one alias.
|
|
||||||
]
|
|
||||||
|
|
||||||
defer.returnValue(ret)
|
|
||||||
|
|
||||||
def _store_room_topic_txn(self, txn, event):
|
def _store_room_topic_txn(self, txn, event):
|
||||||
if hasattr(event, "content") and "topic" in event.content:
|
if hasattr(event, "content") and "topic" in event.content:
|
||||||
self._simple_insert_txn(
|
self._simple_insert_txn(
|
||||||
|
|
|
@ -58,6 +58,10 @@ class RoomMemberStore(SQLBaseStore):
|
||||||
txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,))
|
txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,))
|
||||||
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
|
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
|
||||||
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
|
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
|
||||||
|
txn.call_after(
|
||||||
|
self._membership_stream_cache.entity_has_changed,
|
||||||
|
event.state_key, event.internal_metadata.stream_ordering
|
||||||
|
)
|
||||||
|
|
||||||
def get_room_member(self, user_id, room_id):
|
def get_room_member(self, user_id, room_id):
|
||||||
"""Retrieve the current state of a room member.
|
"""Retrieve the current state of a room member.
|
||||||
|
|
16
synapse/storage/schema/delta/28/public_roms_index.sql
Normal file
16
synapse/storage/schema/delta/28/public_roms_index.sql
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
/* Copyright 2016 OpenMarket Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
CREATE INDEX public_room_index on rooms(is_public);
|
31
synapse/storage/schema/delta/29/push_actions.sql
Normal file
31
synapse/storage/schema/delta/29/push_actions.sql
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
/* Copyright 2016 OpenMarket Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
ALTER TABLE event_push_actions ADD COLUMN topological_ordering BIGINT;
|
||||||
|
ALTER TABLE event_push_actions ADD COLUMN stream_ordering BIGINT;
|
||||||
|
ALTER TABLE event_push_actions ADD COLUMN notif SMALLINT;
|
||||||
|
ALTER TABLE event_push_actions ADD COLUMN highlight SMALLINT;
|
||||||
|
|
||||||
|
UPDATE event_push_actions SET stream_ordering = (
|
||||||
|
SELECT stream_ordering FROM events WHERE event_id = event_push_actions.event_id
|
||||||
|
), topological_ordering = (
|
||||||
|
SELECT topological_ordering FROM events WHERE event_id = event_push_actions.event_id
|
||||||
|
);
|
||||||
|
|
||||||
|
UPDATE event_push_actions SET notif = 1, highlight = 0;
|
||||||
|
|
||||||
|
CREATE INDEX event_push_actions_rm_tokens on event_push_actions(
|
||||||
|
user_id, room_id, topological_ordering, stream_ordering
|
||||||
|
);
|
|
@ -171,15 +171,10 @@ class StateStore(SQLBaseStore):
|
||||||
events = yield self._get_events(event_ids, get_prev_content=False)
|
events = yield self._get_events(event_ids, get_prev_content=False)
|
||||||
defer.returnValue(events)
|
defer.returnValue(events)
|
||||||
|
|
||||||
def _get_state_groups_from_groups(self, groups_and_types):
|
def _get_state_groups_from_groups(self, groups, types):
|
||||||
"""Returns dictionary state_group -> state event ids
|
"""Returns dictionary state_group -> state event ids
|
||||||
|
|
||||||
Args:
|
|
||||||
groups_and_types (list): list of 2-tuple (`group`, `types`)
|
|
||||||
"""
|
"""
|
||||||
def f(txn):
|
def f(txn, groups):
|
||||||
results = {}
|
|
||||||
for group, types in groups_and_types:
|
|
||||||
if types is not None:
|
if types is not None:
|
||||||
where_clause = "AND (%s)" % (
|
where_clause = "AND (%s)" % (
|
||||||
" OR ".join(["(type = ? AND state_key = ?)"] * len(types)),
|
" OR ".join(["(type = ? AND state_key = ?)"] * len(types)),
|
||||||
|
@ -188,23 +183,30 @@ class StateStore(SQLBaseStore):
|
||||||
where_clause = ""
|
where_clause = ""
|
||||||
|
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT event_id FROM state_groups_state WHERE"
|
"SELECT state_group, event_id FROM state_groups_state WHERE"
|
||||||
" state_group = ? %s"
|
" state_group IN (%s) %s" % (
|
||||||
) % (where_clause,)
|
",".join("?" for _ in groups),
|
||||||
|
where_clause,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
args = [group]
|
args = list(groups)
|
||||||
if types is not None:
|
if types is not None:
|
||||||
args.extend([i for typ in types for i in typ])
|
args.extend([i for typ in types for i in typ])
|
||||||
|
|
||||||
txn.execute(sql, args)
|
txn.execute(sql, args)
|
||||||
|
rows = self.cursor_to_dict(txn)
|
||||||
|
|
||||||
results[group] = [r[0] for r in txn.fetchall()]
|
results = {}
|
||||||
|
for row in rows:
|
||||||
|
results.setdefault(row["state_group"], []).append(row["event_id"])
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
chunks = [groups[i:i + 100] for i in xrange(0, len(groups), 100)]
|
||||||
|
for chunk in chunks:
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
"_get_state_groups_from_groups",
|
"_get_state_groups_from_groups",
|
||||||
f,
|
f, chunk
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -264,26 +266,20 @@ class StateStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
@cachedList(cache=_get_state_group_for_event.cache, list_name="event_ids",
|
@cachedList(cache=_get_state_group_for_event.cache, list_name="event_ids",
|
||||||
num_args=1)
|
num_args=1, inlineCallbacks=True)
|
||||||
def _get_state_group_for_events(self, event_ids):
|
def _get_state_group_for_events(self, event_ids):
|
||||||
"""Returns mapping event_id -> state_group
|
"""Returns mapping event_id -> state_group
|
||||||
"""
|
"""
|
||||||
def f(txn):
|
rows = yield self._simple_select_many_batch(
|
||||||
results = {}
|
|
||||||
for event_id in event_ids:
|
|
||||||
results[event_id] = self._simple_select_one_onecol_txn(
|
|
||||||
txn,
|
|
||||||
table="event_to_state_groups",
|
table="event_to_state_groups",
|
||||||
keyvalues={
|
column="event_id",
|
||||||
"event_id": event_id,
|
iterable=event_ids,
|
||||||
},
|
keyvalues={},
|
||||||
retcol="state_group",
|
retcols=("event_id", "state_group",),
|
||||||
allow_none=True,
|
desc="_get_state_group_for_events",
|
||||||
)
|
)
|
||||||
|
|
||||||
return results
|
defer.returnValue({row["event_id"]: row["state_group"] for row in rows})
|
||||||
|
|
||||||
return self.runInteraction("_get_state_group_for_events", f)
|
|
||||||
|
|
||||||
def _get_some_state_from_cache(self, group, types):
|
def _get_some_state_from_cache(self, group, types):
|
||||||
"""Checks if group is in cache. See `_get_state_for_groups`
|
"""Checks if group is in cache. See `_get_state_for_groups`
|
||||||
|
@ -355,7 +351,7 @@ class StateStore(SQLBaseStore):
|
||||||
all events are returned.
|
all events are returned.
|
||||||
"""
|
"""
|
||||||
results = {}
|
results = {}
|
||||||
missing_groups_and_types = []
|
missing_groups = []
|
||||||
if types is not None:
|
if types is not None:
|
||||||
for group in set(groups):
|
for group in set(groups):
|
||||||
state_dict, missing_types, got_all = self._get_some_state_from_cache(
|
state_dict, missing_types, got_all = self._get_some_state_from_cache(
|
||||||
|
@ -364,7 +360,7 @@ class StateStore(SQLBaseStore):
|
||||||
results[group] = state_dict
|
results[group] = state_dict
|
||||||
|
|
||||||
if not got_all:
|
if not got_all:
|
||||||
missing_groups_and_types.append((group, missing_types))
|
missing_groups.append(group)
|
||||||
else:
|
else:
|
||||||
for group in set(groups):
|
for group in set(groups):
|
||||||
state_dict, got_all = self._get_all_state_from_cache(
|
state_dict, got_all = self._get_all_state_from_cache(
|
||||||
|
@ -373,9 +369,9 @@ class StateStore(SQLBaseStore):
|
||||||
results[group] = state_dict
|
results[group] = state_dict
|
||||||
|
|
||||||
if not got_all:
|
if not got_all:
|
||||||
missing_groups_and_types.append((group, None))
|
missing_groups.append(group)
|
||||||
|
|
||||||
if not missing_groups_and_types:
|
if not missing_groups:
|
||||||
defer.returnValue({
|
defer.returnValue({
|
||||||
group: {
|
group: {
|
||||||
type_tuple: event
|
type_tuple: event
|
||||||
|
@ -389,7 +385,7 @@ class StateStore(SQLBaseStore):
|
||||||
cache_seq_num = self._state_group_cache.sequence
|
cache_seq_num = self._state_group_cache.sequence
|
||||||
|
|
||||||
group_state_dict = yield self._get_state_groups_from_groups(
|
group_state_dict = yield self._get_state_groups_from_groups(
|
||||||
missing_groups_and_types
|
missing_groups, types
|
||||||
)
|
)
|
||||||
|
|
||||||
state_events = yield self._get_events(
|
state_events = yield self._get_events(
|
||||||
|
|
|
@ -37,10 +37,9 @@ from twisted.internet import defer
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore
|
||||||
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
||||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.types import RoomStreamToken
|
from synapse.types import RoomStreamToken
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logcontext import preserve_fn
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
@ -78,13 +77,6 @@ def upper_bound(token):
|
||||||
|
|
||||||
|
|
||||||
class StreamStore(SQLBaseStore):
|
class StreamStore(SQLBaseStore):
|
||||||
def __init__(self, hs):
|
|
||||||
super(StreamStore, self).__init__(hs)
|
|
||||||
|
|
||||||
self._events_stream_cache = StreamChangeCache(
|
|
||||||
"EventsRoomStreamChangeCache", self._stream_id_gen.get_max_token(None)
|
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_appservice_room_stream(self, service, from_key, to_key, limit=0):
|
def get_appservice_room_stream(self, service, from_key, to_key, limit=0):
|
||||||
# NB this lives here instead of appservice.py so we can reuse the
|
# NB this lives here instead of appservice.py so we can reuse the
|
||||||
|
@ -177,14 +169,14 @@ class StreamStore(SQLBaseStore):
|
||||||
|
|
||||||
results = {}
|
results = {}
|
||||||
room_ids = list(room_ids)
|
room_ids = list(room_ids)
|
||||||
for rm_ids in (room_ids[i:i+20] for i in xrange(0, len(room_ids), 20)):
|
for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)):
|
||||||
res = yield defer.gatherResults([
|
res = yield defer.gatherResults([
|
||||||
self.get_room_events_stream_for_room(
|
preserve_fn(self.get_room_events_stream_for_room)(
|
||||||
room_id, from_key, to_key, limit
|
room_id, from_key, to_key, limit,
|
||||||
).addCallback(lambda r, rm: (rm, r), room_id)
|
)
|
||||||
for room_id in room_ids
|
for room_id in room_ids
|
||||||
])
|
])
|
||||||
results.update(dict(res))
|
results.update(dict(zip(rm_ids, res)))
|
||||||
|
|
||||||
defer.returnValue(results)
|
defer.returnValue(results)
|
||||||
|
|
||||||
|
@ -229,8 +221,11 @@ class StreamStore(SQLBaseStore):
|
||||||
|
|
||||||
rows = self.cursor_to_dict(txn)
|
rows = self.cursor_to_dict(txn)
|
||||||
|
|
||||||
ret = self._get_events_txn(
|
return rows
|
||||||
txn,
|
|
||||||
|
rows = yield self.runInteraction("get_room_events_stream_for_room", f)
|
||||||
|
|
||||||
|
ret = yield self._get_events(
|
||||||
[r["event_id"] for r in rows],
|
[r["event_id"] for r in rows],
|
||||||
get_prev_content=True
|
get_prev_content=True
|
||||||
)
|
)
|
||||||
|
@ -246,11 +241,10 @@ class StreamStore(SQLBaseStore):
|
||||||
# get.
|
# get.
|
||||||
key = from_key
|
key = from_key
|
||||||
|
|
||||||
return ret, key
|
defer.returnValue((ret, key))
|
||||||
res = yield self.runInteraction("get_room_events_stream_for_room", f)
|
|
||||||
defer.returnValue(res)
|
|
||||||
|
|
||||||
def get_room_changes_for_user(self, user_id, from_key, to_key):
|
@defer.inlineCallbacks
|
||||||
|
def get_membership_changes_for_user(self, user_id, from_key, to_key):
|
||||||
if from_key is not None:
|
if from_key is not None:
|
||||||
from_id = RoomStreamToken.parse_stream_token(from_key).stream
|
from_id = RoomStreamToken.parse_stream_token(from_key).stream
|
||||||
else:
|
else:
|
||||||
|
@ -258,7 +252,14 @@ class StreamStore(SQLBaseStore):
|
||||||
to_id = RoomStreamToken.parse_stream_token(to_key).stream
|
to_id = RoomStreamToken.parse_stream_token(to_key).stream
|
||||||
|
|
||||||
if from_key == to_key:
|
if from_key == to_key:
|
||||||
return defer.succeed([])
|
defer.returnValue([])
|
||||||
|
|
||||||
|
if from_id:
|
||||||
|
has_changed = self._membership_stream_cache.has_entity_changed(
|
||||||
|
user_id, int(from_id)
|
||||||
|
)
|
||||||
|
if not has_changed:
|
||||||
|
defer.returnValue([])
|
||||||
|
|
||||||
def f(txn):
|
def f(txn):
|
||||||
if from_id is not None:
|
if from_id is not None:
|
||||||
|
@ -283,17 +284,19 @@ class StreamStore(SQLBaseStore):
|
||||||
txn.execute(sql, (user_id, to_id,))
|
txn.execute(sql, (user_id, to_id,))
|
||||||
rows = self.cursor_to_dict(txn)
|
rows = self.cursor_to_dict(txn)
|
||||||
|
|
||||||
ret = self._get_events_txn(
|
return rows
|
||||||
txn,
|
|
||||||
|
rows = yield self.runInteraction("get_membership_changes_for_user", f)
|
||||||
|
|
||||||
|
ret = yield self._get_events(
|
||||||
[r["event_id"] for r in rows],
|
[r["event_id"] for r in rows],
|
||||||
get_prev_content=True
|
get_prev_content=True
|
||||||
)
|
)
|
||||||
|
|
||||||
return ret
|
self._set_before_and_after(ret, rows, topo_order=False)
|
||||||
|
|
||||||
return self.runInteraction("get_room_changes_for_user", f)
|
defer.returnValue(ret)
|
||||||
|
|
||||||
@log_function
|
|
||||||
def get_room_events_stream(
|
def get_room_events_stream(
|
||||||
self,
|
self,
|
||||||
user_id,
|
user_id,
|
||||||
|
@ -324,11 +327,6 @@ class StreamStore(SQLBaseStore):
|
||||||
" WHERE m.user_id = ? AND m.membership = 'join'"
|
" WHERE m.user_id = ? AND m.membership = 'join'"
|
||||||
)
|
)
|
||||||
current_room_membership_args = [user_id]
|
current_room_membership_args = [user_id]
|
||||||
if room_ids:
|
|
||||||
current_room_membership_sql += " AND m.room_id in (%s)" % (
|
|
||||||
",".join(map(lambda _: "?", room_ids))
|
|
||||||
)
|
|
||||||
current_room_membership_args = [user_id] + room_ids
|
|
||||||
|
|
||||||
# We also want to get any membership events about that user, e.g.
|
# We also want to get any membership events about that user, e.g.
|
||||||
# invites or leave notifications.
|
# invites or leave notifications.
|
||||||
|
@ -567,6 +565,7 @@ class StreamStore(SQLBaseStore):
|
||||||
table="events",
|
table="events",
|
||||||
keyvalues={"event_id": event_id},
|
keyvalues={"event_id": event_id},
|
||||||
retcols=("stream_ordering", "topological_ordering"),
|
retcols=("stream_ordering", "topological_ordering"),
|
||||||
|
desc="get_topological_token_for_event",
|
||||||
).addCallback(lambda row: "t%d-%d" % (
|
).addCallback(lambda row: "t%d-%d" % (
|
||||||
row["topological_ordering"], row["stream_ordering"],)
|
row["topological_ordering"], row["stream_ordering"],)
|
||||||
)
|
)
|
||||||
|
@ -604,6 +603,10 @@ class StreamStore(SQLBaseStore):
|
||||||
internal = event.internal_metadata
|
internal = event.internal_metadata
|
||||||
internal.before = str(RoomStreamToken(topo, stream - 1))
|
internal.before = str(RoomStreamToken(topo, stream - 1))
|
||||||
internal.after = str(RoomStreamToken(topo, stream))
|
internal.after = str(RoomStreamToken(topo, stream))
|
||||||
|
internal.order = (
|
||||||
|
int(topo) if topo else 0,
|
||||||
|
int(stream),
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_events_around(self, room_id, event_id, before_limit, after_limit):
|
def get_events_around(self, room_id, event_id, before_limit, after_limit):
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
from synapse.util.logcontext import PreserveLoggingContext
|
||||||
|
|
||||||
from twisted.internet import defer, reactor, task
|
from twisted.internet import defer, reactor, task
|
||||||
|
|
||||||
|
@ -46,7 +46,7 @@ class Clock(object):
|
||||||
|
|
||||||
def looping_call(self, f, msec):
|
def looping_call(self, f, msec):
|
||||||
l = task.LoopingCall(f)
|
l = task.LoopingCall(f)
|
||||||
l.start(msec/1000.0, now=False)
|
l.start(msec / 1000.0, now=False)
|
||||||
return l
|
return l
|
||||||
|
|
||||||
def stop_looping_call(self, loop):
|
def stop_looping_call(self, loop):
|
||||||
|
@ -61,10 +61,8 @@ class Clock(object):
|
||||||
*args: Postional arguments to pass to function.
|
*args: Postional arguments to pass to function.
|
||||||
**kwargs: Key arguments to pass to function.
|
**kwargs: Key arguments to pass to function.
|
||||||
"""
|
"""
|
||||||
current_context = LoggingContext.current_context()
|
|
||||||
|
|
||||||
def wrapped_callback(*args, **kwargs):
|
def wrapped_callback(*args, **kwargs):
|
||||||
with PreserveLoggingContext(current_context):
|
with PreserveLoggingContext():
|
||||||
callback(*args, **kwargs)
|
callback(*args, **kwargs)
|
||||||
|
|
||||||
with PreserveLoggingContext():
|
with PreserveLoggingContext():
|
||||||
|
|
|
@ -16,13 +16,16 @@
|
||||||
|
|
||||||
from twisted.internet import defer, reactor
|
from twisted.internet import defer, reactor
|
||||||
|
|
||||||
from .logcontext import preserve_context_over_deferred
|
from .logcontext import PreserveLoggingContext
|
||||||
|
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def sleep(seconds):
|
def sleep(seconds):
|
||||||
d = defer.Deferred()
|
d = defer.Deferred()
|
||||||
|
with PreserveLoggingContext():
|
||||||
reactor.callLater(seconds, d.callback, seconds)
|
reactor.callLater(seconds, d.callback, seconds)
|
||||||
return preserve_context_over_deferred(d)
|
res = yield d
|
||||||
|
defer.returnValue(res)
|
||||||
|
|
||||||
|
|
||||||
def run_on_reactor():
|
def run_on_reactor():
|
||||||
|
@ -54,6 +57,7 @@ class ObservableDeferred(object):
|
||||||
object.__setattr__(self, "_result", (True, r))
|
object.__setattr__(self, "_result", (True, r))
|
||||||
while self._observers:
|
while self._observers:
|
||||||
try:
|
try:
|
||||||
|
# TODO: Handle errors here.
|
||||||
self._observers.pop().callback(r)
|
self._observers.pop().callback(r)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
@ -63,6 +67,7 @@ class ObservableDeferred(object):
|
||||||
object.__setattr__(self, "_result", (False, f))
|
object.__setattr__(self, "_result", (False, f))
|
||||||
while self._observers:
|
while self._observers:
|
||||||
try:
|
try:
|
||||||
|
# TODO: Handle errors here.
|
||||||
self._observers.pop().errback(f)
|
self._observers.pop().errback(f)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -18,6 +18,9 @@ from synapse.util.async import ObservableDeferred
|
||||||
from synapse.util import unwrapFirstError
|
from synapse.util import unwrapFirstError
|
||||||
from synapse.util.caches.lrucache import LruCache
|
from synapse.util.caches.lrucache import LruCache
|
||||||
from synapse.util.caches.treecache import TreeCache
|
from synapse.util.caches.treecache import TreeCache
|
||||||
|
from synapse.util.logcontext import (
|
||||||
|
PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
|
||||||
|
)
|
||||||
|
|
||||||
from . import caches_by_name, DEBUG_CACHES, cache_counter
|
from . import caches_by_name, DEBUG_CACHES, cache_counter
|
||||||
|
|
||||||
|
@ -149,7 +152,7 @@ class CacheDescriptor(object):
|
||||||
self.lru = lru
|
self.lru = lru
|
||||||
self.tree = tree
|
self.tree = tree
|
||||||
|
|
||||||
self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
|
self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
|
||||||
|
|
||||||
if len(self.arg_names) < self.num_args:
|
if len(self.arg_names) < self.num_args:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
@ -190,7 +193,7 @@ class CacheDescriptor(object):
|
||||||
defer.returnValue(cached_result)
|
defer.returnValue(cached_result)
|
||||||
observer.addCallback(check_result)
|
observer.addCallback(check_result)
|
||||||
|
|
||||||
return observer
|
return preserve_context_over_deferred(observer)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
# Get the sequence number of the cache before reading from the
|
# Get the sequence number of the cache before reading from the
|
||||||
# database so that we can tell if the cache is invalidated
|
# database so that we can tell if the cache is invalidated
|
||||||
|
@ -198,6 +201,7 @@ class CacheDescriptor(object):
|
||||||
sequence = self.cache.sequence
|
sequence = self.cache.sequence
|
||||||
|
|
||||||
ret = defer.maybeDeferred(
|
ret = defer.maybeDeferred(
|
||||||
|
preserve_context_over_fn,
|
||||||
self.function_to_call,
|
self.function_to_call,
|
||||||
obj, *args, **kwargs
|
obj, *args, **kwargs
|
||||||
)
|
)
|
||||||
|
@ -211,7 +215,7 @@ class CacheDescriptor(object):
|
||||||
ret = ObservableDeferred(ret, consumeErrors=True)
|
ret = ObservableDeferred(ret, consumeErrors=True)
|
||||||
self.cache.update(sequence, cache_key, ret)
|
self.cache.update(sequence, cache_key, ret)
|
||||||
|
|
||||||
return ret.observe()
|
return preserve_context_over_deferred(ret.observe())
|
||||||
|
|
||||||
wrapped.invalidate = self.cache.invalidate
|
wrapped.invalidate = self.cache.invalidate
|
||||||
wrapped.invalidate_all = self.cache.invalidate_all
|
wrapped.invalidate_all = self.cache.invalidate_all
|
||||||
|
@ -250,7 +254,7 @@ class CacheListDescriptor(object):
|
||||||
self.num_args = num_args
|
self.num_args = num_args
|
||||||
self.list_name = list_name
|
self.list_name = list_name
|
||||||
|
|
||||||
self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
|
self.arg_names = inspect.getargspec(orig).args[1:num_args + 1]
|
||||||
self.list_pos = self.arg_names.index(self.list_name)
|
self.list_pos = self.arg_names.index(self.list_name)
|
||||||
|
|
||||||
self.cache = cache
|
self.cache = cache
|
||||||
|
@ -299,6 +303,7 @@ class CacheListDescriptor(object):
|
||||||
args_to_call[self.list_name] = missing
|
args_to_call[self.list_name] = missing
|
||||||
|
|
||||||
ret_d = defer.maybeDeferred(
|
ret_d = defer.maybeDeferred(
|
||||||
|
preserve_context_over_fn,
|
||||||
self.function_to_call,
|
self.function_to_call,
|
||||||
**args_to_call
|
**args_to_call
|
||||||
)
|
)
|
||||||
|
@ -308,6 +313,7 @@ class CacheListDescriptor(object):
|
||||||
# We need to create deferreds for each arg in the list so that
|
# We need to create deferreds for each arg in the list so that
|
||||||
# we can insert the new deferred into the cache.
|
# we can insert the new deferred into the cache.
|
||||||
for arg in missing:
|
for arg in missing:
|
||||||
|
with PreserveLoggingContext():
|
||||||
observer = ret_d.observe()
|
observer = ret_d.observe()
|
||||||
observer.addCallback(lambda r, arg: r.get(arg, None), arg)
|
observer.addCallback(lambda r, arg: r.get(arg, None), arg)
|
||||||
|
|
||||||
|
@ -327,10 +333,10 @@ class CacheListDescriptor(object):
|
||||||
|
|
||||||
cached[arg] = res
|
cached[arg] = res
|
||||||
|
|
||||||
return defer.gatherResults(
|
return preserve_context_over_deferred(defer.gatherResults(
|
||||||
cached.values(),
|
cached.values(),
|
||||||
consumeErrors=True,
|
consumeErrors=True,
|
||||||
).addErrback(unwrapFirstError).addCallback(lambda res: dict(res))
|
).addErrback(unwrapFirstError).addCallback(lambda res: dict(res)))
|
||||||
|
|
||||||
obj.__dict__[self.orig.__name__] = wrapped
|
obj.__dict__[self.orig.__name__] = wrapped
|
||||||
|
|
||||||
|
|
|
@ -55,7 +55,7 @@ class ExpiringCache(object):
|
||||||
def f():
|
def f():
|
||||||
self._prune_cache()
|
self._prune_cache()
|
||||||
|
|
||||||
self._clock.looping_call(f, self._expiry_ms/2)
|
self._clock.looping_call(f, self._expiry_ms / 2)
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
def __setitem__(self, key, value):
|
||||||
now = self._clock.time_msec()
|
now = self._clock.time_msec()
|
||||||
|
|
|
@ -87,7 +87,8 @@ class SnapshotCache(object):
|
||||||
# expire from the rotation of that cache.
|
# expire from the rotation of that cache.
|
||||||
self.next_result_cache[key] = result
|
self.next_result_cache[key] = result
|
||||||
self.pending_result_cache.pop(key, None)
|
self.pending_result_cache.pop(key, None)
|
||||||
|
return r
|
||||||
|
|
||||||
result.observe().addBoth(shuffle_along)
|
result.addBoth(shuffle_along)
|
||||||
|
|
||||||
return result.observe()
|
return result.observe()
|
||||||
|
|
|
@ -32,7 +32,7 @@ class StreamChangeCache(object):
|
||||||
entities that may have changed since that position. If position key is too
|
entities that may have changed since that position. If position key is too
|
||||||
old then the cache will simply return all given entities.
|
old then the cache will simply return all given entities.
|
||||||
"""
|
"""
|
||||||
def __init__(self, name, current_stream_pos, max_size=10000):
|
def __init__(self, name, current_stream_pos, max_size=10000, prefilled_cache={}):
|
||||||
self._max_size = max_size
|
self._max_size = max_size
|
||||||
self._entity_to_key = {}
|
self._entity_to_key = {}
|
||||||
self._cache = sorteddict()
|
self._cache = sorteddict()
|
||||||
|
@ -40,6 +40,9 @@ class StreamChangeCache(object):
|
||||||
self.name = name
|
self.name = name
|
||||||
caches_by_name[self.name] = self._cache
|
caches_by_name[self.name] = self._cache
|
||||||
|
|
||||||
|
for entity, stream_pos in prefilled_cache.items():
|
||||||
|
self.entity_has_changed(entity, stream_pos)
|
||||||
|
|
||||||
def has_entity_changed(self, entity, stream_pos):
|
def has_entity_changed(self, entity, stream_pos):
|
||||||
"""Returns True if the entity may have been updated since stream_pos
|
"""Returns True if the entity may have been updated since stream_pos
|
||||||
"""
|
"""
|
||||||
|
@ -49,15 +52,10 @@ class StreamChangeCache(object):
|
||||||
cache_counter.inc_misses(self.name)
|
cache_counter.inc_misses(self.name)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if stream_pos == self._earliest_known_stream_pos:
|
|
||||||
# If the same as the earliest key, assume nothing has changed.
|
|
||||||
cache_counter.inc_hits(self.name)
|
|
||||||
return False
|
|
||||||
|
|
||||||
latest_entity_change_pos = self._entity_to_key.get(entity, None)
|
latest_entity_change_pos = self._entity_to_key.get(entity, None)
|
||||||
if latest_entity_change_pos is None:
|
if latest_entity_change_pos is None:
|
||||||
cache_counter.inc_misses(self.name)
|
cache_counter.inc_hits(self.name)
|
||||||
return True
|
return False
|
||||||
|
|
||||||
if stream_pos < latest_entity_change_pos:
|
if stream_pos < latest_entity_change_pos:
|
||||||
cache_counter.inc_misses(self.name)
|
cache_counter.inc_misses(self.name)
|
||||||
|
@ -95,7 +93,7 @@ class StreamChangeCache(object):
|
||||||
|
|
||||||
if stream_pos > self._earliest_known_stream_pos:
|
if stream_pos > self._earliest_known_stream_pos:
|
||||||
old_pos = self._entity_to_key.get(entity, None)
|
old_pos = self._entity_to_key.get(entity, None)
|
||||||
if old_pos:
|
if old_pos is not None:
|
||||||
stream_pos = max(stream_pos, old_pos)
|
stream_pos = max(stream_pos, old_pos)
|
||||||
self._cache.pop(old_pos, None)
|
self._cache.pop(old_pos, None)
|
||||||
self._cache[stream_pos] = entity
|
self._cache[stream_pos] = entity
|
||||||
|
|
|
@ -58,7 +58,7 @@ class TreeCache(object):
|
||||||
|
|
||||||
if n:
|
if n:
|
||||||
break
|
break
|
||||||
node_and_keys[i+1][0].pop(k)
|
node_and_keys[i + 1][0].pop(k)
|
||||||
|
|
||||||
popped, cnt = _strip_and_count_entires(popped)
|
popped, cnt = _strip_and_count_entires(popped)
|
||||||
self.size -= cnt
|
self.size -= cnt
|
||||||
|
|
|
@ -15,9 +15,7 @@
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.util.logcontext import (
|
from synapse.util.logcontext import PreserveLoggingContext
|
||||||
PreserveLoggingContext, preserve_context_over_deferred,
|
|
||||||
)
|
|
||||||
|
|
||||||
from synapse.util import unwrapFirstError
|
from synapse.util import unwrapFirstError
|
||||||
|
|
||||||
|
@ -97,6 +95,7 @@ class Signal(object):
|
||||||
Each observer callable may return a Deferred."""
|
Each observer callable may return a Deferred."""
|
||||||
self.observers.append(observer)
|
self.observers.append(observer)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def fire(self, *args, **kwargs):
|
def fire(self, *args, **kwargs):
|
||||||
"""Invokes every callable in the observer list, passing in the args and
|
"""Invokes every callable in the observer list, passing in the args and
|
||||||
kwargs. Exceptions thrown by observers are logged but ignored. It is
|
kwargs. Exceptions thrown by observers are logged but ignored. It is
|
||||||
|
@ -116,6 +115,7 @@ class Signal(object):
|
||||||
failure.getTracebackObject()))
|
failure.getTracebackObject()))
|
||||||
if not self.suppress_failures:
|
if not self.suppress_failures:
|
||||||
return failure
|
return failure
|
||||||
|
|
||||||
return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb)
|
return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb)
|
||||||
|
|
||||||
with PreserveLoggingContext():
|
with PreserveLoggingContext():
|
||||||
|
@ -124,8 +124,11 @@ class Signal(object):
|
||||||
for observer in self.observers
|
for observer in self.observers
|
||||||
]
|
]
|
||||||
|
|
||||||
d = defer.gatherResults(deferreds, consumeErrors=True)
|
res = yield defer.gatherResults(
|
||||||
|
deferreds, consumeErrors=True
|
||||||
|
).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
d.addErrback(unwrapFirstError)
|
defer.returnValue(res)
|
||||||
|
|
||||||
return preserve_context_over_deferred(d)
|
def __repr__(self):
|
||||||
|
return "<Signal name=%r>" % (self.name,)
|
||||||
|
|
|
@ -41,13 +41,14 @@ except:
|
||||||
|
|
||||||
class LoggingContext(object):
|
class LoggingContext(object):
|
||||||
"""Additional context for log formatting. Contexts are scoped within a
|
"""Additional context for log formatting. Contexts are scoped within a
|
||||||
"with" block. Contexts inherit the state of their parent contexts.
|
"with" block.
|
||||||
Args:
|
Args:
|
||||||
name (str): Name for the context for debugging.
|
name (str): Name for the context for debugging.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__slots__ = [
|
__slots__ = [
|
||||||
"parent_context", "name", "usage_start", "usage_end", "main_thread", "__dict__"
|
"previous_context", "name", "usage_start", "usage_end", "main_thread",
|
||||||
|
"__dict__", "tag", "alive",
|
||||||
]
|
]
|
||||||
|
|
||||||
thread_local = threading.local()
|
thread_local = threading.local()
|
||||||
|
@ -72,10 +73,13 @@ class LoggingContext(object):
|
||||||
def add_database_transaction(self, duration_ms):
|
def add_database_transaction(self, duration_ms):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def __nonzero__(self):
|
||||||
|
return False
|
||||||
|
|
||||||
sentinel = Sentinel()
|
sentinel = Sentinel()
|
||||||
|
|
||||||
def __init__(self, name=None):
|
def __init__(self, name=None):
|
||||||
self.parent_context = None
|
self.previous_context = LoggingContext.current_context()
|
||||||
self.name = name
|
self.name = name
|
||||||
self.ru_stime = 0.
|
self.ru_stime = 0.
|
||||||
self.ru_utime = 0.
|
self.ru_utime = 0.
|
||||||
|
@ -83,6 +87,8 @@ class LoggingContext(object):
|
||||||
self.db_txn_duration = 0.
|
self.db_txn_duration = 0.
|
||||||
self.usage_start = None
|
self.usage_start = None
|
||||||
self.main_thread = threading.current_thread()
|
self.main_thread = threading.current_thread()
|
||||||
|
self.tag = ""
|
||||||
|
self.alive = True
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return "%s@%x" % (self.name, id(self))
|
return "%s@%x" % (self.name, id(self))
|
||||||
|
@ -101,6 +107,7 @@ class LoggingContext(object):
|
||||||
The context that was previously active
|
The context that was previously active
|
||||||
"""
|
"""
|
||||||
current = cls.current_context()
|
current = cls.current_context()
|
||||||
|
|
||||||
if current is not context:
|
if current is not context:
|
||||||
current.stop()
|
current.stop()
|
||||||
cls.thread_local.current_context = context
|
cls.thread_local.current_context = context
|
||||||
|
@ -109,9 +116,13 @@ class LoggingContext(object):
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
"""Enters this logging context into thread local storage"""
|
"""Enters this logging context into thread local storage"""
|
||||||
if self.parent_context is not None:
|
old_context = self.set_current_context(self)
|
||||||
raise Exception("Attempt to enter logging context multiple times")
|
if self.previous_context != old_context:
|
||||||
self.parent_context = self.set_current_context(self)
|
logger.warn(
|
||||||
|
"Expected previous context %r, found %r",
|
||||||
|
self.previous_context, old_context
|
||||||
|
)
|
||||||
|
self.alive = True
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, type, value, traceback):
|
def __exit__(self, type, value, traceback):
|
||||||
|
@ -120,7 +131,7 @@ class LoggingContext(object):
|
||||||
Returns:
|
Returns:
|
||||||
None to avoid suppressing any exeptions that were thrown.
|
None to avoid suppressing any exeptions that were thrown.
|
||||||
"""
|
"""
|
||||||
current = self.set_current_context(self.parent_context)
|
current = self.set_current_context(self.previous_context)
|
||||||
if current is not self:
|
if current is not self:
|
||||||
if current is self.sentinel:
|
if current is self.sentinel:
|
||||||
logger.debug("Expected logging context %s has been lost", self)
|
logger.debug("Expected logging context %s has been lost", self)
|
||||||
|
@ -130,16 +141,11 @@ class LoggingContext(object):
|
||||||
current,
|
current,
|
||||||
self
|
self
|
||||||
)
|
)
|
||||||
self.parent_context = None
|
self.previous_context = None
|
||||||
|
self.alive = False
|
||||||
def __getattr__(self, name):
|
|
||||||
"""Delegate member lookup to parent context"""
|
|
||||||
return getattr(self.parent_context, name)
|
|
||||||
|
|
||||||
def copy_to(self, record):
|
def copy_to(self, record):
|
||||||
"""Copy fields from this context and its parents to the record"""
|
"""Copy fields from this context to the record"""
|
||||||
if self.parent_context is not None:
|
|
||||||
self.parent_context.copy_to(record)
|
|
||||||
for key, value in self.__dict__.items():
|
for key, value in self.__dict__.items():
|
||||||
setattr(record, key, value)
|
setattr(record, key, value)
|
||||||
|
|
||||||
|
@ -208,7 +214,7 @@ class PreserveLoggingContext(object):
|
||||||
exited. Used to restore the context after a function using
|
exited. Used to restore the context after a function using
|
||||||
@defer.inlineCallbacks is resumed by a callback from the reactor."""
|
@defer.inlineCallbacks is resumed by a callback from the reactor."""
|
||||||
|
|
||||||
__slots__ = ["current_context", "new_context"]
|
__slots__ = ["current_context", "new_context", "has_parent"]
|
||||||
|
|
||||||
def __init__(self, new_context=LoggingContext.sentinel):
|
def __init__(self, new_context=LoggingContext.sentinel):
|
||||||
self.new_context = new_context
|
self.new_context = new_context
|
||||||
|
@ -219,12 +225,27 @@ class PreserveLoggingContext(object):
|
||||||
self.new_context
|
self.new_context
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.current_context:
|
||||||
|
self.has_parent = self.current_context.previous_context is not None
|
||||||
|
if not self.current_context.alive:
|
||||||
|
logger.debug(
|
||||||
|
"Entering dead context: %s",
|
||||||
|
self.current_context,
|
||||||
|
)
|
||||||
|
|
||||||
def __exit__(self, type, value, traceback):
|
def __exit__(self, type, value, traceback):
|
||||||
"""Restores the current logging context"""
|
"""Restores the current logging context"""
|
||||||
LoggingContext.set_current_context(self.current_context)
|
context = LoggingContext.set_current_context(self.current_context)
|
||||||
|
|
||||||
|
if context != self.new_context:
|
||||||
|
logger.debug(
|
||||||
|
"Unexpected logging context: %s is not %s",
|
||||||
|
context, self.new_context,
|
||||||
|
)
|
||||||
|
|
||||||
if self.current_context is not LoggingContext.sentinel:
|
if self.current_context is not LoggingContext.sentinel:
|
||||||
if self.current_context.parent_context is None:
|
if not self.current_context.alive:
|
||||||
logger.warn(
|
logger.debug(
|
||||||
"Restoring dead context: %s",
|
"Restoring dead context: %s",
|
||||||
self.current_context,
|
self.current_context,
|
||||||
)
|
)
|
||||||
|
@ -284,3 +305,74 @@ def preserve_context_over_deferred(deferred):
|
||||||
d = _PreservingContextDeferred(current_context)
|
d = _PreservingContextDeferred(current_context)
|
||||||
deferred.chainDeferred(d)
|
deferred.chainDeferred(d)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
def preserve_fn(f):
|
||||||
|
"""Ensures that function is called with correct context and that context is
|
||||||
|
restored after return. Useful for wrapping functions that return a deferred
|
||||||
|
which you don't yield on.
|
||||||
|
"""
|
||||||
|
current = LoggingContext.current_context()
|
||||||
|
|
||||||
|
def g(*args, **kwargs):
|
||||||
|
with PreserveLoggingContext(current):
|
||||||
|
return f(*args, **kwargs)
|
||||||
|
|
||||||
|
return g
|
||||||
|
|
||||||
|
|
||||||
|
# modules to ignore in `logcontext_tracer`
|
||||||
|
_to_ignore = [
|
||||||
|
"synapse.util.logcontext",
|
||||||
|
"synapse.http.server",
|
||||||
|
"synapse.storage._base",
|
||||||
|
"synapse.util.async",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def logcontext_tracer(frame, event, arg):
|
||||||
|
"""A tracer that logs whenever a logcontext "unexpectedly" changes within
|
||||||
|
a function. Probably inaccurate.
|
||||||
|
|
||||||
|
Use by calling `sys.settrace(logcontext_tracer)` in the main thread.
|
||||||
|
"""
|
||||||
|
if event == 'call':
|
||||||
|
name = frame.f_globals["__name__"]
|
||||||
|
if name.startswith("synapse"):
|
||||||
|
if name == "synapse.util.logcontext":
|
||||||
|
if frame.f_code.co_name in ["__enter__", "__exit__"]:
|
||||||
|
tracer = frame.f_back.f_trace
|
||||||
|
if tracer:
|
||||||
|
tracer.just_changed = True
|
||||||
|
|
||||||
|
tracer = frame.f_trace
|
||||||
|
if tracer:
|
||||||
|
return tracer
|
||||||
|
|
||||||
|
if not any(name.startswith(ig) for ig in _to_ignore):
|
||||||
|
return LineTracer()
|
||||||
|
|
||||||
|
|
||||||
|
class LineTracer(object):
|
||||||
|
__slots__ = ["context", "just_changed"]
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.context = LoggingContext.current_context()
|
||||||
|
self.just_changed = False
|
||||||
|
|
||||||
|
def __call__(self, frame, event, arg):
|
||||||
|
if event in 'line':
|
||||||
|
if self.just_changed:
|
||||||
|
self.context = LoggingContext.current_context()
|
||||||
|
self.just_changed = False
|
||||||
|
else:
|
||||||
|
c = LoggingContext.current_context()
|
||||||
|
if c != self.context:
|
||||||
|
logger.info(
|
||||||
|
"Context changed! %s -> %s, %s, %s",
|
||||||
|
self.context, c,
|
||||||
|
frame.f_code.co_filename, frame.f_lineno
|
||||||
|
)
|
||||||
|
self.context = c
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
|
@ -111,7 +111,7 @@ def time_function(f):
|
||||||
_log_debug_as_f(
|
_log_debug_as_f(
|
||||||
f,
|
f,
|
||||||
"[FUNC END] {%s-%d} %f",
|
"[FUNC END] {%s-%d} %f",
|
||||||
(func_name, id, end-start,),
|
(func_name, id, end - start,),
|
||||||
)
|
)
|
||||||
|
|
||||||
return r
|
return r
|
||||||
|
@ -168,3 +168,38 @@ def trace_function(f):
|
||||||
|
|
||||||
wrapped.__name__ = func_name
|
wrapped.__name__ = func_name
|
||||||
return wrapped
|
return wrapped
|
||||||
|
|
||||||
|
|
||||||
|
def get_previous_frames():
|
||||||
|
s = inspect.currentframe().f_back.f_back
|
||||||
|
to_return = []
|
||||||
|
while s:
|
||||||
|
if s.f_globals["__name__"].startswith("synapse"):
|
||||||
|
filename, lineno, function, _, _ = inspect.getframeinfo(s)
|
||||||
|
args_string = inspect.formatargvalues(*inspect.getargvalues(s))
|
||||||
|
|
||||||
|
to_return.append("{{ %s:%d %s - Args: %s }}" % (
|
||||||
|
filename, lineno, function, args_string
|
||||||
|
))
|
||||||
|
|
||||||
|
s = s.f_back
|
||||||
|
|
||||||
|
return ", ". join(to_return)
|
||||||
|
|
||||||
|
|
||||||
|
def get_previous_frame(ignore=[]):
|
||||||
|
s = inspect.currentframe().f_back.f_back
|
||||||
|
|
||||||
|
while s:
|
||||||
|
if s.f_globals["__name__"].startswith("synapse"):
|
||||||
|
if not any(s.f_globals["__name__"].startswith(ig) for ig in ignore):
|
||||||
|
filename, lineno, function, _, _ = inspect.getframeinfo(s)
|
||||||
|
args_string = inspect.formatargvalues(*inspect.getargvalues(s))
|
||||||
|
|
||||||
|
return "{{ %s:%d %s - Args: %s }}" % (
|
||||||
|
filename, lineno, function, args_string
|
||||||
|
)
|
||||||
|
|
||||||
|
s = s.f_back
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
97
synapse/util/metrics.py
Normal file
97
synapse/util/metrics.py
Normal file
|
@ -0,0 +1,97 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from synapse.util.logcontext import LoggingContext
|
||||||
|
import synapse.metrics
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
metrics = synapse.metrics.get_metrics_for(__name__)
|
||||||
|
|
||||||
|
block_timer = metrics.register_distribution(
|
||||||
|
"block_timer",
|
||||||
|
labels=["block_name"]
|
||||||
|
)
|
||||||
|
|
||||||
|
block_ru_utime = metrics.register_distribution(
|
||||||
|
"block_ru_utime", labels=["block_name"]
|
||||||
|
)
|
||||||
|
|
||||||
|
block_ru_stime = metrics.register_distribution(
|
||||||
|
"block_ru_stime", labels=["block_name"]
|
||||||
|
)
|
||||||
|
|
||||||
|
block_db_txn_count = metrics.register_distribution(
|
||||||
|
"block_db_txn_count", labels=["block_name"]
|
||||||
|
)
|
||||||
|
|
||||||
|
block_db_txn_duration = metrics.register_distribution(
|
||||||
|
"block_db_txn_duration", labels=["block_name"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Measure(object):
|
||||||
|
__slots__ = [
|
||||||
|
"clock", "name", "start_context", "start", "new_context", "ru_utime",
|
||||||
|
"ru_stime", "db_txn_count", "db_txn_duration"
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, clock, name):
|
||||||
|
self.clock = clock
|
||||||
|
self.name = name
|
||||||
|
self.start_context = None
|
||||||
|
self.start = None
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.start = self.clock.time_msec()
|
||||||
|
self.start_context = LoggingContext.current_context()
|
||||||
|
if self.start_context:
|
||||||
|
self.ru_utime, self.ru_stime = self.start_context.get_resource_usage()
|
||||||
|
self.db_txn_count = self.start_context.db_txn_count
|
||||||
|
self.db_txn_duration = self.start_context.db_txn_duration
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
if exc_type is not None or not self.start_context:
|
||||||
|
return
|
||||||
|
|
||||||
|
duration = self.clock.time_msec() - self.start
|
||||||
|
block_timer.inc_by(duration, self.name)
|
||||||
|
|
||||||
|
context = LoggingContext.current_context()
|
||||||
|
|
||||||
|
if context != self.start_context:
|
||||||
|
logger.warn(
|
||||||
|
"Context have unexpectedly changed from '%s' to '%s'. (%r)",
|
||||||
|
context, self.start_context, self.name
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if not context:
|
||||||
|
logger.warn("Expected context. (%r)", self.name)
|
||||||
|
return
|
||||||
|
|
||||||
|
ru_utime, ru_stime = context.get_resource_usage()
|
||||||
|
|
||||||
|
block_ru_utime.inc_by(ru_utime - self.ru_utime, self.name)
|
||||||
|
block_ru_stime.inc_by(ru_stime - self.ru_stime, self.name)
|
||||||
|
block_db_txn_count.inc_by(context.db_txn_count - self.db_txn_count, self.name)
|
||||||
|
block_db_txn_duration.inc_by(
|
||||||
|
context.db_txn_duration - self.db_txn_duration, self.name
|
||||||
|
)
|
|
@ -18,6 +18,7 @@ from twisted.internet import defer
|
||||||
from synapse.api.errors import LimitExceededError
|
from synapse.api.errors import LimitExceededError
|
||||||
|
|
||||||
from synapse.util.async import sleep
|
from synapse.util.async import sleep
|
||||||
|
from synapse.util.logcontext import preserve_fn
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
import contextlib
|
import contextlib
|
||||||
|
@ -163,7 +164,7 @@ class _PerHostRatelimiter(object):
|
||||||
"Ratelimit [%s]: sleeping req",
|
"Ratelimit [%s]: sleeping req",
|
||||||
id(request_id),
|
id(request_id),
|
||||||
)
|
)
|
||||||
ret_defer = sleep(self.sleep_msec/1000.0)
|
ret_defer = preserve_fn(sleep)(self.sleep_msec / 1000.0)
|
||||||
|
|
||||||
self.sleeping_requests.add(request_id)
|
self.sleeping_requests.add(request_id)
|
||||||
|
|
||||||
|
|
14
tests/config/__init__.py
Normal file
14
tests/config/__init__.py
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
50
tests/config/test_generate.py
Normal file
50
tests/config/test_generate.py
Normal file
|
@ -0,0 +1,50 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import os.path
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
|
from tests import unittest
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigGenerationTestCase(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.dir = tempfile.mkdtemp()
|
||||||
|
print self.dir
|
||||||
|
self.file = os.path.join(self.dir, "homeserver.yaml")
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
shutil.rmtree(self.dir)
|
||||||
|
|
||||||
|
def test_generate_config_generates_files(self):
|
||||||
|
HomeServerConfig.load_config("", [
|
||||||
|
"--generate-config",
|
||||||
|
"-c", self.file,
|
||||||
|
"--report-stats=yes",
|
||||||
|
"-H", "lemurs.win"
|
||||||
|
])
|
||||||
|
|
||||||
|
self.assertSetEqual(
|
||||||
|
set([
|
||||||
|
"homeserver.yaml",
|
||||||
|
"lemurs.win.log.config",
|
||||||
|
"lemurs.win.signing.key",
|
||||||
|
"lemurs.win.tls.crt",
|
||||||
|
"lemurs.win.tls.dh",
|
||||||
|
"lemurs.win.tls.key",
|
||||||
|
]),
|
||||||
|
set(os.listdir(self.dir))
|
||||||
|
)
|
78
tests/config/test_load.py
Normal file
78
tests/config/test_load.py
Normal file
|
@ -0,0 +1,78 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import os.path
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
import yaml
|
||||||
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
|
from tests import unittest
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigLoadingTestCase(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.dir = tempfile.mkdtemp()
|
||||||
|
print self.dir
|
||||||
|
self.file = os.path.join(self.dir, "homeserver.yaml")
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
shutil.rmtree(self.dir)
|
||||||
|
|
||||||
|
def test_load_fails_if_server_name_missing(self):
|
||||||
|
self.generate_config_and_remove_lines_containing("server_name")
|
||||||
|
with self.assertRaises(Exception):
|
||||||
|
HomeServerConfig.load_config("", ["-c", self.file])
|
||||||
|
|
||||||
|
def test_generates_and_loads_macaroon_secret_key(self):
|
||||||
|
self.generate_config()
|
||||||
|
|
||||||
|
with open(self.file,
|
||||||
|
"r") as f:
|
||||||
|
raw = yaml.load(f)
|
||||||
|
self.assertIn("macaroon_secret_key", raw)
|
||||||
|
|
||||||
|
config = HomeServerConfig.load_config("", ["-c", self.file])
|
||||||
|
self.assertTrue(
|
||||||
|
hasattr(config, "macaroon_secret_key"),
|
||||||
|
"Want config to have attr macaroon_secret_key"
|
||||||
|
)
|
||||||
|
if len(config.macaroon_secret_key) < 5:
|
||||||
|
self.fail(
|
||||||
|
"Want macaroon secret key to be string of at least length 5,"
|
||||||
|
"was: %r" % (config.macaroon_secret_key,)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_load_succeeds_if_macaroon_secret_key_missing(self):
|
||||||
|
self.generate_config_and_remove_lines_containing("macaroon")
|
||||||
|
config1 = HomeServerConfig.load_config("", ["-c", self.file])
|
||||||
|
config2 = HomeServerConfig.load_config("", ["-c", self.file])
|
||||||
|
self.assertEqual(config1.macaroon_secret_key, config2.macaroon_secret_key)
|
||||||
|
|
||||||
|
def generate_config(self):
|
||||||
|
HomeServerConfig.load_config("", [
|
||||||
|
"--generate-config",
|
||||||
|
"-c", self.file,
|
||||||
|
"--report-stats=yes",
|
||||||
|
"-H", "lemurs.win"
|
||||||
|
])
|
||||||
|
|
||||||
|
def generate_config_and_remove_lines_containing(self, needle):
|
||||||
|
self.generate_config()
|
||||||
|
|
||||||
|
with open(self.file, "r") as f:
|
||||||
|
contents = f.readlines()
|
||||||
|
contents = [l for l in contents if needle not in l]
|
||||||
|
with open(self.file, "w") as f:
|
||||||
|
f.write("".join(contents))
|
|
@ -122,7 +122,7 @@ class EventStreamPermissionsTestCase(RestTestCase):
|
||||||
self.ratelimiter = hs.get_ratelimiter()
|
self.ratelimiter = hs.get_ratelimiter()
|
||||||
self.ratelimiter.send_message.return_value = (True, 0)
|
self.ratelimiter.send_message.return_value = (True, 0)
|
||||||
hs.config.enable_registration_captcha = False
|
hs.config.enable_registration_captcha = False
|
||||||
hs.config.disable_registration = False
|
hs.config.enable_registration = True
|
||||||
|
|
||||||
hs.get_handlers().federation_handler = Mock()
|
hs.get_handlers().federation_handler = Mock()
|
||||||
|
|
||||||
|
|
|
@ -41,7 +41,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
||||||
self.hs.hostname = "superbig~testing~thing.com"
|
self.hs.hostname = "superbig~testing~thing.com"
|
||||||
self.hs.get_auth = Mock(return_value=self.auth)
|
self.hs.get_auth = Mock(return_value=self.auth)
|
||||||
self.hs.get_handlers = Mock(return_value=self.handlers)
|
self.hs.get_handlers = Mock(return_value=self.handlers)
|
||||||
self.hs.config.disable_registration = False
|
self.hs.config.enable_registration = True
|
||||||
|
|
||||||
# init the thing we're testing
|
# init the thing we're testing
|
||||||
self.servlet = RegisterRestServlet(self.hs)
|
self.servlet = RegisterRestServlet(self.hs)
|
||||||
|
@ -120,7 +120,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
||||||
}))
|
}))
|
||||||
|
|
||||||
def test_POST_disabled_registration(self):
|
def test_POST_disabled_registration(self):
|
||||||
self.hs.config.disable_registration = True
|
self.hs.config.enable_registration = False
|
||||||
self.request_data = json.dumps({
|
self.request_data = json.dumps({
|
||||||
"username": "kermit",
|
"username": "kermit",
|
||||||
"password": "monkey"
|
"password": "monkey"
|
||||||
|
|
|
@ -51,32 +51,6 @@ class RoomStoreTestCase(unittest.TestCase):
|
||||||
(yield self.store.get_room(self.room.to_string()))
|
(yield self.store.get_room(self.room.to_string()))
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_get_rooms(self):
|
|
||||||
# get_rooms does an INNER JOIN on the room_aliases table :(
|
|
||||||
|
|
||||||
rooms = yield self.store.get_rooms(is_public=True)
|
|
||||||
# Should be empty before we add the alias
|
|
||||||
self.assertEquals([], rooms)
|
|
||||||
|
|
||||||
yield self.store.create_room_alias_association(
|
|
||||||
room_alias=self.alias,
|
|
||||||
room_id=self.room.to_string(),
|
|
||||||
servers=["test"]
|
|
||||||
)
|
|
||||||
|
|
||||||
rooms = yield self.store.get_rooms(is_public=True)
|
|
||||||
|
|
||||||
self.assertEquals(1, len(rooms))
|
|
||||||
self.assertEquals({
|
|
||||||
"name": None,
|
|
||||||
"room_id": self.room.to_string(),
|
|
||||||
"topic": None,
|
|
||||||
"aliases": [self.alias.to_string()],
|
|
||||||
"world_readable": False,
|
|
||||||
"guest_can_join": False,
|
|
||||||
}, rooms[0])
|
|
||||||
|
|
||||||
|
|
||||||
class RoomEventsStoreTestCase(unittest.TestCase):
|
class RoomEventsStoreTestCase(unittest.TestCase):
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ from .. import unittest
|
||||||
from synapse.util.async import sleep
|
from synapse.util.async import sleep
|
||||||
from synapse.util.logcontext import LoggingContext
|
from synapse.util.logcontext import LoggingContext
|
||||||
|
|
||||||
|
|
||||||
class LoggingContextTestCase(unittest.TestCase):
|
class LoggingContextTestCase(unittest.TestCase):
|
||||||
|
|
||||||
def _check_test_key(self, value):
|
def _check_test_key(self, value):
|
||||||
|
@ -17,15 +18,6 @@ class LoggingContextTestCase(unittest.TestCase):
|
||||||
context_one.test_key = "test"
|
context_one.test_key = "test"
|
||||||
self._check_test_key("test")
|
self._check_test_key("test")
|
||||||
|
|
||||||
def test_chaining(self):
|
|
||||||
with LoggingContext() as context_one:
|
|
||||||
context_one.test_key = "one"
|
|
||||||
with LoggingContext() as context_two:
|
|
||||||
self._check_test_key("one")
|
|
||||||
context_two.test_key = "two"
|
|
||||||
self._check_test_key("two")
|
|
||||||
self._check_test_key("one")
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_sleep(self):
|
def test_sleep(self):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
|
@ -46,9 +46,10 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
|
||||||
config = Mock()
|
config = Mock()
|
||||||
config.signing_key = [MockKey()]
|
config.signing_key = [MockKey()]
|
||||||
config.event_cache_size = 1
|
config.event_cache_size = 1
|
||||||
config.disable_registration = False
|
config.enable_registration = True
|
||||||
config.macaroon_secret_key = "not even a little secret"
|
config.macaroon_secret_key = "not even a little secret"
|
||||||
config.server_name = "server.under.test"
|
config.server_name = "server.under.test"
|
||||||
|
config.trusted_third_party_id_servers = []
|
||||||
|
|
||||||
if "clock" not in kargs:
|
if "clock" not in kargs:
|
||||||
kargs["clock"] = MockClock()
|
kargs["clock"] = MockClock()
|
||||||
|
|
2
tox.ini
2
tox.ini
|
@ -11,7 +11,7 @@ deps =
|
||||||
setenv =
|
setenv =
|
||||||
PYTHONDONTWRITEBYTECODE = no_byte_code
|
PYTHONDONTWRITEBYTECODE = no_byte_code
|
||||||
commands =
|
commands =
|
||||||
/bin/bash -c "coverage run {env:COVERAGE_OPTS:} --source={toxinidir}/synapse \
|
/bin/bash -c "find {toxinidir} -name '*.pyc' -delete ; coverage run {env:COVERAGE_OPTS:} --source={toxinidir}/synapse \
|
||||||
{envbindir}/trial {env:TRIAL_FLAGS:} {posargs:tests} {env:TOXSUFFIX:}"
|
{envbindir}/trial {env:TRIAL_FLAGS:} {posargs:tests} {env:TOXSUFFIX:}"
|
||||||
{env:DUMP_COVERAGE_COMMAND:coverage report -m}
|
{env:DUMP_COVERAGE_COMMAND:coverage report -m}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue