0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2025-01-19 05:21:55 +01:00

Copypasta the 3PU support code to also do 3PL

This commit is contained in:
Paul "LeoNerd" Evans 2016-08-18 16:09:50 +01:00
parent f3afd6ef1a
commit 06964c4a0a
3 changed files with 61 additions and 3 deletions

View file

@ -82,6 +82,17 @@ class ApplicationServiceApi(SimpleHttpClient):
logger.warning("query_3pu to %s threw exception %s", uri, ex)
defer.returnValue([])
@defer.inlineCallbacks
def query_3pl(self, service, protocol, fields):
uri = service.url + ("/3pl/%s" % urllib.quote(protocol))
response = None
try:
response = yield self.get_json(uri, fields)
defer.returnValue(response)
except Exception as ex:
logger.warning("query_3pl to %s threw exception %s", uri, ex)
defer.returnValue([])
@defer.inlineCallbacks
def push_bulk(self, service, events, txn_id=None):
events = self._serialize(events)

View file

@ -34,11 +34,11 @@ def log_failure(failure):
)
)
def _is_valid_3pu_result(r):
def _is_valid_3pentity_result(r, field):
if not isinstance(r, dict):
return False
for k in ("userid", "protocol"):
for k in (field, "protocol"):
if k not in r:
return False
if not isinstance(r[k], str):
@ -185,7 +185,34 @@ class ApplicationServicesHandler(object):
if not isinstance(result, list):
continue
for r in result:
if _is_valid_3pu_result(r):
if _is_valid_3pentity_result(r, field="userid"):
ret.append(r)
else:
logger.warn("Application service returned an " +
"invalid result %r", r)
defer.returnValue(ret)
@defer.inlineCallbacks
def query_3pl(self, protocol, fields):
services = yield self._get_services_for_3pn(protocol)
deferreds = []
for service in services:
deferreds.append(self.appservice_api.query_3pl(
service, protocol, fields
))
results = yield defer.DeferredList(deferreds, consumeErrors=True)
ret = []
for (success, result) in results:
if not success:
continue
if not isinstance(result, list):
continue
for r in result:
if _is_valid_3pentity_result(r, field="alias"):
ret.append(r)
else:
logger.warn("Application service returned an " +

View file

@ -43,5 +43,25 @@ class ThirdPartyUserServlet(RestServlet):
defer.returnValue((200, results))
class ThirdPartyLocationServlet(RestServlet):
PATTERNS = client_v2_patterns("/3pl(/(?P<protocol>[^/]+))?$",
releases=())
def __init__(self, hs):
super(ThirdPartyLocationServlet, self).__init__()
self.appservice_handler = hs.get_application_service_handler()
@defer.inlineCallbacks
def on_GET(self, request, protocol):
fields = request.args
del fields["access_token"]
results = yield self.appservice_handler.query_3pl(protocol, fields)
defer.returnValue((200, results))
def register_servlets(hs, http_server):
ThirdPartyUserServlet(hs).register(http_server)
ThirdPartyLocationServlet(hs).register(http_server)