Filter 3PU lookups by only ASes that declare knowledge of that protocol

This commit is contained in:
Paul "LeoNerd" Evans 2016-08-18 14:56:02 +01:00
parent d5bf7a4a99
commit 434bbf2cb5
3 changed files with 22 additions and 3 deletions

View file

@ -81,13 +81,17 @@ class ApplicationService(object):
NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS] NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS]
def __init__(self, token, url=None, namespaces=None, hs_token=None, def __init__(self, token, url=None, namespaces=None, hs_token=None,
sender=None, id=None): sender=None, id=None, protocols=None):
self.token = token self.token = token
self.url = url self.url = url
self.hs_token = hs_token self.hs_token = hs_token
self.sender = sender self.sender = sender
self.namespaces = self._check_namespaces(namespaces) self.namespaces = self._check_namespaces(namespaces)
self.id = id self.id = id
if protocols:
self.protocols = set(protocols)
else:
self.protocols = set()
def _check_namespaces(self, namespaces): def _check_namespaces(self, namespaces):
# Sanity check that it is of the form: # Sanity check that it is of the form:
@ -219,6 +223,9 @@ class ApplicationService(object):
or user_id == self.sender or user_id == self.sender
) )
def is_interested_in_protocol(self, protocol):
return protocol in self.protocols
def is_exclusive_alias(self, alias): def is_exclusive_alias(self, alias):
return self._is_exclusive(ApplicationService.NS_ALIASES, alias) return self._is_exclusive(ApplicationService.NS_ALIASES, alias)

View file

@ -122,6 +122,15 @@ def _load_appservice(hostname, as_info, config_filename):
raise ValueError( raise ValueError(
"Missing/bad type 'exclusive' key in %s", regex_obj "Missing/bad type 'exclusive' key in %s", regex_obj
) )
# protocols check
protocols = as_info.get("protocols")
if protocols:
# Because strings are lists in python
if isinstance(protocols, str) or not isinstance(protocols, list):
raise KeyError("Optional 'protocols' must be a list if present.")
for p in protocols:
if not isinstance(p, str):
raise KeyError("Bad value for 'protocols' item")
return ApplicationService( return ApplicationService(
token=as_info["as_token"], token=as_info["as_token"],
url=as_info["url"], url=as_info["url"],
@ -129,4 +138,5 @@ def _load_appservice(hostname, as_info, config_filename):
hs_token=as_info["hs_token"], hs_token=as_info["hs_token"],
sender=user_id, sender=user_id,
id=as_info["id"], id=as_info["id"],
protocols=protocols,
) )

View file

@ -191,9 +191,11 @@ class ApplicationServicesHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_services_for_3pn(self, protocol): def _get_services_for_3pn(self, protocol):
# TODO(paul): Filter by protocol
services = yield self.store.get_app_services() services = yield self.store.get_app_services()
defer.returnValue(services) interested_list = [
s for s in services if s.is_interested_in_protocol(protocol)
]
defer.returnValue(interested_list)
@defer.inlineCallbacks @defer.inlineCallbacks
def _is_unknown_user(self, user_id): def _is_unknown_user(self, user_id):