Merge branch 'develop' into markjh/bearer_token

This commit is contained in:
Mark Haines 2016-09-09 18:51:22 +01:00
commit 3ddec016ff
7 changed files with 89 additions and 9 deletions

View file

@ -32,6 +32,14 @@ HOUR_IN_MS = 60 * 60 * 1000
APP_SERVICE_PREFIX = "/_matrix/app/unstable" APP_SERVICE_PREFIX = "/_matrix/app/unstable"
def _is_valid_3pe_metadata(info):
if "instances" not in info:
return False
if not isinstance(info["instances"], list):
return False
return True
def _is_valid_3pe_result(r, field): def _is_valid_3pe_result(r, field):
if not isinstance(r, dict): if not isinstance(r, dict):
return False return False
@ -162,11 +170,18 @@ class ApplicationServiceApi(SimpleHttpClient):
urllib.quote(protocol) urllib.quote(protocol)
) )
try: try:
defer.returnValue((yield self.get_json(uri, {}))) info = yield self.get_json(uri, {})
if not _is_valid_3pe_metadata(info):
logger.warning("query_3pe_protocol to %s did not return a"
" valid result", uri)
defer.returnValue(None)
defer.returnValue(info)
except Exception as ex: except Exception as ex:
logger.warning("query_3pe_protocol to %s threw exception %s", logger.warning("query_3pe_protocol to %s threw exception %s",
uri, ex) uri, ex)
defer.returnValue({}) defer.returnValue(None)
key = (service.id, protocol) key = (service.id, protocol)
return self.protocol_meta_cache.get(key) or ( return self.protocol_meta_cache.get(key) or (

View file

@ -229,7 +229,6 @@ class TransactionQueue(object):
"dropping transaction for now", "dropping transaction for now",
destination, destination,
) )
success = False
finally: finally:
# We want to be *very* sure we delete this after we stop processing # We want to be *very* sure we delete this after we stop processing
self.pending_transactions.pop(destination, None) self.pending_transactions.pop(destination, None)

View file

@ -176,12 +176,41 @@ class ApplicationServicesHandler(object):
defer.returnValue(ret) defer.returnValue(ret)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_3pe_protocols(self): def get_3pe_protocols(self, only_protocol=None):
services = yield self.store.get_app_services() services = yield self.store.get_app_services()
protocols = {} protocols = {}
# Collect up all the individual protocol responses out of the ASes
for s in services: for s in services:
for p in s.protocols: for p in s.protocols:
protocols[p] = yield self.appservice_api.get_3pe_protocol(s, p) if only_protocol is not None and p != only_protocol:
continue
if p not in protocols:
protocols[p] = []
info = yield self.appservice_api.get_3pe_protocol(s, p)
if info is not None:
protocols[p].append(info)
def _merge_instances(infos):
if not infos:
return {}
# Merge the 'instances' lists of multiple results, but just take
# the other fields from the first as they ought to be identical
# copy the result so as not to corrupt the cached one
combined = dict(infos[0])
combined["instances"] = list(combined["instances"])
for info in infos[1:]:
combined["instances"].extend(info["instances"])
return combined
for p in protocols.keys():
protocols[p] = _merge_instances(protocols[p])
defer.returnValue(protocols) defer.returnValue(protocols)

View file

@ -265,6 +265,12 @@ class PresenceHandler(object):
to_notify = {} # Changes we want to notify everyone about to_notify = {} # Changes we want to notify everyone about
to_federation_ping = {} # These need sending keep-alives to_federation_ping = {} # These need sending keep-alives
# Only bother handling the last presence change for each user
new_states_dict = {}
for new_state in new_states:
new_states_dict[new_state.user_id] = new_state
new_state = new_states_dict.values()
for new_state in new_states: for new_state in new_states:
user_id = new_state.user_id user_id = new_state.user_id

View file

@ -181,7 +181,7 @@ class ReplicationResource(Resource):
def replicate(self, request_streams, limit): def replicate(self, request_streams, limit):
writer = _Writer() writer = _Writer()
current_token = yield self.current_replication_token() current_token = yield self.current_replication_token()
logger.info("Replicating up to %r", current_token) logger.debug("Replicating up to %r", current_token)
yield self.account_data(writer, current_token, limit, request_streams) yield self.account_data(writer, current_token, limit, request_streams)
yield self.events(writer, current_token, limit, request_streams) yield self.events(writer, current_token, limit, request_streams)
@ -195,7 +195,7 @@ class ReplicationResource(Resource):
yield self.to_device(writer, current_token, limit, request_streams) yield self.to_device(writer, current_token, limit, request_streams)
self.streams(writer, current_token, request_streams) self.streams(writer, current_token, request_streams)
logger.info("Replicated %d rows", writer.total) logger.debug("Replicated %d rows", writer.total)
defer.returnValue(writer.finish()) defer.returnValue(writer.finish())
def streams(self, writer, current_token, request_streams): def streams(self, writer, current_token, request_streams):

View file

@ -22,7 +22,7 @@ from synapse.streams.config import PaginationConfig
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.filtering import Filter from synapse.api.filtering import Filter
from synapse.types import UserID, RoomID, RoomAlias from synapse.types import UserID, RoomID, RoomAlias
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event, format_event_for_client_v2
from synapse.http.servlet import parse_json_object_from_request, parse_string from synapse.http.servlet import parse_json_object_from_request, parse_string
import logging import logging
@ -120,6 +120,8 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id, event_type, state_key): def on_GET(self, request, room_id, event_type, state_key):
requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
format = parse_string(request, "format", default="content",
allowed_values=["content", "event"])
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
data = yield msg_handler.get_room_data( data = yield msg_handler.get_room_data(
@ -134,6 +136,11 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
raise SynapseError( raise SynapseError(
404, "Event not found.", errcode=Codes.NOT_FOUND 404, "Event not found.", errcode=Codes.NOT_FOUND
) )
if format == "event":
event = format_event_for_client_v2(data.get_dict())
defer.returnValue((200, event))
elif format == "content":
defer.returnValue((200, data.get_dict()["content"])) defer.returnValue((200, data.get_dict()["content"]))
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -42,6 +42,29 @@ class ThirdPartyProtocolsServlet(RestServlet):
defer.returnValue((200, protocols)) defer.returnValue((200, protocols))
class ThirdPartyProtocolServlet(RestServlet):
PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$",
releases=())
def __init__(self, hs):
super(ThirdPartyProtocolServlet, self).__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
@defer.inlineCallbacks
def on_GET(self, request, protocol):
yield self.auth.get_user_by_req(request)
protocols = yield self.appservice_handler.get_3pe_protocols(
only_protocol=protocol,
)
if protocol in protocols:
defer.returnValue((200, protocols[protocol]))
else:
defer.returnValue((404, {"error": "Unknown protocol"}))
class ThirdPartyUserServlet(RestServlet): class ThirdPartyUserServlet(RestServlet):
PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$", PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$",
releases=()) releases=())
@ -92,5 +115,6 @@ class ThirdPartyLocationServlet(RestServlet):
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
ThirdPartyProtocolsServlet(hs).register(http_server) ThirdPartyProtocolsServlet(hs).register(http_server)
ThirdPartyProtocolServlet(hs).register(http_server)
ThirdPartyUserServlet(hs).register(http_server) ThirdPartyUserServlet(hs).register(http_server)
ThirdPartyLocationServlet(hs).register(http_server) ThirdPartyLocationServlet(hs).register(http_server)