mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 18:53:53 +01:00
Merge branch 'release-v0.5.1' of github.com:matrix-org/synapse
This commit is contained in:
commit
48ee9ddb22
57 changed files with 1040 additions and 735 deletions
|
@ -1,3 +1,11 @@
|
|||
Changes in synapse 0.5.1 (2014-11-26)
|
||||
=====================================
|
||||
See UPGRADES.rst for specific instructions on how to upgrade.
|
||||
|
||||
* Fix bug where we served up an Event that did not match its signatures.
|
||||
* Fix regression where we no longer correctly handled the case where a
|
||||
homeserver receives an event for a room it doesn't recognise (but is in.)
|
||||
|
||||
Changes in synapse 0.5.0 (2014-11-19)
|
||||
=====================================
|
||||
This release includes changes to the federation protocol and client-server API
|
||||
|
|
12
README.rst
12
README.rst
|
@ -69,8 +69,8 @@ command line utility which lets you easily see what the JSON APIs are up to).
|
|||
|
||||
Meanwhile, iOS and Android SDKs and clients are currently in development and available from:
|
||||
|
||||
* https://github.com/matrix-org/matrix-ios-sdk
|
||||
* https://github.com/matrix-org/matrix-android-sdk
|
||||
- https://github.com/matrix-org/matrix-ios-sdk
|
||||
- https://github.com/matrix-org/matrix-android-sdk
|
||||
|
||||
We'd like to invite you to join #matrix:matrix.org (via http://matrix.org/alpha), run a homeserver, take a look at the Matrix spec at
|
||||
http://matrix.org/docs/spec, experiment with the APIs and the demo
|
||||
|
@ -94,7 +94,7 @@ header files for python C extensions.
|
|||
Installing prerequisites on Ubuntu or Debian::
|
||||
|
||||
$ sudo apt-get install build-essential python2.7-dev libffi-dev \
|
||||
python-pip python-setuptools
|
||||
python-pip python-setuptools sqlite3
|
||||
|
||||
Installing prerequisites on Mac OS X::
|
||||
|
||||
|
@ -125,7 +125,7 @@ created. To reset the installation::
|
|||
pip seems to leak *lots* of memory during installation. For instance, a Linux
|
||||
host with 512MB of RAM may run out of memory whilst installing Twisted. If this
|
||||
happens, you will have to individually install the dependencies which are
|
||||
failing, e.g.:
|
||||
failing, e.g.::
|
||||
|
||||
$ pip install --user twisted
|
||||
|
||||
|
@ -148,7 +148,7 @@ Troubleshooting Running
|
|||
-----------------------
|
||||
|
||||
If ``synctl`` fails with ``pkg_resources.DistributionNotFound`` errors you may
|
||||
need a newer version of setuptools than that provided by your OS.
|
||||
need a newer version of setuptools than that provided by your OS.::
|
||||
|
||||
$ sudo pip install setuptools --upgrade
|
||||
|
||||
|
@ -172,7 +172,7 @@ Homeserver Development
|
|||
======================
|
||||
|
||||
To check out a homeserver for development, clone the git repo into a working
|
||||
directory of your choice:
|
||||
directory of your choice::
|
||||
|
||||
$ git clone https://github.com/matrix-org/synapse.git
|
||||
$ cd synapse
|
||||
|
|
|
@ -1,3 +1,12 @@
|
|||
Upgrading to v0.5.1
|
||||
===================
|
||||
|
||||
Depending on precisely when you installed v0.5.0 you may have ended up with
|
||||
a stale release of the reference matrix webclient installed as a python module.
|
||||
To uninstall it and ensure you are depending on the latest module, please run::
|
||||
|
||||
$ pip uninstall syweb
|
||||
|
||||
Upgrading to v0.5.0
|
||||
===================
|
||||
|
||||
|
|
2
VERSION
2
VERSION
|
@ -1 +1 @@
|
|||
0.5.0
|
||||
0.5.1
|
||||
|
|
|
@ -23,7 +23,7 @@ def get_targets(server_name):
|
|||
for srv in answers:
|
||||
yield (srv.target, srv.port)
|
||||
except dns.resolver.NXDOMAIN:
|
||||
yield (server_name, 8480)
|
||||
yield (server_name, 8448)
|
||||
|
||||
def get_server_keys(server_name, target, port):
|
||||
url = "https://%s:%i/_matrix/key/v1" % (target, port)
|
||||
|
|
4
setup.py
4
setup.py
|
@ -32,7 +32,7 @@ setup(
|
|||
description="Reference Synapse Home Server",
|
||||
install_requires=[
|
||||
"syutil==0.0.2",
|
||||
"matrix_angular_sdk==0.5.0",
|
||||
"matrix_angular_sdk==0.5.1",
|
||||
"Twisted>=14.0.0",
|
||||
"service_identity>=1.0.0",
|
||||
"pyopenssl>=0.14",
|
||||
|
@ -45,7 +45,7 @@ setup(
|
|||
dependency_links=[
|
||||
"https://github.com/matrix-org/syutil/tarball/v0.0.2#egg=syutil-0.0.2",
|
||||
"https://github.com/pyca/pynacl/tarball/52dbe2dc33f1#egg=pynacl-0.3.0",
|
||||
"https://github.com/matrix-org/matrix-angular-sdk/tarball/v0.5.0/#egg=matrix_angular_sdk-0.5.0",
|
||||
"https://github.com/matrix-org/matrix-angular-sdk/tarball/v0.5.1/#egg=matrix_angular_sdk-0.5.1",
|
||||
],
|
||||
setup_requires=[
|
||||
"setuptools_trial",
|
||||
|
|
|
@ -16,4 +16,4 @@
|
|||
""" This is a reference implementation of a synapse home server.
|
||||
"""
|
||||
|
||||
__version__ = "0.5.0"
|
||||
__version__ = "0.5.1"
|
||||
|
|
|
@ -38,79 +38,66 @@ class Auth(object):
|
|||
self.store = hs.get_datastore()
|
||||
self.state = hs.get_state_handler()
|
||||
|
||||
def check(self, event, raises=False):
|
||||
def check(self, event, auth_events):
|
||||
""" Checks if this event is correctly authed.
|
||||
|
||||
Returns:
|
||||
True if the auth checks pass.
|
||||
Raises:
|
||||
AuthError if there was a problem authorising this event. This will
|
||||
be raised only if raises=True.
|
||||
"""
|
||||
try:
|
||||
if hasattr(event, "room_id"):
|
||||
if event.old_state_events is None:
|
||||
# Oh, we don't know what the state of the room was, so we
|
||||
# are trusting that this is allowed (at least for now)
|
||||
logger.warn("Trusting event: %s", event.event_id)
|
||||
return True
|
||||
|
||||
if hasattr(event, "outlier") and event.outlier is True:
|
||||
# TODO (erikj): Auth for outliers is done differently.
|
||||
return True
|
||||
|
||||
if event.type == RoomCreateEvent.TYPE:
|
||||
# FIXME
|
||||
return True
|
||||
|
||||
# FIXME: Temp hack
|
||||
if event.type == RoomAliasesEvent.TYPE:
|
||||
return True
|
||||
|
||||
if event.type == RoomMemberEvent.TYPE:
|
||||
allowed = self.is_membership_change_allowed(event)
|
||||
if allowed:
|
||||
logger.debug("Allowing! %s", event)
|
||||
else:
|
||||
logger.debug("Denying! %s", event)
|
||||
return allowed
|
||||
|
||||
self.check_event_sender_in_room(event)
|
||||
self._can_send_event(event)
|
||||
|
||||
if event.type == RoomPowerLevelsEvent.TYPE:
|
||||
self._check_power_levels(event)
|
||||
|
||||
if event.type == RoomRedactionEvent.TYPE:
|
||||
self._check_redaction(event)
|
||||
|
||||
logger.debug("Allowing! %s", event)
|
||||
if not hasattr(event, "room_id"):
|
||||
raise AuthError(500, "Event has no room_id: %s" % event)
|
||||
if auth_events is None:
|
||||
# Oh, we don't know what the state of the room was, so we
|
||||
# are trusting that this is allowed (at least for now)
|
||||
logger.warn("Trusting event: %s", event.event_id)
|
||||
return True
|
||||
else:
|
||||
raise AuthError(500, "Unknown event: %s" % event)
|
||||
|
||||
if event.type == RoomCreateEvent.TYPE:
|
||||
# FIXME
|
||||
return True
|
||||
|
||||
# FIXME: Temp hack
|
||||
if event.type == RoomAliasesEvent.TYPE:
|
||||
return True
|
||||
|
||||
if event.type == RoomMemberEvent.TYPE:
|
||||
allowed = self.is_membership_change_allowed(
|
||||
event, auth_events
|
||||
)
|
||||
if allowed:
|
||||
logger.debug("Allowing! %s", event)
|
||||
else:
|
||||
logger.debug("Denying! %s", event)
|
||||
return allowed
|
||||
|
||||
self.check_event_sender_in_room(event, auth_events)
|
||||
self._can_send_event(event, auth_events)
|
||||
|
||||
if event.type == RoomPowerLevelsEvent.TYPE:
|
||||
self._check_power_levels(event, auth_events)
|
||||
|
||||
if event.type == RoomRedactionEvent.TYPE:
|
||||
self._check_redaction(event, auth_events)
|
||||
|
||||
logger.debug("Allowing! %s", event)
|
||||
except AuthError as e:
|
||||
logger.info(
|
||||
"Event auth check failed on event %s with msg: %s",
|
||||
event, e.msg
|
||||
)
|
||||
logger.info("Denying! %s", event)
|
||||
if raises:
|
||||
raise
|
||||
|
||||
return False
|
||||
raise
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_joined_room(self, room_id, user_id):
|
||||
try:
|
||||
member = yield self.store.get_room_member(
|
||||
room_id=room_id,
|
||||
user_id=user_id
|
||||
)
|
||||
self._check_joined_room(member, user_id, room_id)
|
||||
defer.returnValue(member)
|
||||
except AttributeError:
|
||||
pass
|
||||
defer.returnValue(None)
|
||||
member = yield self.state.get_current_state(
|
||||
room_id=room_id,
|
||||
event_type=RoomMemberEvent.TYPE,
|
||||
state_key=user_id
|
||||
)
|
||||
self._check_joined_room(member, user_id, room_id)
|
||||
defer.returnValue(member)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_host_in_room(self, room_id, host):
|
||||
|
@ -130,9 +117,9 @@ class Auth(object):
|
|||
|
||||
defer.returnValue(False)
|
||||
|
||||
def check_event_sender_in_room(self, event):
|
||||
def check_event_sender_in_room(self, event, auth_events):
|
||||
key = (RoomMemberEvent.TYPE, event.user_id, )
|
||||
member_event = event.state_events.get(key)
|
||||
member_event = auth_events.get(key)
|
||||
|
||||
return self._check_joined_room(
|
||||
member_event,
|
||||
|
@ -147,15 +134,15 @@ class Auth(object):
|
|||
))
|
||||
|
||||
@log_function
|
||||
def is_membership_change_allowed(self, event):
|
||||
def is_membership_change_allowed(self, event, auth_events):
|
||||
membership = event.content["membership"]
|
||||
|
||||
# Check if this is the room creator joining:
|
||||
if len(event.prev_events) == 1 and Membership.JOIN == membership:
|
||||
# Get room creation event:
|
||||
key = (RoomCreateEvent.TYPE, "", )
|
||||
create = event.old_state_events.get(key)
|
||||
if event.prev_events[0][0] == create.event_id:
|
||||
create = auth_events.get(key)
|
||||
if create and event.prev_events[0][0] == create.event_id:
|
||||
if create.content["creator"] == event.state_key:
|
||||
return True
|
||||
|
||||
|
@ -163,19 +150,19 @@ class Auth(object):
|
|||
|
||||
# get info about the caller
|
||||
key = (RoomMemberEvent.TYPE, event.user_id, )
|
||||
caller = event.old_state_events.get(key)
|
||||
caller = auth_events.get(key)
|
||||
|
||||
caller_in_room = caller and caller.membership == Membership.JOIN
|
||||
caller_invited = caller and caller.membership == Membership.INVITE
|
||||
|
||||
# get info about the target
|
||||
key = (RoomMemberEvent.TYPE, target_user_id, )
|
||||
target = event.old_state_events.get(key)
|
||||
target = auth_events.get(key)
|
||||
|
||||
target_in_room = target and target.membership == Membership.JOIN
|
||||
|
||||
key = (RoomJoinRulesEvent.TYPE, "", )
|
||||
join_rule_event = event.old_state_events.get(key)
|
||||
join_rule_event = auth_events.get(key)
|
||||
if join_rule_event:
|
||||
join_rule = join_rule_event.content.get(
|
||||
"join_rule", JoinRules.INVITE
|
||||
|
@ -186,11 +173,13 @@ class Auth(object):
|
|||
user_level = self._get_power_level_from_event_state(
|
||||
event,
|
||||
event.user_id,
|
||||
auth_events,
|
||||
)
|
||||
|
||||
ban_level, kick_level, redact_level = (
|
||||
self._get_ops_level_from_event_state(
|
||||
event
|
||||
event,
|
||||
auth_events,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -260,9 +249,9 @@ class Auth(object):
|
|||
|
||||
return True
|
||||
|
||||
def _get_power_level_from_event_state(self, event, user_id):
|
||||
def _get_power_level_from_event_state(self, event, user_id, auth_events):
|
||||
key = (RoomPowerLevelsEvent.TYPE, "", )
|
||||
power_level_event = event.old_state_events.get(key)
|
||||
power_level_event = auth_events.get(key)
|
||||
level = None
|
||||
if power_level_event:
|
||||
level = power_level_event.content.get("users", {}).get(user_id)
|
||||
|
@ -270,16 +259,16 @@ class Auth(object):
|
|||
level = power_level_event.content.get("users_default", 0)
|
||||
else:
|
||||
key = (RoomCreateEvent.TYPE, "", )
|
||||
create_event = event.old_state_events.get(key)
|
||||
create_event = auth_events.get(key)
|
||||
if (create_event is not None and
|
||||
create_event.content["creator"] == user_id):
|
||||
create_event.content["creator"] == user_id):
|
||||
return 100
|
||||
|
||||
return level
|
||||
|
||||
def _get_ops_level_from_event_state(self, event):
|
||||
def _get_ops_level_from_event_state(self, event, auth_events):
|
||||
key = (RoomPowerLevelsEvent.TYPE, "", )
|
||||
power_level_event = event.old_state_events.get(key)
|
||||
power_level_event = auth_events.get(key)
|
||||
|
||||
if power_level_event:
|
||||
return (
|
||||
|
@ -375,6 +364,11 @@ class Auth(object):
|
|||
key = (RoomMemberEvent.TYPE, event.user_id, )
|
||||
member_event = event.old_state_events.get(key)
|
||||
|
||||
key = (RoomCreateEvent.TYPE, "", )
|
||||
create_event = event.old_state_events.get(key)
|
||||
if create_event:
|
||||
auth_events.append(create_event.event_id)
|
||||
|
||||
if join_rule_event:
|
||||
join_rule = join_rule_event.content.get("join_rule")
|
||||
is_public = join_rule == JoinRules.PUBLIC if join_rule else False
|
||||
|
@ -406,9 +400,9 @@ class Auth(object):
|
|||
event.auth_events = zip(auth_events, hashes)
|
||||
|
||||
@log_function
|
||||
def _can_send_event(self, event):
|
||||
def _can_send_event(self, event, auth_events):
|
||||
key = (RoomPowerLevelsEvent.TYPE, "", )
|
||||
send_level_event = event.old_state_events.get(key)
|
||||
send_level_event = auth_events.get(key)
|
||||
send_level = None
|
||||
if send_level_event:
|
||||
send_level = send_level_event.content.get("events", {}).get(
|
||||
|
@ -432,6 +426,7 @@ class Auth(object):
|
|||
user_level = self._get_power_level_from_event_state(
|
||||
event,
|
||||
event.user_id,
|
||||
auth_events,
|
||||
)
|
||||
|
||||
if user_level:
|
||||
|
@ -468,14 +463,16 @@ class Auth(object):
|
|||
|
||||
return True
|
||||
|
||||
def _check_redaction(self, event):
|
||||
def _check_redaction(self, event, auth_events):
|
||||
user_level = self._get_power_level_from_event_state(
|
||||
event,
|
||||
event.user_id,
|
||||
auth_events,
|
||||
)
|
||||
|
||||
_, _, redact_level = self._get_ops_level_from_event_state(
|
||||
event
|
||||
event,
|
||||
auth_events,
|
||||
)
|
||||
|
||||
if user_level < redact_level:
|
||||
|
@ -484,7 +481,7 @@ class Auth(object):
|
|||
"You don't have permission to redact events"
|
||||
)
|
||||
|
||||
def _check_power_levels(self, event):
|
||||
def _check_power_levels(self, event, auth_events):
|
||||
user_list = event.content.get("users", {})
|
||||
# Validate users
|
||||
for k, v in user_list.items():
|
||||
|
@ -499,7 +496,7 @@ class Auth(object):
|
|||
raise SynapseError(400, "Not a valid power level: %s" % (v,))
|
||||
|
||||
key = (event.type, event.state_key, )
|
||||
current_state = event.old_state_events.get(key)
|
||||
current_state = auth_events.get(key)
|
||||
|
||||
if not current_state:
|
||||
return
|
||||
|
@ -507,6 +504,7 @@ class Auth(object):
|
|||
user_level = self._get_power_level_from_event_state(
|
||||
event,
|
||||
event.user_id,
|
||||
auth_events,
|
||||
)
|
||||
|
||||
# Check other levels:
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Codes(object):
|
||||
UNAUTHORIZED = "M_UNAUTHORIZED"
|
||||
|
@ -38,7 +40,7 @@ class CodeMessageException(Exception):
|
|||
"""An exception with integer code and message string attributes."""
|
||||
|
||||
def __init__(self, code, msg):
|
||||
logging.error("%s: %s, %s", type(self).__name__, code, msg)
|
||||
logger.info("%s: %s, %s", type(self).__name__, code, msg)
|
||||
super(CodeMessageException, self).__init__("%d: %s" % (code, msg))
|
||||
self.code = code
|
||||
self.msg = msg
|
||||
|
@ -140,7 +142,8 @@ def cs_exception(exception):
|
|||
if isinstance(exception, CodeMessageException):
|
||||
return exception.error_dict()
|
||||
else:
|
||||
logging.error("Unknown exception type: %s", type(exception))
|
||||
logger.error("Unknown exception type: %s", type(exception))
|
||||
return {}
|
||||
|
||||
|
||||
def cs_error(msg, code=Codes.UNKNOWN, **kwargs):
|
||||
|
|
|
@ -83,6 +83,8 @@ class SynapseEvent(JsonEncodedObject):
|
|||
"content",
|
||||
]
|
||||
|
||||
outlier = False
|
||||
|
||||
def __init__(self, raises=True, **kwargs):
|
||||
super(SynapseEvent, self).__init__(**kwargs)
|
||||
# if "content" in kwargs:
|
||||
|
@ -123,6 +125,7 @@ class SynapseEvent(JsonEncodedObject):
|
|||
pdu_json.pop("outlier", None)
|
||||
pdu_json.pop("replaces_state", None)
|
||||
pdu_json.pop("redacted", None)
|
||||
pdu_json.pop("prev_content", None)
|
||||
state_hash = pdu_json.pop("state_hash", None)
|
||||
if state_hash is not None:
|
||||
pdu_json.setdefault("unsigned", {})["state_hash"] = state_hash
|
||||
|
|
|
@ -26,7 +26,7 @@ from twisted.web.server import Site
|
|||
from synapse.http.server import JsonResource, RootRedirect
|
||||
from synapse.http.content_repository import ContentRepoResource
|
||||
from synapse.http.server_key_resource import LocalKey
|
||||
from synapse.http.client import MatrixHttpClient
|
||||
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
||||
from synapse.api.urls import (
|
||||
CLIENT_PREFIX, FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX,
|
||||
SERVER_KEY_PREFIX,
|
||||
|
@ -51,7 +51,7 @@ logger = logging.getLogger(__name__)
|
|||
class SynapseHomeServer(HomeServer):
|
||||
|
||||
def build_http_client(self):
|
||||
return MatrixHttpClient(self)
|
||||
return MatrixFederationHttpClient(self)
|
||||
|
||||
def build_resource_for_client(self):
|
||||
return JsonResource()
|
||||
|
@ -116,7 +116,7 @@ class SynapseHomeServer(HomeServer):
|
|||
# extra resources to existing nodes. See self._resource_id for the key.
|
||||
resource_mappings = {}
|
||||
for (full_path, resource) in desired_tree:
|
||||
logging.info("Attaching %s to path %s", resource, full_path)
|
||||
logger.info("Attaching %s to path %s", resource, full_path)
|
||||
last_resource = self.root_resource
|
||||
for path_seg in full_path.split('/')[1:-1]:
|
||||
if not path_seg in last_resource.listNames():
|
||||
|
@ -221,12 +221,12 @@ def setup():
|
|||
|
||||
db_name = hs.get_db_name()
|
||||
|
||||
logging.info("Preparing database: %s...", db_name)
|
||||
logger.info("Preparing database: %s...", db_name)
|
||||
|
||||
with sqlite3.connect(db_name) as db_conn:
|
||||
prepare_database(db_conn)
|
||||
|
||||
logging.info("Database prepared in %s.", db_name)
|
||||
logger.info("Database prepared in %s.", db_name)
|
||||
|
||||
hs.get_db_pool()
|
||||
|
||||
|
@ -257,13 +257,16 @@ def setup():
|
|||
else:
|
||||
reactor.run()
|
||||
|
||||
|
||||
def run():
|
||||
with LoggingContext("run"):
|
||||
reactor.run()
|
||||
|
||||
|
||||
def main():
|
||||
with LoggingContext("main"):
|
||||
setup()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
|
@ -21,11 +21,12 @@ import signal
|
|||
|
||||
SYNAPSE = ["python", "-m", "synapse.app.homeserver"]
|
||||
|
||||
CONFIGFILE="homeserver.yaml"
|
||||
PIDFILE="homeserver.pid"
|
||||
CONFIGFILE = "homeserver.yaml"
|
||||
PIDFILE = "homeserver.pid"
|
||||
|
||||
GREEN = "\x1b[1;32m"
|
||||
NORMAL = "\x1b[m"
|
||||
|
||||
GREEN="\x1b[1;32m"
|
||||
NORMAL="\x1b[m"
|
||||
|
||||
def start():
|
||||
if not os.path.exists(CONFIGFILE):
|
||||
|
@ -43,12 +44,14 @@ def start():
|
|||
subprocess.check_call(args)
|
||||
print GREEN + "started" + NORMAL
|
||||
|
||||
|
||||
def stop():
|
||||
if os.path.exists(PIDFILE):
|
||||
pid = int(open(PIDFILE).read())
|
||||
os.kill(pid, signal.SIGTERM)
|
||||
print GREEN + "stopped" + NORMAL
|
||||
|
||||
|
||||
def main():
|
||||
action = sys.argv[1] if sys.argv[1:] else "usage"
|
||||
if action == "start":
|
||||
|
@ -62,5 +65,6 @@ def main():
|
|||
sys.stderr.write("Usage: %s [start|stop|restart]\n" % (sys.argv[0],))
|
||||
sys.exit(1)
|
||||
|
||||
if __name__=='__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
|
|||
def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
|
||||
"""Check whether the hash for this PDU matches the contents"""
|
||||
computed_hash = _compute_content_hash(event, hash_algorithm)
|
||||
logging.debug("Expecting hash: %s", encode_base64(computed_hash.digest()))
|
||||
logger.debug("Expecting hash: %s", encode_base64(computed_hash.digest()))
|
||||
if computed_hash.name not in event.hashes:
|
||||
raise SynapseError(
|
||||
400,
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
from twisted.web.http import HTTPClient
|
||||
from twisted.internet.protocol import Factory
|
||||
from twisted.internet import defer, reactor
|
||||
from synapse.http.endpoint import matrix_endpoint
|
||||
from synapse.http.endpoint import matrix_federation_endpoint
|
||||
from synapse.util.logcontext import PreserveLoggingContext
|
||||
import json
|
||||
import logging
|
||||
|
@ -31,7 +31,7 @@ def fetch_server_key(server_name, ssl_context_factory):
|
|||
"""Fetch the keys for a remote server."""
|
||||
|
||||
factory = SynapseKeyClientFactory()
|
||||
endpoint = matrix_endpoint(
|
||||
endpoint = matrix_federation_endpoint(
|
||||
reactor, server_name, ssl_context_factory, timeout=30
|
||||
)
|
||||
|
||||
|
@ -48,7 +48,7 @@ def fetch_server_key(server_name, ssl_context_factory):
|
|||
|
||||
|
||||
class SynapseKeyClientError(Exception):
|
||||
"""The key wasn't retireved from the remote server."""
|
||||
"""The key wasn't retrieved from the remote server."""
|
||||
pass
|
||||
|
||||
|
||||
|
|
|
@ -135,7 +135,7 @@ class Keyring(object):
|
|||
|
||||
time_now_ms = self.clock.time_msec()
|
||||
|
||||
self.store.store_server_certificate(
|
||||
yield self.store.store_server_certificate(
|
||||
server_name,
|
||||
server_name,
|
||||
time_now_ms,
|
||||
|
@ -143,7 +143,7 @@ class Keyring(object):
|
|||
)
|
||||
|
||||
for key_id, key in verify_keys.items():
|
||||
self.store.store_server_verify_key(
|
||||
yield self.store.store_server_verify_key(
|
||||
server_name, server_name, time_now_ms, key
|
||||
)
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@ from .units import Transaction, Edu
|
|||
from .persistence import TransactionActions
|
||||
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.logcontext import PreserveLoggingContext
|
||||
|
||||
import logging
|
||||
|
||||
|
@ -319,19 +320,20 @@ class ReplicationLayer(object):
|
|||
|
||||
logger.debug("[%s] Transacition is new", transaction.transaction_id)
|
||||
|
||||
dl = []
|
||||
for pdu in pdu_list:
|
||||
dl.append(self._handle_new_pdu(transaction.origin, pdu))
|
||||
with PreserveLoggingContext():
|
||||
dl = []
|
||||
for pdu in pdu_list:
|
||||
dl.append(self._handle_new_pdu(transaction.origin, pdu))
|
||||
|
||||
if hasattr(transaction, "edus"):
|
||||
for edu in [Edu(**x) for x in transaction.edus]:
|
||||
self.received_edu(
|
||||
transaction.origin,
|
||||
edu.edu_type,
|
||||
edu.content
|
||||
)
|
||||
if hasattr(transaction, "edus"):
|
||||
for edu in [Edu(**x) for x in transaction.edus]:
|
||||
self.received_edu(
|
||||
transaction.origin,
|
||||
edu.edu_type,
|
||||
edu.content
|
||||
)
|
||||
|
||||
results = yield defer.DeferredList(dl)
|
||||
results = yield defer.DeferredList(dl)
|
||||
|
||||
ret = []
|
||||
for r in results:
|
||||
|
@ -425,7 +427,9 @@ class ReplicationLayer(object):
|
|||
time_now = self._clock.time_msec()
|
||||
defer.returnValue((200, {
|
||||
"state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
|
||||
"auth_chain": [p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]],
|
||||
"auth_chain": [
|
||||
p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]
|
||||
],
|
||||
}))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -436,7 +440,9 @@ class ReplicationLayer(object):
|
|||
(
|
||||
200,
|
||||
{
|
||||
"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
|
||||
"auth_chain": [
|
||||
a.get_pdu_json(time_now) for a in auth_pdus
|
||||
],
|
||||
}
|
||||
)
|
||||
)
|
||||
|
@ -457,7 +463,7 @@ class ReplicationLayer(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def send_join(self, destination, pdu):
|
||||
time_now = self._clock.time_msec()
|
||||
time_now = self._clock.time_msec()
|
||||
_, content = yield self.transport_layer.send_join(
|
||||
destination,
|
||||
pdu.room_id,
|
||||
|
@ -475,11 +481,17 @@ class ReplicationLayer(object):
|
|||
# FIXME: We probably want to do something with the auth_chain given
|
||||
# to us
|
||||
|
||||
# auth_chain = [
|
||||
# Pdu(outlier=True, **p) for p in content.get("auth_chain", [])
|
||||
# ]
|
||||
auth_chain = [
|
||||
self.event_from_pdu_json(p, outlier=True)
|
||||
for p in content.get("auth_chain", [])
|
||||
]
|
||||
|
||||
defer.returnValue(state)
|
||||
auth_chain.sort(key=lambda e: e.depth)
|
||||
|
||||
defer.returnValue({
|
||||
"state": state,
|
||||
"auth_chain": auth_chain,
|
||||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_invite(self, destination, context, event_id, pdu):
|
||||
|
@ -498,13 +510,15 @@ class ReplicationLayer(object):
|
|||
defer.returnValue(self.event_from_pdu_json(pdu_dict))
|
||||
|
||||
@log_function
|
||||
def _get_persisted_pdu(self, origin, event_id):
|
||||
def _get_persisted_pdu(self, origin, event_id, do_auth=True):
|
||||
""" Get a PDU from the database with given origin and id.
|
||||
|
||||
Returns:
|
||||
Deferred: Results in a `Pdu`.
|
||||
"""
|
||||
return self.handler.get_persisted_pdu(origin, event_id)
|
||||
return self.handler.get_persisted_pdu(
|
||||
origin, event_id, do_auth=do_auth
|
||||
)
|
||||
|
||||
def _transaction_from_pdus(self, pdu_list):
|
||||
"""Returns a new Transaction containing the given PDUs suitable for
|
||||
|
@ -523,7 +537,9 @@ class ReplicationLayer(object):
|
|||
@log_function
|
||||
def _handle_new_pdu(self, origin, pdu, backfilled=False):
|
||||
# We reprocess pdus when we have seen them only as outliers
|
||||
existing = yield self._get_persisted_pdu(origin, pdu.event_id)
|
||||
existing = yield self._get_persisted_pdu(
|
||||
origin, pdu.event_id, do_auth=False
|
||||
)
|
||||
|
||||
if existing and (not existing.outlier or pdu.outlier):
|
||||
logger.debug("Already seen pdu %s", pdu.event_id)
|
||||
|
@ -532,6 +548,36 @@ class ReplicationLayer(object):
|
|||
|
||||
state = None
|
||||
|
||||
# We need to make sure we have all the auth events.
|
||||
for e_id, _ in pdu.auth_events:
|
||||
exists = yield self._get_persisted_pdu(
|
||||
origin,
|
||||
e_id,
|
||||
do_auth=False
|
||||
)
|
||||
|
||||
if not exists:
|
||||
try:
|
||||
logger.debug(
|
||||
"_handle_new_pdu fetch missing auth event %s from %s",
|
||||
e_id,
|
||||
origin,
|
||||
)
|
||||
|
||||
yield self.get_pdu(
|
||||
origin,
|
||||
event_id=e_id,
|
||||
outlier=True,
|
||||
)
|
||||
|
||||
logger.debug("Processed pdu %s", e_id)
|
||||
except:
|
||||
logger.warn(
|
||||
"Failed to get auth event %s from %s",
|
||||
e_id,
|
||||
origin
|
||||
)
|
||||
|
||||
# Get missing pdus if necessary.
|
||||
if not pdu.outlier:
|
||||
# We only backfill backwards to the min depth.
|
||||
|
@ -539,16 +585,28 @@ class ReplicationLayer(object):
|
|||
pdu.room_id
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"_handle_new_pdu min_depth for %s: %d",
|
||||
pdu.room_id, min_depth
|
||||
)
|
||||
|
||||
if min_depth and pdu.depth > min_depth:
|
||||
for event_id, hashes in pdu.prev_events:
|
||||
exists = yield self._get_persisted_pdu(origin, event_id)
|
||||
exists = yield self._get_persisted_pdu(
|
||||
origin,
|
||||
event_id,
|
||||
do_auth=False
|
||||
)
|
||||
|
||||
if not exists:
|
||||
logger.debug("Requesting pdu %s", event_id)
|
||||
logger.debug(
|
||||
"_handle_new_pdu requesting pdu %s",
|
||||
event_id
|
||||
)
|
||||
|
||||
try:
|
||||
yield self.get_pdu(
|
||||
pdu.origin,
|
||||
origin,
|
||||
event_id=event_id,
|
||||
)
|
||||
logger.debug("Processed pdu %s", event_id)
|
||||
|
@ -558,6 +616,10 @@ class ReplicationLayer(object):
|
|||
else:
|
||||
# We need to get the state at this event, since we have reached
|
||||
# a backward extremity edge.
|
||||
logger.debug(
|
||||
"_handle_new_pdu getting state for %s",
|
||||
pdu.room_id
|
||||
)
|
||||
state = yield self.get_state_for_context(
|
||||
origin, pdu.room_id, pdu.event_id,
|
||||
)
|
||||
|
@ -649,7 +711,8 @@ class _TransactionQueue(object):
|
|||
(pdu, deferred, order)
|
||||
)
|
||||
|
||||
self._attempt_new_transaction(destination)
|
||||
with PreserveLoggingContext():
|
||||
self._attempt_new_transaction(destination)
|
||||
|
||||
deferreds.append(deferred)
|
||||
|
||||
|
@ -669,7 +732,9 @@ class _TransactionQueue(object):
|
|||
deferred.errback(failure)
|
||||
else:
|
||||
logger.exception("Failed to send edu", failure)
|
||||
self._attempt_new_transaction(destination).addErrback(eb)
|
||||
|
||||
with PreserveLoggingContext():
|
||||
self._attempt_new_transaction(destination).addErrback(eb)
|
||||
|
||||
return deferred
|
||||
|
||||
|
|
|
@ -25,7 +25,6 @@ import logging
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
class Edu(JsonEncodedObject):
|
||||
""" An Edu represents a piece of data sent from one homeserver to another.
|
||||
|
||||
|
|
|
@ -78,7 +78,7 @@ class BaseHandler(object):
|
|||
|
||||
if not suppress_auth:
|
||||
logger.debug("Authing...")
|
||||
self.auth.check(event, raises=True)
|
||||
self.auth.check(event, auth_events=event.old_state_events)
|
||||
logger.debug("Authed")
|
||||
else:
|
||||
logger.debug("Suppressed auth.")
|
||||
|
@ -112,7 +112,7 @@ class BaseHandler(object):
|
|||
|
||||
event.destinations = list(destinations)
|
||||
|
||||
self.notifier.on_new_room_event(event, extra_users=extra_users)
|
||||
yield self.notifier.on_new_room_event(event, extra_users=extra_users)
|
||||
|
||||
federation_handler = self.hs.get_handlers().federation_handler
|
||||
yield federation_handler.handle_new_event(event, snapshot)
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
from twisted.internet import defer
|
||||
from ._base import BaseHandler
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.api.errors import SynapseError, Codes, CodeMessageException
|
||||
from synapse.api.events.room import RoomAliasesEvent
|
||||
|
||||
import logging
|
||||
|
@ -84,22 +84,32 @@ class DirectoryHandler(BaseHandler):
|
|||
room_id = result.room_id
|
||||
servers = result.servers
|
||||
else:
|
||||
result = yield self.federation.make_query(
|
||||
destination=room_alias.domain,
|
||||
query_type="directory",
|
||||
args={
|
||||
"room_alias": room_alias.to_string(),
|
||||
},
|
||||
retry_on_dns_fail=False,
|
||||
)
|
||||
try:
|
||||
result = yield self.federation.make_query(
|
||||
destination=room_alias.domain,
|
||||
query_type="directory",
|
||||
args={
|
||||
"room_alias": room_alias.to_string(),
|
||||
},
|
||||
retry_on_dns_fail=False,
|
||||
)
|
||||
except CodeMessageException as e:
|
||||
logging.warn("Error retrieving alias")
|
||||
if e.code == 404:
|
||||
result = None
|
||||
else:
|
||||
raise
|
||||
|
||||
if result and "room_id" in result and "servers" in result:
|
||||
room_id = result["room_id"]
|
||||
servers = result["servers"]
|
||||
|
||||
if not room_id:
|
||||
defer.returnValue({})
|
||||
return
|
||||
raise SynapseError(
|
||||
404,
|
||||
"Room alias %r not found" % (room_alias.to_string(),),
|
||||
Codes.NOT_FOUND
|
||||
)
|
||||
|
||||
extra_servers = yield self.store.get_joined_hosts_for_room(room_id)
|
||||
servers = list(set(extra_servers) | set(servers))
|
||||
|
@ -128,8 +138,11 @@ class DirectoryHandler(BaseHandler):
|
|||
"servers": result.servers,
|
||||
})
|
||||
else:
|
||||
raise SynapseError(404, "Room alias \"%s\" not found", room_alias)
|
||||
|
||||
raise SynapseError(
|
||||
404,
|
||||
"Room alias %r not found" % (room_alias.to_string(),),
|
||||
Codes.NOT_FOUND
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_room_alias_update_event(self, user_id, room_id):
|
||||
|
|
|
@ -56,7 +56,7 @@ class EventStreamHandler(BaseHandler):
|
|||
self.clock.cancel_call_later(
|
||||
self._stop_timer_per_user.pop(auth_user))
|
||||
else:
|
||||
self.distributor.fire(
|
||||
yield self.distributor.fire(
|
||||
"started_user_eventstream", auth_user
|
||||
)
|
||||
self._streams_per_user[auth_user] += 1
|
||||
|
@ -65,8 +65,10 @@ class EventStreamHandler(BaseHandler):
|
|||
pagin_config.from_token = None
|
||||
|
||||
rm_handler = self.hs.get_handlers().room_member_handler
|
||||
logger.debug("BETA")
|
||||
room_ids = yield rm_handler.get_rooms_for_user(auth_user)
|
||||
|
||||
logger.debug("ALPHA")
|
||||
with PreserveLoggingContext():
|
||||
events, tokens = yield self.notifier.get_events_for(
|
||||
auth_user, room_ids, pagin_config, timeout
|
||||
|
@ -93,7 +95,7 @@ class EventStreamHandler(BaseHandler):
|
|||
logger.debug(
|
||||
"_later stopped_user_eventstream %s", auth_user
|
||||
)
|
||||
self.distributor.fire(
|
||||
yield self.distributor.fire(
|
||||
"stopped_user_eventstream", auth_user
|
||||
)
|
||||
del self._stop_timer_per_user[auth_user]
|
||||
|
|
|
@ -24,7 +24,8 @@ from synapse.api.constants import Membership
|
|||
from synapse.util.logutils import log_function
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.crypto.event_signing import (
|
||||
compute_event_signature, check_event_content_hash
|
||||
compute_event_signature, check_event_content_hash,
|
||||
add_hashes_and_signatures,
|
||||
)
|
||||
from syutil.jsonutil import encode_canonical_json
|
||||
|
||||
|
@ -122,7 +123,8 @@ class FederationHandler(BaseHandler):
|
|||
event.origin, redacted_pdu_json
|
||||
)
|
||||
except SynapseError as e:
|
||||
logger.warn("Signature check failed for %s redacted to %s",
|
||||
logger.warn(
|
||||
"Signature check failed for %s redacted to %s",
|
||||
encode_canonical_json(pdu.get_pdu_json()),
|
||||
encode_canonical_json(redacted_pdu_json),
|
||||
)
|
||||
|
@ -140,15 +142,27 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
event = redacted_event
|
||||
|
||||
is_new_state = yield self.state_handler.annotate_event_with_state(
|
||||
event,
|
||||
old_state=state
|
||||
)
|
||||
|
||||
logger.debug("Event: %s", event)
|
||||
|
||||
# FIXME (erikj): Awful hack to make the case where we are not currently
|
||||
# in the room work
|
||||
current_state = None
|
||||
if state:
|
||||
is_in_room = yield self.auth.check_host_in_room(
|
||||
event.room_id,
|
||||
self.server_name
|
||||
)
|
||||
if not is_in_room:
|
||||
logger.debug("Got event for room we're not in.")
|
||||
current_state = state
|
||||
|
||||
try:
|
||||
self.auth.check(event, raises=True)
|
||||
yield self._handle_new_event(
|
||||
event,
|
||||
state=state,
|
||||
backfilled=backfilled,
|
||||
current_state=current_state,
|
||||
)
|
||||
except AuthError as e:
|
||||
raise FederationError(
|
||||
"ERROR",
|
||||
|
@ -157,43 +171,14 @@ class FederationHandler(BaseHandler):
|
|||
affected=event.event_id,
|
||||
)
|
||||
|
||||
is_new_state = is_new_state and not backfilled
|
||||
|
||||
# TODO: Implement something in federation that allows us to
|
||||
# respond to PDU.
|
||||
|
||||
yield self.store.persist_event(
|
||||
event,
|
||||
backfilled,
|
||||
is_new_state=is_new_state
|
||||
)
|
||||
|
||||
room = yield self.store.get_room(event.room_id)
|
||||
|
||||
if not room:
|
||||
# Huh, let's try and get the current state
|
||||
try:
|
||||
yield self.replication_layer.get_state_for_context(
|
||||
event.origin, event.room_id, event.event_id,
|
||||
)
|
||||
|
||||
hosts = yield self.store.get_joined_hosts_for_room(
|
||||
event.room_id
|
||||
)
|
||||
if self.hs.hostname in hosts:
|
||||
try:
|
||||
yield self.store.store_room(
|
||||
room_id=event.room_id,
|
||||
room_creator_user_id="",
|
||||
is_public=False,
|
||||
)
|
||||
except:
|
||||
pass
|
||||
except:
|
||||
logger.exception(
|
||||
"Failed to get current state for room %s",
|
||||
event.room_id
|
||||
)
|
||||
yield self.store.store_room(
|
||||
room_id=event.room_id,
|
||||
room_creator_user_id="",
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
if not backfilled:
|
||||
extra_users = []
|
||||
|
@ -209,7 +194,7 @@ class FederationHandler(BaseHandler):
|
|||
if event.type == RoomMemberEvent.TYPE:
|
||||
if event.membership == Membership.JOIN:
|
||||
user = self.hs.parse_userid(event.state_key)
|
||||
self.distributor.fire(
|
||||
yield self.distributor.fire(
|
||||
"user_joined_room", user=user, room_id=event.room_id
|
||||
)
|
||||
|
||||
|
@ -254,6 +239,8 @@ class FederationHandler(BaseHandler):
|
|||
pdu=event
|
||||
)
|
||||
|
||||
|
||||
|
||||
defer.returnValue(pdu)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -275,6 +262,8 @@ class FederationHandler(BaseHandler):
|
|||
We suspend processing of any received events from this room until we
|
||||
have finished processing the join.
|
||||
"""
|
||||
logger.debug("Joining %s to %s", joinee, room_id)
|
||||
|
||||
pdu = yield self.replication_layer.make_join(
|
||||
target_host,
|
||||
room_id,
|
||||
|
@ -297,19 +286,28 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
try:
|
||||
event.event_id = self.event_factory.create_event_id()
|
||||
event.origin = self.hs.hostname
|
||||
event.content = content
|
||||
|
||||
state = yield self.replication_layer.send_join(
|
||||
if not hasattr(event, "signatures"):
|
||||
event.signatures = {}
|
||||
|
||||
add_hashes_and_signatures(
|
||||
event,
|
||||
self.hs.hostname,
|
||||
self.hs.config.signing_key[0],
|
||||
)
|
||||
|
||||
ret = yield self.replication_layer.send_join(
|
||||
target_host,
|
||||
event
|
||||
)
|
||||
|
||||
logger.debug("do_invite_join state: %s", state)
|
||||
state = ret["state"]
|
||||
auth_chain = ret["auth_chain"]
|
||||
|
||||
yield self.state_handler.annotate_event_with_state(
|
||||
event,
|
||||
old_state=state
|
||||
)
|
||||
logger.debug("do_invite_join auth_chain: %s", auth_chain)
|
||||
logger.debug("do_invite_join state: %s", state)
|
||||
|
||||
logger.debug("do_invite_join event: %s", event)
|
||||
|
||||
|
@ -323,34 +321,41 @@ class FederationHandler(BaseHandler):
|
|||
# FIXME
|
||||
pass
|
||||
|
||||
for e in auth_chain:
|
||||
e.outlier = True
|
||||
yield self._handle_new_event(e)
|
||||
yield self.notifier.on_new_room_event(
|
||||
e, extra_users=[joinee]
|
||||
)
|
||||
|
||||
for e in state:
|
||||
# FIXME: Auth these.
|
||||
e.outlier = True
|
||||
|
||||
yield self.state_handler.annotate_event_with_state(
|
||||
e,
|
||||
yield self._handle_new_event(e)
|
||||
yield self.notifier.on_new_room_event(
|
||||
e, extra_users=[joinee]
|
||||
)
|
||||
|
||||
yield self.store.persist_event(
|
||||
e,
|
||||
backfilled=False,
|
||||
is_new_state=True
|
||||
)
|
||||
|
||||
yield self.store.persist_event(
|
||||
yield self._handle_new_event(
|
||||
event,
|
||||
backfilled=False,
|
||||
is_new_state=True
|
||||
state=state,
|
||||
current_state=state
|
||||
)
|
||||
|
||||
yield self.notifier.on_new_room_event(
|
||||
event, extra_users=[joinee]
|
||||
)
|
||||
|
||||
logger.debug("Finished joining %s to %s", joinee, room_id)
|
||||
finally:
|
||||
room_queue = self.room_queues[room_id]
|
||||
del self.room_queues[room_id]
|
||||
|
||||
for p in room_queue:
|
||||
try:
|
||||
yield self.on_receive_pdu(p, backfilled=False)
|
||||
self.on_receive_pdu(p, backfilled=False)
|
||||
except:
|
||||
pass
|
||||
logger.exception("Couldn't handle pdu")
|
||||
|
||||
defer.returnValue(True)
|
||||
|
||||
|
@ -374,7 +379,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
yield self.state_handler.annotate_event_with_state(event)
|
||||
yield self.auth.add_auth_events(event)
|
||||
self.auth.check(event, raises=True)
|
||||
self.auth.check(event, auth_events=event.old_state_events)
|
||||
|
||||
pdu = event
|
||||
|
||||
|
@ -390,16 +395,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
event.outlier = False
|
||||
|
||||
is_new_state = yield self.state_handler.annotate_event_with_state(event)
|
||||
self.auth.check(event, raises=True)
|
||||
|
||||
# FIXME (erikj): All this is duplicated above :(
|
||||
|
||||
yield self.store.persist_event(
|
||||
event,
|
||||
backfilled=False,
|
||||
is_new_state=is_new_state
|
||||
)
|
||||
yield self._handle_new_event(event)
|
||||
|
||||
extra_users = []
|
||||
if event.type == RoomMemberEvent.TYPE:
|
||||
|
@ -412,9 +408,9 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
|
||||
if event.type == RoomMemberEvent.TYPE:
|
||||
if event.membership == Membership.JOIN:
|
||||
if event.content["membership"] == Membership.JOIN:
|
||||
user = self.hs.parse_userid(event.state_key)
|
||||
self.distributor.fire(
|
||||
yield self.distributor.fire(
|
||||
"user_joined_room", user=user, room_id=event.room_id
|
||||
)
|
||||
|
||||
|
@ -527,7 +523,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def get_persisted_pdu(self, origin, event_id):
|
||||
def get_persisted_pdu(self, origin, event_id, do_auth=True):
|
||||
""" Get a PDU from the database with given origin and id.
|
||||
|
||||
Returns:
|
||||
|
@ -539,12 +535,13 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
|
||||
if event:
|
||||
in_room = yield self.auth.check_host_in_room(
|
||||
event.room_id,
|
||||
origin
|
||||
)
|
||||
if not in_room:
|
||||
raise AuthError(403, "Host not in room.")
|
||||
if do_auth:
|
||||
in_room = yield self.auth.check_host_in_room(
|
||||
event.room_id,
|
||||
origin
|
||||
)
|
||||
if not in_room:
|
||||
raise AuthError(403, "Host not in room.")
|
||||
|
||||
defer.returnValue(event)
|
||||
else:
|
||||
|
@ -562,3 +559,65 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
while waiters:
|
||||
waiters.pop().callback(None)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _handle_new_event(self, event, state=None, backfilled=False,
|
||||
current_state=None):
|
||||
if state:
|
||||
for s in state:
|
||||
yield self._handle_new_event(s)
|
||||
|
||||
is_new_state = yield self.state_handler.annotate_event_with_state(
|
||||
event,
|
||||
old_state=state
|
||||
)
|
||||
|
||||
if event.old_state_events:
|
||||
known_ids = set(
|
||||
[s.event_id for s in event.old_state_events.values()]
|
||||
)
|
||||
for e_id, _ in event.auth_events:
|
||||
if e_id not in known_ids:
|
||||
e = yield self.store.get_event(
|
||||
e_id,
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
if not e:
|
||||
# TODO: Do some conflict res to make sure that we're
|
||||
# not the ones who are wrong.
|
||||
logger.info(
|
||||
"Rejecting %s as %s not in %s",
|
||||
event.event_id, e_id, known_ids,
|
||||
)
|
||||
raise AuthError(403, "Auth events are stale")
|
||||
|
||||
auth_events = event.old_state_events
|
||||
else:
|
||||
# We need to get the auth events from somewhere.
|
||||
|
||||
# TODO: Don't just hit the DBs?
|
||||
|
||||
auth_events = {}
|
||||
for e_id, _ in event.auth_events:
|
||||
e = yield self.store.get_event(
|
||||
e_id,
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
if not e:
|
||||
raise AuthError(
|
||||
403,
|
||||
"Can't find auth event %s." % (e_id, )
|
||||
)
|
||||
|
||||
auth_events[(e.type, e.state_key)] = e
|
||||
|
||||
self.auth.check(event, auth_events=auth_events)
|
||||
|
||||
yield self.store.persist_event(
|
||||
event,
|
||||
backfilled=backfilled,
|
||||
is_new_state=(is_new_state and not backfilled),
|
||||
current_state=current_state,
|
||||
)
|
||||
|
|
|
@ -17,13 +17,12 @@ from twisted.internet import defer
|
|||
|
||||
from ._base import BaseHandler
|
||||
from synapse.api.errors import LoginError, Codes
|
||||
from synapse.http.client import IdentityServerHttpClient
|
||||
from synapse.http.client import SimpleHttpClient
|
||||
from synapse.util.emailutils import EmailException
|
||||
import synapse.util.emailutils as emailutils
|
||||
|
||||
import bcrypt
|
||||
import logging
|
||||
import urllib
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -97,10 +96,16 @@ class LoginHandler(BaseHandler):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def _query_email(self, email):
|
||||
httpCli = IdentityServerHttpClient(self.hs)
|
||||
httpCli = SimpleHttpClient(self.hs)
|
||||
data = yield httpCli.get_json(
|
||||
'matrix.org:8090', # TODO FIXME This should be configurable.
|
||||
"/_matrix/identity/api/v1/lookup?medium=email&address=" +
|
||||
"%s" % urllib.quote(email)
|
||||
# TODO FIXME This should be configurable.
|
||||
# XXX: ID servers need to use HTTPS
|
||||
"http://%s%s" % (
|
||||
"matrix.org:8090", "/_matrix/identity/api/v1/lookup"
|
||||
),
|
||||
{
|
||||
'medium': 'email',
|
||||
'address': email
|
||||
}
|
||||
)
|
||||
defer.returnValue(data)
|
||||
|
|
|
@ -18,6 +18,7 @@ from twisted.internet import defer
|
|||
from synapse.api.constants import Membership
|
||||
from synapse.api.errors import RoomError
|
||||
from synapse.streams.config import PaginationConfig
|
||||
from synapse.util.logcontext import PreserveLoggingContext
|
||||
from ._base import BaseHandler
|
||||
|
||||
import logging
|
||||
|
@ -86,9 +87,10 @@ class MessageHandler(BaseHandler):
|
|||
event, snapshot, suppress_auth=suppress_auth
|
||||
)
|
||||
|
||||
self.hs.get_handlers().presence_handler.bump_presence_active_time(
|
||||
user
|
||||
)
|
||||
with PreserveLoggingContext():
|
||||
self.hs.get_handlers().presence_handler.bump_presence_active_time(
|
||||
user
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_messages(self, user_id=None, room_id=None, pagin_config=None,
|
||||
|
@ -241,7 +243,7 @@ class MessageHandler(BaseHandler):
|
|||
public_room_ids = [r["room_id"] for r in public_rooms]
|
||||
|
||||
limit = pagin_config.limit
|
||||
if not limit:
|
||||
if limit is None:
|
||||
limit = 10
|
||||
|
||||
for event in room_list:
|
||||
|
@ -296,7 +298,7 @@ class MessageHandler(BaseHandler):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def room_initial_sync(self, user_id, room_id, pagin_config=None,
|
||||
feedback=False):
|
||||
feedback=False):
|
||||
yield self.auth.check_joined_room(room_id, user_id)
|
||||
|
||||
# TODO(paul): I wish I was called with user objects not user_id
|
||||
|
@ -304,7 +306,7 @@ class MessageHandler(BaseHandler):
|
|||
auth_user = self.hs.parse_userid(user_id)
|
||||
|
||||
# TODO: These concurrently
|
||||
state_tuples = yield self.store.get_current_state(room_id)
|
||||
state_tuples = yield self.state_handler.get_current_state(room_id)
|
||||
state = [self.hs.serialize_event(x) for x in state_tuples]
|
||||
|
||||
member_event = (yield self.store.get_room_member(
|
||||
|
@ -340,8 +342,8 @@ class MessageHandler(BaseHandler):
|
|||
)
|
||||
presence.append(member_presence)
|
||||
except Exception:
|
||||
logger.exception("Failed to get member presence of %r",
|
||||
m.user_id
|
||||
logger.exception(
|
||||
"Failed to get member presence of %r", m.user_id
|
||||
)
|
||||
|
||||
defer.returnValue({
|
||||
|
|
|
@ -19,6 +19,7 @@ from synapse.api.errors import SynapseError, AuthError
|
|||
from synapse.api.constants import PresenceState
|
||||
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.logcontext import PreserveLoggingContext
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
||||
|
@ -142,7 +143,7 @@ class PresenceHandler(BaseHandler):
|
|||
return UserPresenceCache()
|
||||
|
||||
def registered_user(self, user):
|
||||
self.store.create_presence(user.localpart)
|
||||
return self.store.create_presence(user.localpart)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def is_presence_visible(self, observer_user, observed_user):
|
||||
|
@ -241,14 +242,12 @@ class PresenceHandler(BaseHandler):
|
|||
was_level = self.STATE_LEVELS[statuscache.get_state()["presence"]]
|
||||
now_level = self.STATE_LEVELS[state["presence"]]
|
||||
|
||||
yield defer.DeferredList([
|
||||
self.store.set_presence_state(
|
||||
target_user.localpart, state_to_store
|
||||
),
|
||||
self.distributor.fire(
|
||||
"collect_presencelike_data", target_user, state
|
||||
),
|
||||
])
|
||||
yield self.store.set_presence_state(
|
||||
target_user.localpart, state_to_store
|
||||
)
|
||||
yield self.distributor.fire(
|
||||
"collect_presencelike_data", target_user, state
|
||||
)
|
||||
|
||||
if now_level > was_level:
|
||||
state["last_active"] = self.clock.time_msec()
|
||||
|
@ -256,14 +255,15 @@ class PresenceHandler(BaseHandler):
|
|||
now_online = state["presence"] != PresenceState.OFFLINE
|
||||
was_polling = target_user in self._user_cachemap
|
||||
|
||||
if now_online and not was_polling:
|
||||
self.start_polling_presence(target_user, state=state)
|
||||
elif not now_online and was_polling:
|
||||
self.stop_polling_presence(target_user)
|
||||
with PreserveLoggingContext():
|
||||
if now_online and not was_polling:
|
||||
self.start_polling_presence(target_user, state=state)
|
||||
elif not now_online and was_polling:
|
||||
self.stop_polling_presence(target_user)
|
||||
|
||||
# TODO(paul): perform a presence push as part of start/stop poll so
|
||||
# we don't have to do this all the time
|
||||
self.changed_presencelike_data(target_user, state)
|
||||
# TODO(paul): perform a presence push as part of start/stop poll so
|
||||
# we don't have to do this all the time
|
||||
self.changed_presencelike_data(target_user, state)
|
||||
|
||||
def bump_presence_active_time(self, user, now=None):
|
||||
if now is None:
|
||||
|
@ -277,7 +277,7 @@ class PresenceHandler(BaseHandler):
|
|||
self._user_cachemap_latest_serial += 1
|
||||
statuscache.update(state, serial=self._user_cachemap_latest_serial)
|
||||
|
||||
self.push_presence(user, statuscache=statuscache)
|
||||
return self.push_presence(user, statuscache=statuscache)
|
||||
|
||||
@log_function
|
||||
def started_user_eventstream(self, user):
|
||||
|
@ -381,8 +381,10 @@ class PresenceHandler(BaseHandler):
|
|||
yield self.store.set_presence_list_accepted(
|
||||
observer_user.localpart, observed_user.to_string()
|
||||
)
|
||||
|
||||
self.start_polling_presence(observer_user, target_user=observed_user)
|
||||
with PreserveLoggingContext():
|
||||
self.start_polling_presence(
|
||||
observer_user, target_user=observed_user
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def deny_presence(self, observed_user, observer_user):
|
||||
|
@ -401,7 +403,10 @@ class PresenceHandler(BaseHandler):
|
|||
observer_user.localpart, observed_user.to_string()
|
||||
)
|
||||
|
||||
self.stop_polling_presence(observer_user, target_user=observed_user)
|
||||
with PreserveLoggingContext():
|
||||
self.stop_polling_presence(
|
||||
observer_user, target_user=observed_user
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_presence_list(self, observer_user, accepted=None):
|
||||
|
@ -710,7 +715,8 @@ class PresenceHandler(BaseHandler):
|
|||
if not self._remote_sendmap[user]:
|
||||
del self._remote_sendmap[user]
|
||||
|
||||
yield defer.DeferredList(deferreds)
|
||||
with PreserveLoggingContext():
|
||||
yield defer.DeferredList(deferreds)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def push_update_to_local_and_remote(self, observed_user, statuscache,
|
||||
|
|
|
@ -17,6 +17,7 @@ from twisted.internet import defer
|
|||
|
||||
from synapse.api.errors import SynapseError, AuthError, CodeMessageException
|
||||
from synapse.api.constants import Membership
|
||||
from synapse.util.logcontext import PreserveLoggingContext
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
||||
|
@ -46,7 +47,7 @@ class ProfileHandler(BaseHandler):
|
|||
)
|
||||
|
||||
def registered_user(self, user):
|
||||
self.store.create_profile(user.localpart)
|
||||
return self.store.create_profile(user.localpart)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_displayname(self, target_user):
|
||||
|
@ -152,13 +153,14 @@ class ProfileHandler(BaseHandler):
|
|||
if not user.is_mine:
|
||||
defer.returnValue(None)
|
||||
|
||||
(displayname, avatar_url) = yield defer.gatherResults(
|
||||
[
|
||||
self.store.get_profile_displayname(user.localpart),
|
||||
self.store.get_profile_avatar_url(user.localpart),
|
||||
],
|
||||
consumeErrors=True
|
||||
)
|
||||
with PreserveLoggingContext():
|
||||
(displayname, avatar_url) = yield defer.gatherResults(
|
||||
[
|
||||
self.store.get_profile_displayname(user.localpart),
|
||||
self.store.get_profile_avatar_url(user.localpart),
|
||||
],
|
||||
consumeErrors=True
|
||||
)
|
||||
|
||||
state["displayname"] = displayname
|
||||
state["avatar_url"] = avatar_url
|
||||
|
|
|
@ -22,7 +22,7 @@ from synapse.api.errors import (
|
|||
)
|
||||
from ._base import BaseHandler
|
||||
import synapse.util.stringutils as stringutils
|
||||
from synapse.http.client import IdentityServerHttpClient
|
||||
from synapse.http.client import SimpleHttpClient
|
||||
from synapse.http.client import CaptchaServerHttpClient
|
||||
|
||||
import base64
|
||||
|
@ -69,7 +69,7 @@ class RegistrationHandler(BaseHandler):
|
|||
password_hash=password_hash
|
||||
)
|
||||
|
||||
self.distributor.fire("registered_user", user)
|
||||
yield self.distributor.fire("registered_user", user)
|
||||
else:
|
||||
# autogen a random user ID
|
||||
attempts = 0
|
||||
|
@ -133,7 +133,7 @@ class RegistrationHandler(BaseHandler):
|
|||
|
||||
if not threepid:
|
||||
raise RegistrationError(400, "Couldn't validate 3pid")
|
||||
logger.info("got threepid medium %s address %s",
|
||||
logger.info("got threepid with medium '%s' and address '%s'",
|
||||
threepid['medium'], threepid['address'])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -159,7 +159,7 @@ class RegistrationHandler(BaseHandler):
|
|||
def _threepid_from_creds(self, creds):
|
||||
# TODO: get this from the homeserver rather than creating a new one for
|
||||
# each request
|
||||
httpCli = IdentityServerHttpClient(self.hs)
|
||||
httpCli = SimpleHttpClient(self.hs)
|
||||
# XXX: make this configurable!
|
||||
trustedIdServers = ['matrix.org:8090']
|
||||
if not creds['idServer'] in trustedIdServers:
|
||||
|
@ -167,8 +167,11 @@ class RegistrationHandler(BaseHandler):
|
|||
'credentials', creds['idServer'])
|
||||
defer.returnValue(None)
|
||||
data = yield httpCli.get_json(
|
||||
creds['idServer'],
|
||||
"/_matrix/identity/api/v1/3pid/getValidated3pid",
|
||||
# XXX: This should be HTTPS
|
||||
"http://%s%s" % (
|
||||
creds['idServer'],
|
||||
"/_matrix/identity/api/v1/3pid/getValidated3pid"
|
||||
),
|
||||
{'sid': creds['sid'], 'clientSecret': creds['clientSecret']}
|
||||
)
|
||||
|
||||
|
@ -178,16 +181,21 @@ class RegistrationHandler(BaseHandler):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def _bind_threepid(self, creds, mxid):
|
||||
httpCli = IdentityServerHttpClient(self.hs)
|
||||
yield
|
||||
logger.debug("binding threepid")
|
||||
httpCli = SimpleHttpClient(self.hs)
|
||||
data = yield httpCli.post_urlencoded_get_json(
|
||||
creds['idServer'],
|
||||
"/_matrix/identity/api/v1/3pid/bind",
|
||||
# XXX: Change when ID servers are all HTTPS
|
||||
"http://%s%s" % (
|
||||
creds['idServer'], "/_matrix/identity/api/v1/3pid/bind"
|
||||
),
|
||||
{
|
||||
'sid': creds['sid'],
|
||||
'clientSecret': creds['clientSecret'],
|
||||
'mxid': mxid,
|
||||
}
|
||||
)
|
||||
logger.debug("bound threepid")
|
||||
defer.returnValue(data)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -215,10 +223,7 @@ class RegistrationHandler(BaseHandler):
|
|||
# each request
|
||||
client = CaptchaServerHttpClient(self.hs)
|
||||
data = yield client.post_urlencoded_get_raw(
|
||||
"www.google.com:80",
|
||||
"/recaptcha/api/verify",
|
||||
# twisted dislikes google's response, no content length.
|
||||
accept_partial=True,
|
||||
"http://www.google.com:80/recaptcha/api/verify",
|
||||
args={
|
||||
'privatekey': private_key,
|
||||
'remoteip': ip_addr,
|
||||
|
|
|
@ -178,7 +178,9 @@ class RoomCreationHandler(BaseHandler):
|
|||
|
||||
if room_alias:
|
||||
result["room_alias"] = room_alias.to_string()
|
||||
directory_handler.send_room_alias_update_event(user_id, room_id)
|
||||
yield directory_handler.send_room_alias_update_event(
|
||||
user_id, room_id
|
||||
)
|
||||
|
||||
defer.returnValue(result)
|
||||
|
||||
|
@ -211,7 +213,6 @@ class RoomCreationHandler(BaseHandler):
|
|||
**event_keys
|
||||
)
|
||||
|
||||
|
||||
power_levels_event = self.event_factory.create_event(
|
||||
etype=RoomPowerLevelsEvent.TYPE,
|
||||
content={
|
||||
|
@ -480,7 +481,7 @@ class RoomMemberHandler(BaseHandler):
|
|||
)
|
||||
|
||||
user = self.hs.parse_userid(event.user_id)
|
||||
self.distributor.fire(
|
||||
yield self.distributor.fire(
|
||||
"user_joined_room", user=user, room_id=room_id
|
||||
)
|
||||
|
||||
|
|
|
@ -15,308 +15,45 @@
|
|||
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.internet.error import DNSLookupError
|
||||
from twisted.web.client import (
|
||||
_AgentBase, _URI, readBody, FileBodyProducer, PartialDownloadError
|
||||
Agent, readBody, FileBodyProducer, PartialDownloadError
|
||||
)
|
||||
from twisted.web.http_headers import Headers
|
||||
|
||||
from synapse.http.endpoint import matrix_endpoint
|
||||
from synapse.util.async import sleep
|
||||
from synapse.util.logcontext import PreserveLoggingContext
|
||||
|
||||
from syutil.jsonutil import encode_canonical_json
|
||||
|
||||
from synapse.api.errors import CodeMessageException, SynapseError
|
||||
|
||||
from syutil.crypto.jsonsign import sign_json
|
||||
|
||||
from StringIO import StringIO
|
||||
|
||||
import json
|
||||
import logging
|
||||
import urllib
|
||||
import urlparse
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MatrixHttpAgent(_AgentBase):
|
||||
|
||||
def __init__(self, reactor, pool=None):
|
||||
_AgentBase.__init__(self, reactor, pool)
|
||||
|
||||
def request(self, destination, endpoint, method, path, params, query,
|
||||
headers, body_producer):
|
||||
|
||||
host = b""
|
||||
port = 0
|
||||
fragment = b""
|
||||
|
||||
parsed_URI = _URI(b"http", destination, host, port, path, params,
|
||||
query, fragment)
|
||||
|
||||
# Set the connection pool key to be the destination.
|
||||
key = destination
|
||||
|
||||
return self._requestWithEndpoint(key, endpoint, method, parsed_URI,
|
||||
headers, body_producer,
|
||||
parsed_URI.originForm)
|
||||
|
||||
|
||||
class BaseHttpClient(object):
|
||||
"""Base class for HTTP clients using twisted.
|
||||
class SimpleHttpClient(object):
|
||||
"""
|
||||
A simple, no-frills HTTP client with methods that wrap up common ways of
|
||||
using HTTP in Matrix
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
self.agent = MatrixHttpAgent(reactor)
|
||||
self.hs = hs
|
||||
# The default context factory in Twisted 14.0.0 (which we require) is
|
||||
# BrowserLikePolicyForHTTPS which will do regular cert validation
|
||||
# 'like a browser'
|
||||
self.agent = Agent(reactor)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _create_request(self, destination, method, path_bytes,
|
||||
body_callback, headers_dict={}, param_bytes=b"",
|
||||
query_bytes=b"", retry_on_dns_fail=True):
|
||||
""" Creates and sends a request to the given url
|
||||
"""
|
||||
headers_dict[b"User-Agent"] = [b"Synapse"]
|
||||
headers_dict[b"Host"] = [destination]
|
||||
|
||||
url_bytes = urlparse.urlunparse(
|
||||
("", "", path_bytes, param_bytes, query_bytes, "",)
|
||||
)
|
||||
|
||||
logger.debug("Sending request to %s: %s %s",
|
||||
destination, method, url_bytes)
|
||||
|
||||
logger.debug(
|
||||
"Types: %s",
|
||||
[
|
||||
type(destination), type(method), type(path_bytes),
|
||||
type(param_bytes),
|
||||
type(query_bytes)
|
||||
]
|
||||
)
|
||||
|
||||
retries_left = 5
|
||||
|
||||
endpoint = self._getEndpoint(reactor, destination)
|
||||
|
||||
while True:
|
||||
|
||||
producer = None
|
||||
if body_callback:
|
||||
producer = body_callback(method, url_bytes, headers_dict)
|
||||
|
||||
try:
|
||||
with PreserveLoggingContext():
|
||||
response = yield self.agent.request(
|
||||
destination,
|
||||
endpoint,
|
||||
method,
|
||||
path_bytes,
|
||||
param_bytes,
|
||||
query_bytes,
|
||||
Headers(headers_dict),
|
||||
producer
|
||||
)
|
||||
|
||||
logger.debug("Got response to %s", method)
|
||||
break
|
||||
except Exception as e:
|
||||
if not retry_on_dns_fail and isinstance(e, DNSLookupError):
|
||||
logger.warn("DNS Lookup failed to %s with %s", destination,
|
||||
e)
|
||||
raise SynapseError(400, "Domain specified not found.")
|
||||
|
||||
logger.exception("Got error in _create_request")
|
||||
_print_ex(e)
|
||||
|
||||
if retries_left:
|
||||
yield sleep(2 ** (5 - retries_left))
|
||||
retries_left -= 1
|
||||
else:
|
||||
raise
|
||||
|
||||
if 200 <= response.code < 300:
|
||||
# We need to update the transactions table to say it was sent?
|
||||
pass
|
||||
else:
|
||||
# :'(
|
||||
# Update transactions table?
|
||||
logger.error(
|
||||
"Got response %d %s", response.code, response.phrase
|
||||
)
|
||||
raise CodeMessageException(
|
||||
response.code, response.phrase
|
||||
)
|
||||
|
||||
defer.returnValue(response)
|
||||
|
||||
|
||||
class MatrixHttpClient(BaseHttpClient):
|
||||
""" Wrapper around the twisted HTTP client api. Implements
|
||||
|
||||
Attributes:
|
||||
agent (twisted.web.client.Agent): The twisted Agent used to send the
|
||||
requests.
|
||||
"""
|
||||
|
||||
RETRY_DNS_LOOKUP_FAILURES = "__retry_dns"
|
||||
|
||||
def __init__(self, hs):
|
||||
self.signing_key = hs.config.signing_key[0]
|
||||
self.server_name = hs.hostname
|
||||
BaseHttpClient.__init__(self, hs)
|
||||
|
||||
def sign_request(self, destination, method, url_bytes, headers_dict,
|
||||
content=None):
|
||||
request = {
|
||||
"method": method,
|
||||
"uri": url_bytes,
|
||||
"origin": self.server_name,
|
||||
"destination": destination,
|
||||
}
|
||||
|
||||
if content is not None:
|
||||
request["content"] = content
|
||||
|
||||
request = sign_json(request, self.server_name, self.signing_key)
|
||||
|
||||
auth_headers = []
|
||||
|
||||
for key, sig in request["signatures"][self.server_name].items():
|
||||
auth_headers.append(bytes(
|
||||
"X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (
|
||||
self.server_name, key, sig,
|
||||
)
|
||||
))
|
||||
|
||||
headers_dict[b"Authorization"] = auth_headers
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def put_json(self, destination, path, data={}, json_data_callback=None):
|
||||
""" Sends the specifed json data using PUT
|
||||
|
||||
Args:
|
||||
destination (str): The remote server to send the HTTP request
|
||||
to.
|
||||
path (str): The HTTP path.
|
||||
data (dict): A dict containing the data that will be used as
|
||||
the request body. This will be encoded as JSON.
|
||||
json_data_callback (callable): A callable returning the dict to
|
||||
use as the request body.
|
||||
|
||||
Returns:
|
||||
Deferred: Succeeds when we get a 2xx HTTP response. The result
|
||||
will be the decoded JSON body. On a 4xx or 5xx error response a
|
||||
CodeMessageException is raised.
|
||||
"""
|
||||
|
||||
if not json_data_callback:
|
||||
def json_data_callback():
|
||||
return data
|
||||
|
||||
def body_callback(method, url_bytes, headers_dict):
|
||||
json_data = json_data_callback()
|
||||
self.sign_request(
|
||||
destination, method, url_bytes, headers_dict, json_data
|
||||
)
|
||||
producer = _JsonProducer(json_data)
|
||||
return producer
|
||||
|
||||
response = yield self._create_request(
|
||||
destination.encode("ascii"),
|
||||
"PUT",
|
||||
path.encode("ascii"),
|
||||
body_callback=body_callback,
|
||||
headers_dict={"Content-Type": ["application/json"]},
|
||||
)
|
||||
|
||||
logger.debug("Getting resp body")
|
||||
body = yield readBody(response)
|
||||
logger.debug("Got resp body")
|
||||
|
||||
defer.returnValue((response.code, body))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_json(self, destination, path, args={}, retry_on_dns_fail=True):
|
||||
""" Get's some json from the given host homeserver and path
|
||||
|
||||
Args:
|
||||
destination (str): The remote server to send the HTTP request
|
||||
to.
|
||||
path (str): The HTTP path.
|
||||
args (dict): A dictionary used to create query strings, defaults to
|
||||
None.
|
||||
**Note**: The value of each key is assumed to be an iterable
|
||||
and *not* a string.
|
||||
|
||||
Returns:
|
||||
Deferred: Succeeds when we get *any* HTTP response.
|
||||
|
||||
The result of the deferred is a tuple of `(code, response)`,
|
||||
where `response` is a dict representing the decoded JSON body.
|
||||
"""
|
||||
logger.debug("get_json args: %s", args)
|
||||
|
||||
encoded_args = {}
|
||||
for k, vs in args.items():
|
||||
if isinstance(vs, basestring):
|
||||
vs = [vs]
|
||||
encoded_args[k] = [v.encode("UTF-8") for v in vs]
|
||||
|
||||
query_bytes = urllib.urlencode(encoded_args, True)
|
||||
logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail)
|
||||
|
||||
def body_callback(method, url_bytes, headers_dict):
|
||||
self.sign_request(destination, method, url_bytes, headers_dict)
|
||||
return None
|
||||
|
||||
response = yield self._create_request(
|
||||
destination.encode("ascii"),
|
||||
"GET",
|
||||
path.encode("ascii"),
|
||||
query_bytes=query_bytes,
|
||||
body_callback=body_callback,
|
||||
retry_on_dns_fail=retry_on_dns_fail
|
||||
)
|
||||
|
||||
body = yield readBody(response)
|
||||
|
||||
defer.returnValue(json.loads(body))
|
||||
|
||||
def _getEndpoint(self, reactor, destination):
|
||||
return matrix_endpoint(
|
||||
reactor, destination, timeout=10,
|
||||
ssl_context_factory=self.hs.tls_context_factory
|
||||
)
|
||||
|
||||
|
||||
class IdentityServerHttpClient(BaseHttpClient):
|
||||
"""Separate HTTP client for talking to the Identity servers since they
|
||||
don't use SRV records and talk x-www-form-urlencoded rather than JSON.
|
||||
"""
|
||||
def _getEndpoint(self, reactor, destination):
|
||||
#TODO: This should be talking TLS
|
||||
return matrix_endpoint(reactor, destination, timeout=10)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def post_urlencoded_get_json(self, destination, path, args={}):
|
||||
def post_urlencoded_get_json(self, uri, args={}):
|
||||
logger.debug("post_urlencoded_get_json args: %s", args)
|
||||
query_bytes = urllib.urlencode(args, True)
|
||||
|
||||
def body_callback(method, url_bytes, headers_dict):
|
||||
return FileBodyProducer(StringIO(query_bytes))
|
||||
|
||||
response = yield self._create_request(
|
||||
destination.encode("ascii"),
|
||||
response = yield self.agent.request(
|
||||
"POST",
|
||||
path.encode("ascii"),
|
||||
body_callback=body_callback,
|
||||
headers_dict={
|
||||
uri.encode("ascii"),
|
||||
headers=Headers({
|
||||
"Content-Type": ["application/x-www-form-urlencoded"]
|
||||
}
|
||||
}),
|
||||
bodyProducer=FileBodyProducer(StringIO(query_bytes))
|
||||
)
|
||||
|
||||
body = yield readBody(response)
|
||||
|
@ -324,13 +61,11 @@ class IdentityServerHttpClient(BaseHttpClient):
|
|||
defer.returnValue(json.loads(body))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_json(self, destination, path, args={}, retry_on_dns_fail=True):
|
||||
""" Get's some json from the given host homeserver and path
|
||||
def get_json(self, uri, args={}):
|
||||
""" Get's some json from the given host and path
|
||||
|
||||
Args:
|
||||
destination (str): The remote server to send the HTTP request
|
||||
to.
|
||||
path (str): The HTTP path.
|
||||
uri (str): The URI to request, not including query parameters
|
||||
args (dict): A dictionary used to create query strings, defaults to
|
||||
None.
|
||||
**Note**: The value of each key is assumed to be an iterable
|
||||
|
@ -342,18 +77,15 @@ class IdentityServerHttpClient(BaseHttpClient):
|
|||
The result of the deferred is a tuple of `(code, response)`,
|
||||
where `response` is a dict representing the decoded JSON body.
|
||||
"""
|
||||
logger.debug("get_json args: %s", args)
|
||||
|
||||
query_bytes = urllib.urlencode(args, True)
|
||||
logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail)
|
||||
yield
|
||||
if len(args):
|
||||
query_bytes = urllib.urlencode(args, True)
|
||||
uri = "%s?%s" % (uri, query_bytes)
|
||||
|
||||
response = yield self._create_request(
|
||||
destination.encode("ascii"),
|
||||
response = yield self.agent.request(
|
||||
"GET",
|
||||
path.encode("ascii"),
|
||||
query_bytes=query_bytes,
|
||||
retry_on_dns_fail=retry_on_dns_fail,
|
||||
body_callback=None
|
||||
uri.encode("ascii"),
|
||||
)
|
||||
|
||||
body = yield readBody(response)
|
||||
|
@ -361,38 +93,31 @@ class IdentityServerHttpClient(BaseHttpClient):
|
|||
defer.returnValue(json.loads(body))
|
||||
|
||||
|
||||
class CaptchaServerHttpClient(MatrixHttpClient):
|
||||
"""Separate HTTP client for talking to google's captcha servers"""
|
||||
|
||||
def _getEndpoint(self, reactor, destination):
|
||||
return matrix_endpoint(reactor, destination, timeout=10)
|
||||
class CaptchaServerHttpClient(SimpleHttpClient):
|
||||
"""
|
||||
Separate HTTP client for talking to google's captcha servers
|
||||
Only slightly special because accepts partial download responses
|
||||
"""
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def post_urlencoded_get_raw(self, destination, path, accept_partial=False,
|
||||
args={}):
|
||||
def post_urlencoded_get_raw(self, url, args={}):
|
||||
query_bytes = urllib.urlencode(args, True)
|
||||
|
||||
def body_callback(method, url_bytes, headers_dict):
|
||||
return FileBodyProducer(StringIO(query_bytes))
|
||||
|
||||
response = yield self._create_request(
|
||||
destination.encode("ascii"),
|
||||
response = yield self.agent.request(
|
||||
"POST",
|
||||
path.encode("ascii"),
|
||||
body_callback=body_callback,
|
||||
headers_dict={
|
||||
url.encode("ascii"),
|
||||
bodyProducer=FileBodyProducer(StringIO(query_bytes)),
|
||||
headers=Headers({
|
||||
"Content-Type": ["application/x-www-form-urlencoded"]
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
try:
|
||||
body = yield readBody(response)
|
||||
defer.returnValue(body)
|
||||
except PartialDownloadError as e:
|
||||
if accept_partial:
|
||||
defer.returnValue(e.response)
|
||||
else:
|
||||
raise e
|
||||
# twisted dislikes google's response, no content length.
|
||||
defer.returnValue(e.response)
|
||||
|
||||
|
||||
def _print_ex(e):
|
||||
|
@ -401,24 +126,3 @@ def _print_ex(e):
|
|||
_print_ex(ex)
|
||||
else:
|
||||
logger.exception(e)
|
||||
|
||||
|
||||
class _JsonProducer(object):
|
||||
""" Used by the twisted http client to create the HTTP body from json
|
||||
"""
|
||||
def __init__(self, jsn):
|
||||
self.reset(jsn)
|
||||
|
||||
def reset(self, jsn):
|
||||
self.body = encode_canonical_json(jsn)
|
||||
self.length = len(self.body)
|
||||
|
||||
def startProducing(self, consumer):
|
||||
consumer.write(self.body)
|
||||
return defer.succeed(None)
|
||||
|
||||
def pauseProducing(self):
|
||||
pass
|
||||
|
||||
def stopProducing(self):
|
||||
pass
|
||||
|
|
|
@ -131,11 +131,13 @@ class ContentRepoResource(resource.Resource):
|
|||
request.setHeader('Content-Type', content_type)
|
||||
|
||||
# cache for at least a day.
|
||||
# XXX: we might want to turn this off for data we don't want to recommend
|
||||
# caching as it's sensitive or private - or at least select private.
|
||||
# don't bother setting Expires as all our matrix clients are smart enough to
|
||||
# be happy with Cache-Control (right?)
|
||||
request.setHeader('Cache-Control', 'public,max-age=86400,s-maxage=86400')
|
||||
# XXX: we might want to turn this off for data we don't want to
|
||||
# recommend caching as it's sensitive or private - or at least
|
||||
# select private. don't bother setting Expires as all our matrix
|
||||
# clients are smart enough to be happy with Cache-Control (right?)
|
||||
request.setHeader(
|
||||
"Cache-Control", "public,max-age=86400,s-maxage=86400"
|
||||
)
|
||||
|
||||
d = FileSender().beginFileTransfer(f, request)
|
||||
|
||||
|
@ -179,7 +181,7 @@ class ContentRepoResource(resource.Resource):
|
|||
|
||||
fname = yield self.map_request_to_name(request)
|
||||
|
||||
# TODO I have a suspcious feeling this is just going to block
|
||||
# TODO I have a suspicious feeling this is just going to block
|
||||
with open(fname, "wb") as f:
|
||||
f.write(request.content.read())
|
||||
|
||||
|
@ -188,7 +190,7 @@ class ContentRepoResource(resource.Resource):
|
|||
# FIXME: we can't assume what the repo's public mounted path is
|
||||
# ...plus self-signed SSL won't work to remote clients anyway
|
||||
# ...and we can't assume that it's SSL anyway, as we might want to
|
||||
# server it via the non-SSL listener...
|
||||
# serve it via the non-SSL listener...
|
||||
url = "%s/_matrix/content/%s" % (
|
||||
self.external_addr, file_name
|
||||
)
|
||||
|
|
|
@ -27,8 +27,8 @@ import random
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def matrix_endpoint(reactor, destination, ssl_context_factory=None,
|
||||
timeout=None):
|
||||
def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
|
||||
timeout=None):
|
||||
"""Construct an endpoint for the given matrix destination.
|
||||
|
||||
Args:
|
||||
|
|
308
synapse/http/matrixfederationclient.py
Normal file
308
synapse/http/matrixfederationclient.py
Normal file
|
@ -0,0 +1,308 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.internet.error import DNSLookupError
|
||||
from twisted.web.client import readBody, _AgentBase, _URI
|
||||
from twisted.web.http_headers import Headers
|
||||
|
||||
from synapse.http.endpoint import matrix_federation_endpoint
|
||||
from synapse.util.async import sleep
|
||||
from synapse.util.logcontext import PreserveLoggingContext
|
||||
|
||||
from syutil.jsonutil import encode_canonical_json
|
||||
|
||||
from synapse.api.errors import CodeMessageException, SynapseError
|
||||
|
||||
from syutil.crypto.jsonsign import sign_json
|
||||
|
||||
import json
|
||||
import logging
|
||||
import urllib
|
||||
import urlparse
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MatrixFederationHttpAgent(_AgentBase):
|
||||
|
||||
def __init__(self, reactor, pool=None):
|
||||
_AgentBase.__init__(self, reactor, pool)
|
||||
|
||||
def request(self, destination, endpoint, method, path, params, query,
|
||||
headers, body_producer):
|
||||
|
||||
host = b""
|
||||
port = 0
|
||||
fragment = b""
|
||||
|
||||
parsed_URI = _URI(b"http", destination, host, port, path, params,
|
||||
query, fragment)
|
||||
|
||||
# Set the connection pool key to be the destination.
|
||||
key = destination
|
||||
|
||||
return self._requestWithEndpoint(key, endpoint, method, parsed_URI,
|
||||
headers, body_producer,
|
||||
parsed_URI.originForm)
|
||||
|
||||
|
||||
class MatrixFederationHttpClient(object):
|
||||
"""HTTP client used to talk to other homeservers over the federation
|
||||
protocol. Send client certificates and signs requests.
|
||||
|
||||
Attributes:
|
||||
agent (twisted.web.client.Agent): The twisted Agent used to send the
|
||||
requests.
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
self.signing_key = hs.config.signing_key[0]
|
||||
self.server_name = hs.hostname
|
||||
self.agent = MatrixFederationHttpAgent(reactor)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _create_request(self, destination, method, path_bytes,
|
||||
body_callback, headers_dict={}, param_bytes=b"",
|
||||
query_bytes=b"", retry_on_dns_fail=True):
|
||||
""" Creates and sends a request to the given url
|
||||
"""
|
||||
headers_dict[b"User-Agent"] = [b"Synapse"]
|
||||
headers_dict[b"Host"] = [destination]
|
||||
|
||||
url_bytes = urlparse.urlunparse(
|
||||
("", "", path_bytes, param_bytes, query_bytes, "",)
|
||||
)
|
||||
|
||||
logger.debug("Sending request to %s: %s %s",
|
||||
destination, method, url_bytes)
|
||||
|
||||
logger.debug(
|
||||
"Types: %s",
|
||||
[
|
||||
type(destination), type(method), type(path_bytes),
|
||||
type(param_bytes),
|
||||
type(query_bytes)
|
||||
]
|
||||
)
|
||||
|
||||
retries_left = 5
|
||||
|
||||
endpoint = self._getEndpoint(reactor, destination)
|
||||
|
||||
while True:
|
||||
producer = None
|
||||
if body_callback:
|
||||
producer = body_callback(method, url_bytes, headers_dict)
|
||||
|
||||
try:
|
||||
with PreserveLoggingContext():
|
||||
response = yield self.agent.request(
|
||||
destination,
|
||||
endpoint,
|
||||
method,
|
||||
path_bytes,
|
||||
param_bytes,
|
||||
query_bytes,
|
||||
Headers(headers_dict),
|
||||
producer
|
||||
)
|
||||
|
||||
logger.debug("Got response to %s", method)
|
||||
break
|
||||
except Exception as e:
|
||||
if not retry_on_dns_fail and isinstance(e, DNSLookupError):
|
||||
logger.warn("DNS Lookup failed to %s with %s", destination,
|
||||
e)
|
||||
raise SynapseError(400, "Domain specified not found.")
|
||||
|
||||
logger.exception("Got error in _create_request")
|
||||
_print_ex(e)
|
||||
|
||||
if retries_left:
|
||||
yield sleep(2 ** (5 - retries_left))
|
||||
retries_left -= 1
|
||||
else:
|
||||
raise
|
||||
|
||||
if 200 <= response.code < 300:
|
||||
# We need to update the transactions table to say it was sent?
|
||||
pass
|
||||
else:
|
||||
# :'(
|
||||
# Update transactions table?
|
||||
logger.error(
|
||||
"Got response %d %s", response.code, response.phrase
|
||||
)
|
||||
raise CodeMessageException(
|
||||
response.code, response.phrase
|
||||
)
|
||||
|
||||
defer.returnValue(response)
|
||||
|
||||
def sign_request(self, destination, method, url_bytes, headers_dict,
|
||||
content=None):
|
||||
request = {
|
||||
"method": method,
|
||||
"uri": url_bytes,
|
||||
"origin": self.server_name,
|
||||
"destination": destination,
|
||||
}
|
||||
|
||||
if content is not None:
|
||||
request["content"] = content
|
||||
|
||||
request = sign_json(request, self.server_name, self.signing_key)
|
||||
|
||||
auth_headers = []
|
||||
|
||||
for key, sig in request["signatures"][self.server_name].items():
|
||||
auth_headers.append(bytes(
|
||||
"X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (
|
||||
self.server_name, key, sig,
|
||||
)
|
||||
))
|
||||
|
||||
headers_dict[b"Authorization"] = auth_headers
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def put_json(self, destination, path, data={}, json_data_callback=None):
|
||||
""" Sends the specifed json data using PUT
|
||||
|
||||
Args:
|
||||
destination (str): The remote server to send the HTTP request
|
||||
to.
|
||||
path (str): The HTTP path.
|
||||
data (dict): A dict containing the data that will be used as
|
||||
the request body. This will be encoded as JSON.
|
||||
json_data_callback (callable): A callable returning the dict to
|
||||
use as the request body.
|
||||
|
||||
Returns:
|
||||
Deferred: Succeeds when we get a 2xx HTTP response. The result
|
||||
will be the decoded JSON body. On a 4xx or 5xx error response a
|
||||
CodeMessageException is raised.
|
||||
"""
|
||||
|
||||
if not json_data_callback:
|
||||
def json_data_callback():
|
||||
return data
|
||||
|
||||
def body_callback(method, url_bytes, headers_dict):
|
||||
json_data = json_data_callback()
|
||||
self.sign_request(
|
||||
destination, method, url_bytes, headers_dict, json_data
|
||||
)
|
||||
producer = _JsonProducer(json_data)
|
||||
return producer
|
||||
|
||||
response = yield self._create_request(
|
||||
destination.encode("ascii"),
|
||||
"PUT",
|
||||
path.encode("ascii"),
|
||||
body_callback=body_callback,
|
||||
headers_dict={"Content-Type": ["application/json"]},
|
||||
)
|
||||
|
||||
logger.debug("Getting resp body")
|
||||
body = yield readBody(response)
|
||||
logger.debug("Got resp body")
|
||||
|
||||
defer.returnValue((response.code, body))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_json(self, destination, path, args={}, retry_on_dns_fail=True):
|
||||
""" Get's some json from the given host homeserver and path
|
||||
|
||||
Args:
|
||||
destination (str): The remote server to send the HTTP request
|
||||
to.
|
||||
path (str): The HTTP path.
|
||||
args (dict): A dictionary used to create query strings, defaults to
|
||||
None.
|
||||
**Note**: The value of each key is assumed to be an iterable
|
||||
and *not* a string.
|
||||
|
||||
Returns:
|
||||
Deferred: Succeeds when we get *any* HTTP response.
|
||||
|
||||
The result of the deferred is a tuple of `(code, response)`,
|
||||
where `response` is a dict representing the decoded JSON body.
|
||||
"""
|
||||
logger.debug("get_json args: %s", args)
|
||||
|
||||
encoded_args = {}
|
||||
for k, vs in args.items():
|
||||
if isinstance(vs, basestring):
|
||||
vs = [vs]
|
||||
encoded_args[k] = [v.encode("UTF-8") for v in vs]
|
||||
|
||||
query_bytes = urllib.urlencode(encoded_args, True)
|
||||
logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail)
|
||||
|
||||
def body_callback(method, url_bytes, headers_dict):
|
||||
self.sign_request(destination, method, url_bytes, headers_dict)
|
||||
return None
|
||||
|
||||
response = yield self._create_request(
|
||||
destination.encode("ascii"),
|
||||
"GET",
|
||||
path.encode("ascii"),
|
||||
query_bytes=query_bytes,
|
||||
body_callback=body_callback,
|
||||
retry_on_dns_fail=retry_on_dns_fail
|
||||
)
|
||||
|
||||
body = yield readBody(response)
|
||||
|
||||
defer.returnValue(json.loads(body))
|
||||
|
||||
def _getEndpoint(self, reactor, destination):
|
||||
return matrix_federation_endpoint(
|
||||
reactor, destination, timeout=10,
|
||||
ssl_context_factory=self.hs.tls_context_factory
|
||||
)
|
||||
|
||||
|
||||
def _print_ex(e):
|
||||
if hasattr(e, "reasons") and e.reasons:
|
||||
for ex in e.reasons:
|
||||
_print_ex(ex)
|
||||
else:
|
||||
logger.exception(e)
|
||||
|
||||
|
||||
class _JsonProducer(object):
|
||||
""" Used by the twisted http client to create the HTTP body from json
|
||||
"""
|
||||
def __init__(self, jsn):
|
||||
self.reset(jsn)
|
||||
|
||||
def reset(self, jsn):
|
||||
self.body = encode_canonical_json(jsn)
|
||||
self.length = len(self.body)
|
||||
|
||||
def startProducing(self, consumer):
|
||||
consumer.write(self.body)
|
||||
return defer.succeed(None)
|
||||
|
||||
def pauseProducing(self):
|
||||
pass
|
||||
|
||||
def stopProducing(self):
|
||||
pass
|
|
@ -138,8 +138,7 @@ class JsonResource(HttpServer, resource.Resource):
|
|||
)
|
||||
except CodeMessageException as e:
|
||||
if isinstance(e, SynapseError):
|
||||
logger.error("%s SynapseError: %s - %s", request, e.code,
|
||||
e.msg)
|
||||
logger.info("%s SynapseError: %s - %s", request, e.code, e.msg)
|
||||
else:
|
||||
logger.exception(e)
|
||||
self._send_response(
|
||||
|
|
|
@ -17,6 +17,7 @@ from twisted.internet import defer
|
|||
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.logcontext import PreserveLoggingContext
|
||||
from synapse.util.async import run_on_reactor
|
||||
|
||||
import logging
|
||||
|
||||
|
@ -96,6 +97,7 @@ class Notifier(object):
|
|||
listening to the room, and any listeners for the users in the
|
||||
`extra_users` param.
|
||||
"""
|
||||
yield run_on_reactor()
|
||||
room_id = event.room_id
|
||||
|
||||
room_source = self.event_sources.sources["room"]
|
||||
|
@ -143,6 +145,7 @@ class Notifier(object):
|
|||
|
||||
Will wake up all listeners for the given users and rooms.
|
||||
"""
|
||||
yield run_on_reactor()
|
||||
presence_source = self.event_sources.sources["presence"]
|
||||
|
||||
listeners = set()
|
||||
|
@ -211,6 +214,7 @@ class Notifier(object):
|
|||
timeout,
|
||||
deferred,
|
||||
)
|
||||
|
||||
def _timeout_listener():
|
||||
# TODO (erikj): We should probably set to_token to the current
|
||||
# max rather than reusing from_token.
|
||||
|
|
|
@ -26,7 +26,6 @@ import logging
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
class EventStreamRestServlet(RestServlet):
|
||||
PATTERN = client_path_pattern("/events$")
|
||||
|
||||
|
|
|
@ -117,8 +117,6 @@ class PresenceListRestServlet(RestServlet):
|
|||
logger.exception("JSON parse error")
|
||||
raise SynapseError(400, "Unable to parse content")
|
||||
|
||||
deferreds = []
|
||||
|
||||
if "invite" in content:
|
||||
for u in content["invite"]:
|
||||
if not isinstance(u, basestring):
|
||||
|
@ -126,8 +124,9 @@ class PresenceListRestServlet(RestServlet):
|
|||
if len(u) == 0:
|
||||
continue
|
||||
invited_user = self.hs.parse_userid(u)
|
||||
deferreds.append(self.handlers.presence_handler.send_invite(
|
||||
observer_user=user, observed_user=invited_user))
|
||||
yield self.handlers.presence_handler.send_invite(
|
||||
observer_user=user, observed_user=invited_user
|
||||
)
|
||||
|
||||
if "drop" in content:
|
||||
for u in content["drop"]:
|
||||
|
@ -136,10 +135,9 @@ class PresenceListRestServlet(RestServlet):
|
|||
if len(u) == 0:
|
||||
continue
|
||||
dropped_user = self.hs.parse_userid(u)
|
||||
deferreds.append(self.handlers.presence_handler.drop(
|
||||
observer_user=user, observed_user=dropped_user))
|
||||
|
||||
yield defer.DeferredList(deferreds)
|
||||
yield self.handlers.presence_handler.drop(
|
||||
observer_user=user, observed_user=dropped_user
|
||||
)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
|
|
@ -222,6 +222,7 @@ class RegisterRestServlet(RestServlet):
|
|||
|
||||
threepidCreds = register_json['threepidCreds']
|
||||
handler = self.handlers.registration_handler
|
||||
logger.debug("Registering email. threepidcreds: %s" % (threepidCreds))
|
||||
yield handler.register_email(threepidCreds)
|
||||
session["threepidCreds"] = threepidCreds # store creds for next stage
|
||||
session[LoginType.EMAIL_IDENTITY] = True # mark email as done
|
||||
|
@ -232,6 +233,7 @@ class RegisterRestServlet(RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def _do_password(self, request, register_json, session):
|
||||
yield
|
||||
if (self.hs.config.enable_registration_captcha and
|
||||
not session[LoginType.RECAPTCHA]):
|
||||
# captcha should've been done by this stage!
|
||||
|
@ -259,6 +261,9 @@ class RegisterRestServlet(RestServlet):
|
|||
)
|
||||
|
||||
if session[LoginType.EMAIL_IDENTITY]:
|
||||
logger.debug("Binding emails %s to %s" % (
|
||||
session["threepidCreds"], user_id)
|
||||
)
|
||||
yield handler.bind_emails(user_id, session["threepidCreds"])
|
||||
|
||||
result = {
|
||||
|
|
|
@ -148,7 +148,7 @@ class RoomStateEventRestServlet(RestServlet):
|
|||
content = _parse_json(request)
|
||||
|
||||
event = self.event_factory.create_event(
|
||||
etype=urllib.unquote(event_type),
|
||||
etype=event_type, # already urldecoded
|
||||
content=content,
|
||||
room_id=urllib.unquote(room_id),
|
||||
user_id=user.to_string(),
|
||||
|
|
|
@ -82,7 +82,7 @@ class StateHandler(object):
|
|||
if hasattr(event, "outlier") and event.outlier:
|
||||
event.state_group = None
|
||||
event.old_state_events = None
|
||||
event.state_events = {}
|
||||
event.state_events = None
|
||||
defer.returnValue(False)
|
||||
return
|
||||
|
||||
|
|
|
@ -67,7 +67,7 @@ SCHEMAS = [
|
|||
|
||||
# Remember to update this number every time an incompatible change is made to
|
||||
# database schema files, so the users will be informed on server restarts.
|
||||
SCHEMA_VERSION = 7
|
||||
SCHEMA_VERSION = 8
|
||||
|
||||
|
||||
class _RollbackButIsFineException(Exception):
|
||||
|
@ -93,7 +93,8 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def persist_event(self, event, backfilled=False, is_new_state=True):
|
||||
def persist_event(self, event, backfilled=False, is_new_state=True,
|
||||
current_state=None):
|
||||
stream_ordering = None
|
||||
if backfilled:
|
||||
if not self.min_token_deferred.called:
|
||||
|
@ -109,6 +110,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||
backfilled=backfilled,
|
||||
stream_ordering=stream_ordering,
|
||||
is_new_state=is_new_state,
|
||||
current_state=current_state,
|
||||
)
|
||||
except _RollbackButIsFineException:
|
||||
pass
|
||||
|
@ -137,7 +139,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||
|
||||
@log_function
|
||||
def _persist_event_txn(self, txn, event, backfilled, stream_ordering=None,
|
||||
is_new_state=True):
|
||||
is_new_state=True, current_state=None):
|
||||
if event.type == RoomMemberEvent.TYPE:
|
||||
self._store_room_member_txn(txn, event)
|
||||
elif event.type == FeedbackEvent.TYPE:
|
||||
|
@ -206,8 +208,24 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||
|
||||
self._store_state_groups_txn(txn, event)
|
||||
|
||||
if current_state:
|
||||
txn.execute("DELETE FROM current_state_events")
|
||||
|
||||
for s in current_state:
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
"current_state_events",
|
||||
{
|
||||
"event_id": s.event_id,
|
||||
"room_id": s.room_id,
|
||||
"type": s.type,
|
||||
"state_key": s.state_key,
|
||||
},
|
||||
or_replace=True,
|
||||
)
|
||||
|
||||
is_state = hasattr(event, "state_key") and event.state_key is not None
|
||||
if is_new_state and is_state:
|
||||
if is_state:
|
||||
vals = {
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
|
@ -225,17 +243,18 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||
or_replace=True,
|
||||
)
|
||||
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
"current_state_events",
|
||||
{
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"type": event.type,
|
||||
"state_key": event.state_key,
|
||||
},
|
||||
or_replace=True,
|
||||
)
|
||||
if is_new_state:
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
"current_state_events",
|
||||
{
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"type": event.type,
|
||||
"state_key": event.state_key,
|
||||
},
|
||||
or_replace=True,
|
||||
)
|
||||
|
||||
for e_id, h in event.prev_state:
|
||||
self._simple_insert_txn(
|
||||
|
@ -312,7 +331,12 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||
txn, event.event_id, ref_alg, ref_hash_bytes
|
||||
)
|
||||
|
||||
self._update_min_depth_for_room_txn(txn, event.room_id, event.depth)
|
||||
if not outlier:
|
||||
self._update_min_depth_for_room_txn(
|
||||
txn,
|
||||
event.room_id,
|
||||
event.depth
|
||||
)
|
||||
|
||||
def _store_redaction(self, txn, event):
|
||||
txn.execute(
|
||||
|
@ -508,7 +532,7 @@ def prepare_database(db_conn):
|
|||
"new for the server to understand"
|
||||
)
|
||||
elif user_version < SCHEMA_VERSION:
|
||||
logging.info(
|
||||
logger.info(
|
||||
"Upgrading database from version %d",
|
||||
user_version
|
||||
)
|
||||
|
|
|
@ -57,7 +57,7 @@ class LoggingTransaction(object):
|
|||
if args and args[0]:
|
||||
values = args[0]
|
||||
sql_logger.debug(
|
||||
"[SQL values] {%s} " + ", ".join(("<%s>",) * len(values)),
|
||||
"[SQL values] {%s} " + ", ".join(("<%r>",) * len(values)),
|
||||
self.name,
|
||||
*values
|
||||
)
|
||||
|
@ -91,6 +91,7 @@ class SQLBaseStore(object):
|
|||
def runInteraction(self, desc, func, *args, **kwargs):
|
||||
"""Wraps the .runInteraction() method on the underlying db_pool."""
|
||||
current_context = LoggingContext.current_context()
|
||||
|
||||
def inner_func(txn, *args, **kwargs):
|
||||
with LoggingContext("runInteraction") as context:
|
||||
current_context.copy_to(context)
|
||||
|
@ -115,7 +116,6 @@ class SQLBaseStore(object):
|
|||
"[TXN END] {%s} %f",
|
||||
name, end - start
|
||||
)
|
||||
|
||||
with PreserveLoggingContext():
|
||||
result = yield self._db_pool.runInteraction(
|
||||
inner_func, *args, **kwargs
|
||||
|
@ -246,7 +246,10 @@ class SQLBaseStore(object):
|
|||
raise StoreError(404, "No row found")
|
||||
|
||||
def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol):
|
||||
sql = "SELECT %(retcol)s FROM %(table)s WHERE %(where)s" % {
|
||||
sql = (
|
||||
"SELECT %(retcol)s FROM %(table)s WHERE %(where)s "
|
||||
"ORDER BY rowid asc"
|
||||
) % {
|
||||
"retcol": retcol,
|
||||
"table": table,
|
||||
"where": " AND ".join("%s = ?" % k for k in keyvalues.keys()),
|
||||
|
@ -299,7 +302,7 @@ class SQLBaseStore(object):
|
|||
keyvalues : dict of column names and values to select the rows with
|
||||
retcols : list of strings giving the names of the columns to return
|
||||
"""
|
||||
sql = "SELECT %s FROM %s WHERE %s" % (
|
||||
sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % (
|
||||
", ".join(retcols),
|
||||
table,
|
||||
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
|
||||
|
@ -334,7 +337,7 @@ class SQLBaseStore(object):
|
|||
retcols=None, allow_none=False):
|
||||
""" Combined SELECT then UPDATE."""
|
||||
if retcols:
|
||||
select_sql = "SELECT %s FROM %s WHERE %s" % (
|
||||
select_sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % (
|
||||
", ".join(retcols),
|
||||
table,
|
||||
" AND ".join("%s = ?" % (k) for k in keyvalues)
|
||||
|
@ -461,7 +464,7 @@ class SQLBaseStore(object):
|
|||
def _get_events_txn(self, txn, event_ids):
|
||||
# FIXME (erikj): This should be batched?
|
||||
|
||||
sql = "SELECT * FROM events WHERE event_id = ?"
|
||||
sql = "SELECT * FROM events WHERE event_id = ? ORDER BY rowid asc"
|
||||
|
||||
event_rows = []
|
||||
for e_id in event_ids:
|
||||
|
@ -478,7 +481,9 @@ class SQLBaseStore(object):
|
|||
def _parse_events_txn(self, txn, rows):
|
||||
events = [self._parse_event_from_row(r) for r in rows]
|
||||
|
||||
select_event_sql = "SELECT * FROM events WHERE event_id = ?"
|
||||
select_event_sql = (
|
||||
"SELECT * FROM events WHERE event_id = ? ORDER BY rowid asc"
|
||||
)
|
||||
|
||||
for i, ev in enumerate(events):
|
||||
signatures = self._get_event_signatures_txn(
|
||||
|
|
|
@ -75,7 +75,9 @@ class RegistrationStore(SQLBaseStore):
|
|||
"VALUES (?,?,?)",
|
||||
[user_id, password_hash, now])
|
||||
except IntegrityError:
|
||||
raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE)
|
||||
raise StoreError(
|
||||
400, "User ID already taken.", errcode=Codes.USER_IN_USE
|
||||
)
|
||||
|
||||
# it's possible for this to get a conflict, but only for a single user
|
||||
# since tokens are namespaced based on their user ID
|
||||
|
@ -83,8 +85,8 @@ class RegistrationStore(SQLBaseStore):
|
|||
"VALUES (?,?)", [txn.lastrowid, token])
|
||||
|
||||
def get_user_by_id(self, user_id):
|
||||
query = ("SELECT users.name, users.password_hash FROM users "
|
||||
"WHERE users.name = ?")
|
||||
query = ("SELECT users.name, users.password_hash FROM users"
|
||||
" WHERE users.name = ?")
|
||||
return self._execute(
|
||||
self.cursor_to_dict,
|
||||
query, user_id
|
||||
|
@ -120,10 +122,10 @@ class RegistrationStore(SQLBaseStore):
|
|||
|
||||
def _query_for_auth(self, txn, token):
|
||||
sql = (
|
||||
"SELECT users.name, users.admin, access_tokens.device_id "
|
||||
"FROM users "
|
||||
"INNER JOIN access_tokens on users.id = access_tokens.user_id "
|
||||
"WHERE token = ?"
|
||||
"SELECT users.name, users.admin, access_tokens.device_id"
|
||||
" FROM users"
|
||||
" INNER JOIN access_tokens on users.id = access_tokens.user_id"
|
||||
" WHERE token = ?"
|
||||
)
|
||||
|
||||
cursor = txn.execute(sql, (token,))
|
||||
|
|
|
@ -27,7 +27,9 @@ import logging
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
OpsLevel = collections.namedtuple("OpsLevel", ("ban_level", "kick_level", "redact_level"))
|
||||
OpsLevel = collections.namedtuple("OpsLevel", (
|
||||
"ban_level", "kick_level", "redact_level")
|
||||
)
|
||||
|
||||
|
||||
class RoomStore(SQLBaseStore):
|
||||
|
|
|
@ -177,8 +177,8 @@ class RoomMemberStore(SQLBaseStore):
|
|||
return self._get_members_query(clause, vals)
|
||||
|
||||
def _get_members_query(self, where_clause, where_values):
|
||||
return self._db_pool.runInteraction(
|
||||
self._get_members_query_txn,
|
||||
return self.runInteraction(
|
||||
"get_members_query", self._get_members_query_txn,
|
||||
where_clause, where_values
|
||||
)
|
||||
|
||||
|
|
34
synapse/storage/schema/delta/v8.sql
Normal file
34
synapse/storage/schema/delta/v8.sql
Normal file
|
@ -0,0 +1,34 @@
|
|||
/* Copyright 2014 OpenMarket Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
CREATE TABLE IF NOT EXISTS event_signatures_2 (
|
||||
event_id TEXT,
|
||||
signature_name TEXT,
|
||||
key_id TEXT,
|
||||
signature BLOB,
|
||||
CONSTRAINT uniqueness UNIQUE (event_id, signature_name, key_id)
|
||||
);
|
||||
|
||||
INSERT INTO event_signatures_2 (event_id, signature_name, key_id, signature)
|
||||
SELECT event_id, signature_name, key_id, signature FROM event_signatures;
|
||||
|
||||
DROP TABLE event_signatures;
|
||||
ALTER TABLE event_signatures_2 RENAME TO event_signatures;
|
||||
|
||||
CREATE INDEX IF NOT EXISTS event_signatures_id ON event_signatures (
|
||||
event_id
|
||||
);
|
||||
|
||||
PRAGMA user_version = 8;
|
|
@ -42,7 +42,7 @@ CREATE TABLE IF NOT EXISTS event_signatures (
|
|||
signature_name TEXT,
|
||||
key_id TEXT,
|
||||
signature BLOB,
|
||||
CONSTRAINT uniqueness UNIQUE (event_id, key_id)
|
||||
CONSTRAINT uniqueness UNIQUE (event_id, signature_name, key_id)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS event_signatures_id ON event_signatures (
|
||||
|
|
|
@ -36,7 +36,7 @@ class SignatureStore(SQLBaseStore):
|
|||
return dict(txn.fetchall())
|
||||
|
||||
def _store_event_content_hash_txn(self, txn, event_id, algorithm,
|
||||
hash_bytes):
|
||||
hash_bytes):
|
||||
"""Store a hash for a Event
|
||||
Args:
|
||||
txn (cursor):
|
||||
|
@ -84,7 +84,7 @@ class SignatureStore(SQLBaseStore):
|
|||
return dict(txn.fetchall())
|
||||
|
||||
def _store_event_reference_hash_txn(self, txn, event_id, algorithm,
|
||||
hash_bytes):
|
||||
hash_bytes):
|
||||
"""Store a hash for a PDU
|
||||
Args:
|
||||
txn (cursor):
|
||||
|
@ -127,7 +127,7 @@ class SignatureStore(SQLBaseStore):
|
|||
return res
|
||||
|
||||
def _store_event_signature_txn(self, txn, event_id, signature_name, key_id,
|
||||
signature_bytes):
|
||||
signature_bytes):
|
||||
"""Store a signature from the origin server for a PDU.
|
||||
Args:
|
||||
txn (cursor):
|
||||
|
@ -169,7 +169,7 @@ class SignatureStore(SQLBaseStore):
|
|||
return results
|
||||
|
||||
def _store_prev_event_hash_txn(self, txn, event_id, prev_event_id,
|
||||
algorithm, hash_bytes):
|
||||
algorithm, hash_bytes):
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
"event_edge_hashes",
|
||||
|
|
|
@ -87,7 +87,7 @@ class StateStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
def _store_state_groups_txn(self, txn, event):
|
||||
if not event.state_events:
|
||||
if event.state_events is None:
|
||||
return
|
||||
|
||||
state_group = event.state_group
|
||||
|
|
|
@ -213,8 +213,8 @@ class StreamStore(SQLBaseStore):
|
|||
# Tokens really represent positions between elements, but we use
|
||||
# the convention of pointing to the event before the gap. Hence
|
||||
# we have a bit of asymmetry when it comes to equalities.
|
||||
from_comp = '<=' if direction =='b' else '>'
|
||||
to_comp = '>' if direction =='b' else '<='
|
||||
from_comp = '<=' if direction == 'b' else '>'
|
||||
to_comp = '>' if direction == 'b' else '<='
|
||||
order = "DESC" if direction == 'b' else "ASC"
|
||||
|
||||
args = [room_id]
|
||||
|
@ -235,9 +235,10 @@ class StreamStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
sql = (
|
||||
"SELECT *, (%(redacted)s) AS redacted FROM events "
|
||||
"WHERE outlier = 0 AND room_id = ? AND %(bounds)s "
|
||||
"ORDER BY topological_ordering %(order)s, stream_ordering %(order)s %(limit)s "
|
||||
"SELECT *, (%(redacted)s) AS redacted FROM events"
|
||||
" WHERE outlier = 0 AND room_id = ? AND %(bounds)s"
|
||||
" ORDER BY topological_ordering %(order)s,"
|
||||
" stream_ordering %(order)s %(limit)s"
|
||||
) % {
|
||||
"redacted": del_sql,
|
||||
"bounds": bounds,
|
||||
|
|
|
@ -28,11 +28,11 @@ class SourcePaginationConfig(object):
|
|||
specific event source."""
|
||||
|
||||
def __init__(self, from_key=None, to_key=None, direction='f',
|
||||
limit=0):
|
||||
limit=None):
|
||||
self.from_key = from_key
|
||||
self.to_key = to_key
|
||||
self.direction = 'f' if direction == 'f' else 'b'
|
||||
self.limit = int(limit)
|
||||
self.limit = int(limit) if limit is not None else None
|
||||
|
||||
|
||||
class PaginationConfig(object):
|
||||
|
@ -40,11 +40,11 @@ class PaginationConfig(object):
|
|||
"""A configuration object which stores pagination parameters."""
|
||||
|
||||
def __init__(self, from_token=None, to_token=None, direction='f',
|
||||
limit=0):
|
||||
limit=None):
|
||||
self.from_token = from_token
|
||||
self.to_token = to_token
|
||||
self.direction = 'f' if direction == 'f' else 'b'
|
||||
self.limit = int(limit)
|
||||
self.limit = int(limit) if limit is not None else None
|
||||
|
||||
@classmethod
|
||||
def from_request(cls, request, raise_invalid_params=True):
|
||||
|
@ -80,8 +80,8 @@ class PaginationConfig(object):
|
|||
except:
|
||||
raise SynapseError(400, "'to' paramater is invalid")
|
||||
|
||||
limit = get_param("limit", "0")
|
||||
if not limit.isdigit():
|
||||
limit = get_param("limit", None)
|
||||
if limit is not None and not limit.isdigit():
|
||||
raise SynapseError(400, "'limit' parameter must be an integer.")
|
||||
|
||||
try:
|
||||
|
|
|
@ -37,6 +37,7 @@ class Clock(object):
|
|||
|
||||
def call_later(self, delay, callback):
|
||||
current_context = LoggingContext.current_context()
|
||||
|
||||
def wrapped_callback():
|
||||
LoggingContext.thread_local.current_context = current_context
|
||||
callback()
|
||||
|
|
|
@ -18,6 +18,7 @@ from twisted.internet import defer, reactor
|
|||
|
||||
from .logcontext import PreserveLoggingContext
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def sleep(seconds):
|
||||
d = defer.Deferred()
|
||||
|
@ -25,6 +26,7 @@ def sleep(seconds):
|
|||
with PreserveLoggingContext():
|
||||
yield d
|
||||
|
||||
|
||||
def run_on_reactor():
|
||||
""" This will cause the rest of the function to be invoked upon the next
|
||||
iteration of the main loop
|
||||
|
|
|
@ -13,6 +13,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.util.logcontext import PreserveLoggingContext
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
import logging
|
||||
|
@ -91,6 +93,7 @@ class Signal(object):
|
|||
Each observer callable may return a Deferred."""
|
||||
self.observers.append(observer)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def fire(self, *args, **kwargs):
|
||||
"""Invokes every callable in the observer list, passing in the args and
|
||||
kwargs. Exceptions thrown by observers are logged but ignored. It is
|
||||
|
@ -98,22 +101,24 @@ class Signal(object):
|
|||
|
||||
Returns a Deferred that will complete when all the observers have
|
||||
completed."""
|
||||
deferreds = []
|
||||
for observer in self.observers:
|
||||
d = defer.maybeDeferred(observer, *args, **kwargs)
|
||||
with PreserveLoggingContext():
|
||||
deferreds = []
|
||||
for observer in self.observers:
|
||||
d = defer.maybeDeferred(observer, *args, **kwargs)
|
||||
|
||||
def eb(failure):
|
||||
logger.warning(
|
||||
"%s signal observer %s failed: %r",
|
||||
self.name, observer, failure,
|
||||
exc_info=(
|
||||
failure.type,
|
||||
failure.value,
|
||||
failure.getTracebackObject()))
|
||||
if not self.suppress_failures:
|
||||
raise failure
|
||||
deferreds.append(d.addErrback(eb))
|
||||
def eb(failure):
|
||||
logger.warning(
|
||||
"%s signal observer %s failed: %r",
|
||||
self.name, observer, failure,
|
||||
exc_info=(
|
||||
failure.type,
|
||||
failure.value,
|
||||
failure.getTracebackObject()))
|
||||
if not self.suppress_failures:
|
||||
raise failure
|
||||
deferreds.append(d.addErrback(eb))
|
||||
|
||||
return defer.DeferredList(
|
||||
deferreds, fireOnOneErrback=not self.suppress_failures
|
||||
)
|
||||
result = yield defer.DeferredList(
|
||||
deferreds, fireOnOneErrback=not self.suppress_failures
|
||||
)
|
||||
defer.returnValue(result)
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
import threading
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoggingContext(object):
|
||||
"""Additional context for log formatting. Contexts are scoped within a
|
||||
|
@ -53,11 +55,14 @@ class LoggingContext(object):
|
|||
None to avoid suppressing any exeptions that were thrown.
|
||||
"""
|
||||
if self.thread_local.current_context is not self:
|
||||
logging.error(
|
||||
"Current logging context %s is not the expected context %s",
|
||||
self.thread_local.current_context,
|
||||
self
|
||||
)
|
||||
if self.thread_local.current_context is self.sentinel:
|
||||
logger.debug("Expected logging context %s has been lost", self)
|
||||
else:
|
||||
logger.warn(
|
||||
"Current logging context %s is not expected context %s",
|
||||
self.thread_local.current_context,
|
||||
self
|
||||
)
|
||||
self.thread_local.current_context = self.parent_context
|
||||
self.parent_context = None
|
||||
|
||||
|
|
|
@ -83,20 +83,22 @@ class FederationTestCase(unittest.TestCase):
|
|||
event_id="$a:b",
|
||||
user_id="@a:b",
|
||||
origin="b",
|
||||
auth_events=[],
|
||||
hashes={"sha256":"AcLrgtUIqqwaGoHhrEvYG1YLDIsVPYJdSRGhkp3jJp8"},
|
||||
)
|
||||
|
||||
self.datastore.persist_event.return_value = defer.succeed(None)
|
||||
self.datastore.get_room.return_value = defer.succeed(True)
|
||||
|
||||
self.state_handler.annotate_event_with_state.return_value = (
|
||||
defer.succeed(False)
|
||||
)
|
||||
def annotate(ev, old_state=None):
|
||||
ev.old_state_events = []
|
||||
return defer.succeed(False)
|
||||
self.state_handler.annotate_event_with_state.side_effect = annotate
|
||||
|
||||
yield self.handlers.federation_handler.on_receive_pdu(pdu, False)
|
||||
|
||||
self.datastore.persist_event.assert_called_once_with(
|
||||
ANY, False, is_new_state=False
|
||||
ANY, is_new_state=False, backfilled=False, current_state=None
|
||||
)
|
||||
|
||||
self.state_handler.annotate_event_with_state.assert_called_once_with(
|
||||
|
@ -104,7 +106,7 @@ class FederationTestCase(unittest.TestCase):
|
|||
old_state=None,
|
||||
)
|
||||
|
||||
self.auth.check.assert_called_once_with(ANY, raises=True)
|
||||
self.auth.check.assert_called_once_with(ANY, auth_events={})
|
||||
|
||||
self.notifier.on_new_room_event.assert_called_once_with(
|
||||
ANY,
|
||||
|
|
|
@ -120,7 +120,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
|||
|
||||
self.datastore.get_room_member.return_value = defer.succeed(None)
|
||||
|
||||
event.state_events = {
|
||||
event.old_state_events = {
|
||||
(RoomMemberEvent.TYPE, "@alice:green"): self._create_member(
|
||||
user_id="@alice:green",
|
||||
room_id=room_id,
|
||||
|
@ -129,9 +129,11 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
|||
user_id="@bob:red",
|
||||
room_id=room_id,
|
||||
),
|
||||
(RoomMemberEvent.TYPE, target_user_id): event,
|
||||
}
|
||||
|
||||
event.state_events = event.old_state_events
|
||||
event.state_events[(RoomMemberEvent.TYPE, target_user_id)] = event
|
||||
|
||||
# Actual invocation
|
||||
yield self.room_member_handler.change_membership(event)
|
||||
|
||||
|
@ -187,6 +189,16 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
|||
(RoomMemberEvent.TYPE, user_id): event,
|
||||
}
|
||||
|
||||
event.old_state_events = {
|
||||
(RoomMemberEvent.TYPE, "@alice:green"): self._create_member(
|
||||
user_id="@alice:green",
|
||||
room_id=room_id,
|
||||
),
|
||||
}
|
||||
|
||||
event.state_events = event.old_state_events
|
||||
event.state_events[(RoomMemberEvent.TYPE, user_id)] = event
|
||||
|
||||
# Actual invocation
|
||||
yield self.room_member_handler.change_membership(event)
|
||||
|
||||
|
|
|
@ -84,7 +84,8 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
|||
|
||||
self.assertEquals("Value", value)
|
||||
self.mock_txn.execute.assert_called_with(
|
||||
"SELECT retcol FROM tablename WHERE keycol = ?",
|
||||
"SELECT retcol FROM tablename WHERE keycol = ? "
|
||||
"ORDER BY rowid asc",
|
||||
["TheKey"]
|
||||
)
|
||||
|
||||
|
@ -101,7 +102,8 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
|||
|
||||
self.assertEquals({"colA": 1, "colB": 2, "colC": 3}, ret)
|
||||
self.mock_txn.execute.assert_called_with(
|
||||
"SELECT colA, colB, colC FROM tablename WHERE keycol = ?",
|
||||
"SELECT colA, colB, colC FROM tablename WHERE keycol = ? "
|
||||
"ORDER BY rowid asc",
|
||||
["TheKey"]
|
||||
)
|
||||
|
||||
|
@ -135,7 +137,8 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
|||
|
||||
self.assertEquals([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret)
|
||||
self.mock_txn.execute.assert_called_with(
|
||||
"SELECT colA FROM tablename WHERE keycol = ?",
|
||||
"SELECT colA FROM tablename WHERE keycol = ? "
|
||||
"ORDER BY rowid asc",
|
||||
["A set"]
|
||||
)
|
||||
|
||||
|
@ -184,7 +187,8 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
|||
|
||||
self.assertEquals({"columname": "Old Value"}, ret)
|
||||
self.mock_txn.execute.assert_has_calls([
|
||||
call('SELECT columname FROM tablename WHERE keycol = ?',
|
||||
call('SELECT columname FROM tablename WHERE keycol = ? '
|
||||
'ORDER BY rowid asc',
|
||||
['TheKey']),
|
||||
call("UPDATE tablename SET columname = ? WHERE keycol = ?",
|
||||
["New Value", "TheKey"])
|
||||
|
|
Loading…
Reference in a new issue