mirror of
https://mau.dev/maunium/synapse.git
synced 2024-11-15 22:42:23 +01:00
Merge branch 'develop' into markjh/bearer_token
This commit is contained in:
commit
3ddec016ff
7 changed files with 89 additions and 9 deletions
|
@ -32,6 +32,14 @@ HOUR_IN_MS = 60 * 60 * 1000
|
||||||
APP_SERVICE_PREFIX = "/_matrix/app/unstable"
|
APP_SERVICE_PREFIX = "/_matrix/app/unstable"
|
||||||
|
|
||||||
|
|
||||||
|
def _is_valid_3pe_metadata(info):
|
||||||
|
if "instances" not in info:
|
||||||
|
return False
|
||||||
|
if not isinstance(info["instances"], list):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def _is_valid_3pe_result(r, field):
|
def _is_valid_3pe_result(r, field):
|
||||||
if not isinstance(r, dict):
|
if not isinstance(r, dict):
|
||||||
return False
|
return False
|
||||||
|
@ -162,11 +170,18 @@ class ApplicationServiceApi(SimpleHttpClient):
|
||||||
urllib.quote(protocol)
|
urllib.quote(protocol)
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
defer.returnValue((yield self.get_json(uri, {})))
|
info = yield self.get_json(uri, {})
|
||||||
|
|
||||||
|
if not _is_valid_3pe_metadata(info):
|
||||||
|
logger.warning("query_3pe_protocol to %s did not return a"
|
||||||
|
" valid result", uri)
|
||||||
|
defer.returnValue(None)
|
||||||
|
|
||||||
|
defer.returnValue(info)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logger.warning("query_3pe_protocol to %s threw exception %s",
|
logger.warning("query_3pe_protocol to %s threw exception %s",
|
||||||
uri, ex)
|
uri, ex)
|
||||||
defer.returnValue({})
|
defer.returnValue(None)
|
||||||
|
|
||||||
key = (service.id, protocol)
|
key = (service.id, protocol)
|
||||||
return self.protocol_meta_cache.get(key) or (
|
return self.protocol_meta_cache.get(key) or (
|
||||||
|
|
|
@ -229,7 +229,6 @@ class TransactionQueue(object):
|
||||||
"dropping transaction for now",
|
"dropping transaction for now",
|
||||||
destination,
|
destination,
|
||||||
)
|
)
|
||||||
success = False
|
|
||||||
finally:
|
finally:
|
||||||
# We want to be *very* sure we delete this after we stop processing
|
# We want to be *very* sure we delete this after we stop processing
|
||||||
self.pending_transactions.pop(destination, None)
|
self.pending_transactions.pop(destination, None)
|
||||||
|
|
|
@ -176,12 +176,41 @@ class ApplicationServicesHandler(object):
|
||||||
defer.returnValue(ret)
|
defer.returnValue(ret)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_3pe_protocols(self):
|
def get_3pe_protocols(self, only_protocol=None):
|
||||||
services = yield self.store.get_app_services()
|
services = yield self.store.get_app_services()
|
||||||
protocols = {}
|
protocols = {}
|
||||||
|
|
||||||
|
# Collect up all the individual protocol responses out of the ASes
|
||||||
for s in services:
|
for s in services:
|
||||||
for p in s.protocols:
|
for p in s.protocols:
|
||||||
protocols[p] = yield self.appservice_api.get_3pe_protocol(s, p)
|
if only_protocol is not None and p != only_protocol:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if p not in protocols:
|
||||||
|
protocols[p] = []
|
||||||
|
|
||||||
|
info = yield self.appservice_api.get_3pe_protocol(s, p)
|
||||||
|
|
||||||
|
if info is not None:
|
||||||
|
protocols[p].append(info)
|
||||||
|
|
||||||
|
def _merge_instances(infos):
|
||||||
|
if not infos:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Merge the 'instances' lists of multiple results, but just take
|
||||||
|
# the other fields from the first as they ought to be identical
|
||||||
|
# copy the result so as not to corrupt the cached one
|
||||||
|
combined = dict(infos[0])
|
||||||
|
combined["instances"] = list(combined["instances"])
|
||||||
|
|
||||||
|
for info in infos[1:]:
|
||||||
|
combined["instances"].extend(info["instances"])
|
||||||
|
|
||||||
|
return combined
|
||||||
|
|
||||||
|
for p in protocols.keys():
|
||||||
|
protocols[p] = _merge_instances(protocols[p])
|
||||||
|
|
||||||
defer.returnValue(protocols)
|
defer.returnValue(protocols)
|
||||||
|
|
||||||
|
|
|
@ -265,6 +265,12 @@ class PresenceHandler(object):
|
||||||
to_notify = {} # Changes we want to notify everyone about
|
to_notify = {} # Changes we want to notify everyone about
|
||||||
to_federation_ping = {} # These need sending keep-alives
|
to_federation_ping = {} # These need sending keep-alives
|
||||||
|
|
||||||
|
# Only bother handling the last presence change for each user
|
||||||
|
new_states_dict = {}
|
||||||
|
for new_state in new_states:
|
||||||
|
new_states_dict[new_state.user_id] = new_state
|
||||||
|
new_state = new_states_dict.values()
|
||||||
|
|
||||||
for new_state in new_states:
|
for new_state in new_states:
|
||||||
user_id = new_state.user_id
|
user_id = new_state.user_id
|
||||||
|
|
||||||
|
|
|
@ -181,7 +181,7 @@ class ReplicationResource(Resource):
|
||||||
def replicate(self, request_streams, limit):
|
def replicate(self, request_streams, limit):
|
||||||
writer = _Writer()
|
writer = _Writer()
|
||||||
current_token = yield self.current_replication_token()
|
current_token = yield self.current_replication_token()
|
||||||
logger.info("Replicating up to %r", current_token)
|
logger.debug("Replicating up to %r", current_token)
|
||||||
|
|
||||||
yield self.account_data(writer, current_token, limit, request_streams)
|
yield self.account_data(writer, current_token, limit, request_streams)
|
||||||
yield self.events(writer, current_token, limit, request_streams)
|
yield self.events(writer, current_token, limit, request_streams)
|
||||||
|
@ -195,7 +195,7 @@ class ReplicationResource(Resource):
|
||||||
yield self.to_device(writer, current_token, limit, request_streams)
|
yield self.to_device(writer, current_token, limit, request_streams)
|
||||||
self.streams(writer, current_token, request_streams)
|
self.streams(writer, current_token, request_streams)
|
||||||
|
|
||||||
logger.info("Replicated %d rows", writer.total)
|
logger.debug("Replicated %d rows", writer.total)
|
||||||
defer.returnValue(writer.finish())
|
defer.returnValue(writer.finish())
|
||||||
|
|
||||||
def streams(self, writer, current_token, request_streams):
|
def streams(self, writer, current_token, request_streams):
|
||||||
|
|
|
@ -22,7 +22,7 @@ from synapse.streams.config import PaginationConfig
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.api.filtering import Filter
|
from synapse.api.filtering import Filter
|
||||||
from synapse.types import UserID, RoomID, RoomAlias
|
from synapse.types import UserID, RoomID, RoomAlias
|
||||||
from synapse.events.utils import serialize_event
|
from synapse.events.utils import serialize_event, format_event_for_client_v2
|
||||||
from synapse.http.servlet import parse_json_object_from_request, parse_string
|
from synapse.http.servlet import parse_json_object_from_request, parse_string
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
@ -120,6 +120,8 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, room_id, event_type, state_key):
|
def on_GET(self, request, room_id, event_type, state_key):
|
||||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
|
format = parse_string(request, "format", default="content",
|
||||||
|
allowed_values=["content", "event"])
|
||||||
|
|
||||||
msg_handler = self.handlers.message_handler
|
msg_handler = self.handlers.message_handler
|
||||||
data = yield msg_handler.get_room_data(
|
data = yield msg_handler.get_room_data(
|
||||||
|
@ -134,6 +136,11 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
404, "Event not found.", errcode=Codes.NOT_FOUND
|
404, "Event not found.", errcode=Codes.NOT_FOUND
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if format == "event":
|
||||||
|
event = format_event_for_client_v2(data.get_dict())
|
||||||
|
defer.returnValue((200, event))
|
||||||
|
elif format == "content":
|
||||||
defer.returnValue((200, data.get_dict()["content"]))
|
defer.returnValue((200, data.get_dict()["content"]))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
|
@ -42,6 +42,29 @@ class ThirdPartyProtocolsServlet(RestServlet):
|
||||||
defer.returnValue((200, protocols))
|
defer.returnValue((200, protocols))
|
||||||
|
|
||||||
|
|
||||||
|
class ThirdPartyProtocolServlet(RestServlet):
|
||||||
|
PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$",
|
||||||
|
releases=())
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(ThirdPartyProtocolServlet, self).__init__()
|
||||||
|
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.appservice_handler = hs.get_application_service_handler()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_GET(self, request, protocol):
|
||||||
|
yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
|
protocols = yield self.appservice_handler.get_3pe_protocols(
|
||||||
|
only_protocol=protocol,
|
||||||
|
)
|
||||||
|
if protocol in protocols:
|
||||||
|
defer.returnValue((200, protocols[protocol]))
|
||||||
|
else:
|
||||||
|
defer.returnValue((404, {"error": "Unknown protocol"}))
|
||||||
|
|
||||||
|
|
||||||
class ThirdPartyUserServlet(RestServlet):
|
class ThirdPartyUserServlet(RestServlet):
|
||||||
PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$",
|
PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$",
|
||||||
releases=())
|
releases=())
|
||||||
|
@ -92,5 +115,6 @@ class ThirdPartyLocationServlet(RestServlet):
|
||||||
|
|
||||||
def register_servlets(hs, http_server):
|
def register_servlets(hs, http_server):
|
||||||
ThirdPartyProtocolsServlet(hs).register(http_server)
|
ThirdPartyProtocolsServlet(hs).register(http_server)
|
||||||
|
ThirdPartyProtocolServlet(hs).register(http_server)
|
||||||
ThirdPartyUserServlet(hs).register(http_server)
|
ThirdPartyUserServlet(hs).register(http_server)
|
||||||
ThirdPartyLocationServlet(hs).register(http_server)
|
ThirdPartyLocationServlet(hs).register(http_server)
|
||||||
|
|
Loading…
Reference in a new issue