forked from MirrorHub/synapse
Move validation logic for AS 3PE query response into ApplicationServiceApi class, to keep the handler logic neater
This commit is contained in:
parent
697872cf08
commit
65201631a4
2 changed files with 44 additions and 45 deletions
|
@ -25,6 +25,28 @@ import urllib
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_valid_3pe_result(r, field):
|
||||||
|
if not isinstance(r, dict):
|
||||||
|
return False
|
||||||
|
|
||||||
|
for k in (field, "protocol"):
|
||||||
|
if k not in r:
|
||||||
|
return False
|
||||||
|
if not isinstance(r[k], str):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if "fields" not in r:
|
||||||
|
return False
|
||||||
|
fields = r["fields"]
|
||||||
|
if not isinstance(fields, dict):
|
||||||
|
return False
|
||||||
|
for k in fields.keys():
|
||||||
|
if not isinstance(fields[k], str):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class ApplicationServiceApi(SimpleHttpClient):
|
class ApplicationServiceApi(SimpleHttpClient):
|
||||||
"""This class manages HS -> AS communications, including querying and
|
"""This class manages HS -> AS communications, including querying and
|
||||||
pushing.
|
pushing.
|
||||||
|
@ -76,8 +98,10 @@ class ApplicationServiceApi(SimpleHttpClient):
|
||||||
def query_3pe(self, service, kind, protocol, fields):
|
def query_3pe(self, service, kind, protocol, fields):
|
||||||
if kind == ThirdPartyEntityKind.USER:
|
if kind == ThirdPartyEntityKind.USER:
|
||||||
uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol))
|
uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol))
|
||||||
|
required_field = "userid"
|
||||||
elif kind == ThirdPartyEntityKind.LOCATION:
|
elif kind == ThirdPartyEntityKind.LOCATION:
|
||||||
uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol))
|
uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol))
|
||||||
|
required_field = "alias"
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unrecognised 'kind' argument %r to query_3pe()", kind
|
"Unrecognised 'kind' argument %r to query_3pe()", kind
|
||||||
|
@ -85,7 +109,24 @@ class ApplicationServiceApi(SimpleHttpClient):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = yield self.get_json(uri, fields)
|
response = yield self.get_json(uri, fields)
|
||||||
defer.returnValue(response)
|
if not isinstance(response, list):
|
||||||
|
logger.warning(
|
||||||
|
"query_3pe to %s returned an invalid response %r",
|
||||||
|
uri, response
|
||||||
|
)
|
||||||
|
defer.returnValue([])
|
||||||
|
|
||||||
|
ret = []
|
||||||
|
for r in response:
|
||||||
|
if _is_valid_3pe_result(r, field=required_field):
|
||||||
|
ret.append(r)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"query_3pe to %s returned an invalid result %r",
|
||||||
|
uri, r
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue(ret)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logger.warning("query_3pe to %s threw exception %s", uri, ex)
|
logger.warning("query_3pe to %s threw exception %s", uri, ex)
|
||||||
defer.returnValue([])
|
defer.returnValue([])
|
||||||
|
|
|
@ -18,7 +18,6 @@ from twisted.internet import defer
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
from synapse.util.logcontext import preserve_fn
|
from synapse.util.logcontext import preserve_fn
|
||||||
from synapse.types import ThirdPartyEntityKind
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
@ -36,28 +35,6 @@ def log_failure(failure):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _is_valid_3pentity_result(r, field):
|
|
||||||
if not isinstance(r, dict):
|
|
||||||
return False
|
|
||||||
|
|
||||||
for k in (field, "protocol"):
|
|
||||||
if k not in r:
|
|
||||||
return False
|
|
||||||
if not isinstance(r[k], str):
|
|
||||||
return False
|
|
||||||
|
|
||||||
if "fields" not in r:
|
|
||||||
return False
|
|
||||||
fields = r["fields"]
|
|
||||||
if not isinstance(fields, dict):
|
|
||||||
return False
|
|
||||||
for k in fields.keys():
|
|
||||||
if not isinstance(fields[k], str):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class ApplicationServicesHandler(object):
|
class ApplicationServicesHandler(object):
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
|
@ -178,29 +155,10 @@ class ApplicationServicesHandler(object):
|
||||||
for service in services
|
for service in services
|
||||||
], consumeErrors=True)
|
], consumeErrors=True)
|
||||||
|
|
||||||
required_field = (
|
|
||||||
"userid" if kind == ThirdPartyEntityKind.USER else
|
|
||||||
"alias" if kind == ThirdPartyEntityKind.LOCATION else
|
|
||||||
None
|
|
||||||
)
|
|
||||||
|
|
||||||
ret = []
|
ret = []
|
||||||
for (success, result) in results:
|
for (success, result) in results:
|
||||||
if not success:
|
if success:
|
||||||
logger.warn("Application service failed %r", result)
|
ret.extend(result)
|
||||||
continue
|
|
||||||
if not isinstance(result, list):
|
|
||||||
logger.warn(
|
|
||||||
"Application service returned an invalid response %r", result
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
for r in result:
|
|
||||||
if _is_valid_3pentity_result(r, field=required_field):
|
|
||||||
ret.append(r)
|
|
||||||
else:
|
|
||||||
logger.warn(
|
|
||||||
"Application service returned an invalid result %r", r
|
|
||||||
)
|
|
||||||
|
|
||||||
defer.returnValue(ret)
|
defer.returnValue(ret)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue