Merge pull request #116 from matrix-org/application-services-registration-script

Application services registration changes (PR #116)
This commit is contained in:
Kegsay 2015-04-01 13:51:15 +01:00
commit 80a620a83a
12 changed files with 159 additions and 455 deletions

View file

@ -32,15 +32,13 @@ from twisted.web.resource import Resource
from twisted.web.static import File from twisted.web.static import File
from twisted.web.server import Site from twisted.web.server import Site
from synapse.http.server import JsonResource, RootRedirect from synapse.http.server import JsonResource, RootRedirect
from synapse.rest.appservice.v1 import AppServiceRestResource
from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource from synapse.rest.media.v1.media_repository import MediaRepositoryResource
from synapse.http.server_key_resource import LocalKey from synapse.http.server_key_resource import LocalKey
from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.api.urls import ( from synapse.api.urls import (
CLIENT_PREFIX, FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX, CLIENT_PREFIX, FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX,
SERVER_KEY_PREFIX, MEDIA_PREFIX, CLIENT_V2_ALPHA_PREFIX, APP_SERVICE_PREFIX, SERVER_KEY_PREFIX, MEDIA_PREFIX, CLIENT_V2_ALPHA_PREFIX, STATIC_PREFIX
STATIC_PREFIX
) )
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory from synapse.crypto import context_factory
@ -78,9 +76,6 @@ class SynapseHomeServer(HomeServer):
def build_resource_for_federation(self): def build_resource_for_federation(self):
return JsonResource(self) return JsonResource(self)
def build_resource_for_app_services(self):
return AppServiceRestResource(self)
def build_resource_for_web_client(self): def build_resource_for_web_client(self):
import syweb import syweb
syweb_path = os.path.dirname(syweb.__file__) syweb_path = os.path.dirname(syweb.__file__)
@ -141,7 +136,6 @@ class SynapseHomeServer(HomeServer):
(CONTENT_REPO_PREFIX, self.get_resource_for_content_repo()), (CONTENT_REPO_PREFIX, self.get_resource_for_content_repo()),
(SERVER_KEY_PREFIX, self.get_resource_for_server_key()), (SERVER_KEY_PREFIX, self.get_resource_for_server_key()),
(MEDIA_PREFIX, self.get_resource_for_media_repository()), (MEDIA_PREFIX, self.get_resource_for_media_repository()),
(APP_SERVICE_PREFIX, self.get_resource_for_app_services()),
(STATIC_PREFIX, self.get_resource_for_static_content()), (STATIC_PREFIX, self.get_resource_for_static_content()),
] ]

View file

@ -95,7 +95,7 @@ class ApplicationService(object):
# rooms: [ {regex: "[A-z]+.*", exclusive: true}, ...], # rooms: [ {regex: "[A-z]+.*", exclusive: true}, ...],
# } # }
if not namespaces: if not namespaces:
return None namespaces = {}
for ns in ApplicationService.NS_LIST: for ns in ApplicationService.NS_LIST:
if ns not in namespaces: if ns not in namespaces:

View file

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd # Copyright 2015 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -12,18 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from . import register
from synapse.http.server import JsonResource from ._base import Config
class AppServiceRestResource(JsonResource): class AppServiceConfig(Config):
"""A resource for version 1 of the matrix application service API."""
def __init__(self, hs): def __init__(self, args):
JsonResource.__init__(self, hs) super(AppServiceConfig, self).__init__(args)
self.register_servlets(self, hs) self.app_service_config_files = args.app_service_config_files
@staticmethod @classmethod
def register_servlets(appservice_resource, hs): def add_arguments(cls, parser):
register.register_servlets(hs, appservice_resource) super(AppServiceConfig, cls).add_arguments(parser)
group = parser.add_argument_group("appservice")
group.add_argument(
"--app-service-config-files", type=str, nargs='+',
help="A list of application service config files to use."
)

View file

@ -24,12 +24,13 @@ from .email import EmailConfig
from .voip import VoipConfig from .voip import VoipConfig
from .registration import RegistrationConfig from .registration import RegistrationConfig
from .metrics import MetricsConfig from .metrics import MetricsConfig
from .appservice import AppServiceConfig
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig, RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
EmailConfig, VoipConfig, RegistrationConfig, EmailConfig, VoipConfig, RegistrationConfig,
MetricsConfig,): MetricsConfig, AppServiceConfig,):
pass pass

View file

@ -16,10 +16,8 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.types import UserID from synapse.types import UserID
import synapse.util.stringutils as stringutils
import logging import logging
@ -49,38 +47,6 @@ class ApplicationServicesHandler(object):
self.scheduler = appservice_scheduler self.scheduler = appservice_scheduler
self.started_scheduler = False self.started_scheduler = False
@defer.inlineCallbacks
def register(self, app_service):
logger.info("Register -> %s", app_service)
# check the token is recognised
try:
stored_service = yield self.store.get_app_service_by_token(
app_service.token
)
if not stored_service:
raise StoreError(404, "Application service not found")
app_service.id = stored_service.id
except StoreError:
raise SynapseError(
403, "Unrecognised application services token. "
"Consult the home server admin.",
errcode=Codes.FORBIDDEN
)
app_service.hs_token = self._generate_hs_token()
# create a sender for this application service which is used when
# creating rooms, etc..
account = yield self.hs.get_handlers().registration_handler.register()
app_service.sender = account[0]
yield self.store.update_app_service(app_service)
defer.returnValue(app_service)
@defer.inlineCallbacks
def unregister(self, token):
logger.info("Unregister as_token=%s", token)
yield self.store.unregister_app_service(token)
@defer.inlineCallbacks @defer.inlineCallbacks
def notify_interested_services(self, event): def notify_interested_services(self, event):
"""Notifies (pushes) all application services interested in this event. """Notifies (pushes) all application services interested in this event.
@ -223,6 +189,3 @@ class ApplicationServicesHandler(object):
exists = yield self.query_user_exists(user_id) exists = yield self.query_user_exists(user_id)
defer.returnValue(exists) defer.returnValue(exists)
defer.returnValue(True) defer.returnValue(True)
def _generate_hs_token(self):
return stringutils.random_string(24)

View file

@ -1,14 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -1,48 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains base REST classes for constructing client v1 servlets.
"""
from synapse.http.servlet import RestServlet
from synapse.api.urls import APP_SERVICE_PREFIX
import re
import logging
logger = logging.getLogger(__name__)
def as_path_pattern(path_regex):
"""Creates a regex compiled appservice path with the correct path
prefix.
Args:
path_regex (str): The regex string to match. This should NOT have a ^
as this will be prefixed.
Returns:
SRE_Pattern
"""
return re.compile("^" + APP_SERVICE_PREFIX + path_regex)
class AppServiceRestServlet(RestServlet):
"""A base Synapse REST Servlet for the application services version 1 API.
"""
def __init__(self, hs):
self.hs = hs
self.handler = hs.get_handlers().appservice_handler

View file

@ -1,99 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains REST servlets to do with registration: /register"""
from twisted.internet import defer
from base import AppServiceRestServlet, as_path_pattern
from synapse.api.errors import CodeMessageException, SynapseError
from synapse.storage.appservice import ApplicationService
import json
import logging
logger = logging.getLogger(__name__)
class RegisterRestServlet(AppServiceRestServlet):
"""Handles AS registration with the home server.
"""
PATTERN = as_path_pattern("/register$")
@defer.inlineCallbacks
def on_POST(self, request):
params = _parse_json(request)
# sanity check required params
try:
as_token = params["as_token"]
as_url = params["url"]
if (not isinstance(as_token, basestring) or
not isinstance(as_url, basestring)):
raise ValueError
except (KeyError, ValueError):
raise SynapseError(
400, "Missed required keys: as_token(str) / url(str)."
)
try:
app_service = ApplicationService(
as_token, as_url, params["namespaces"]
)
except ValueError as e:
raise SynapseError(400, e.message)
app_service = yield self.handler.register(app_service)
hs_token = app_service.hs_token
defer.returnValue((200, {
"hs_token": hs_token
}))
class UnregisterRestServlet(AppServiceRestServlet):
"""Handles AS registration with the home server.
"""
PATTERN = as_path_pattern("/unregister$")
def on_POST(self, request):
params = _parse_json(request)
try:
as_token = params["as_token"]
if not isinstance(as_token, basestring):
raise ValueError
except (KeyError, ValueError):
raise SynapseError(400, "Missing required key: as_token(str)")
yield self.handler.unregister(as_token)
raise CodeMessageException(500, "Not implemented")
def _parse_json(request):
try:
content = json.loads(request.content.read())
if type(content) != dict:
raise SynapseError(400, "Content must be a JSON object.")
return content
except ValueError as e:
logger.warn(e)
raise SynapseError(400, "Content not JSON.")
def register_servlets(hs, http_server):
RegisterRestServlet(hs).register(http_server)
UnregisterRestServlet(hs).register(http_server)

View file

@ -79,7 +79,6 @@ class BaseHomeServer(object):
'resource_for_content_repo', 'resource_for_content_repo',
'resource_for_server_key', 'resource_for_server_key',
'resource_for_media_repository', 'resource_for_media_repository',
'resource_for_app_services',
'resource_for_metrics', 'resource_for_metrics',
'event_sources', 'event_sources',
'ratelimiter', 'ratelimiter',

View file

@ -13,155 +13,35 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
import urllib
import yaml
from simplejson import JSONDecodeError from simplejson import JSONDecodeError
import simplejson as json import simplejson as json
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import Membership from synapse.api.constants import Membership
from synapse.api.errors import StoreError
from synapse.appservice import ApplicationService, AppServiceTransaction from synapse.appservice import ApplicationService, AppServiceTransaction
from synapse.storage.roommember import RoomsForUser from synapse.storage.roommember import RoomsForUser
from synapse.types import UserID
from ._base import SQLBaseStore from ._base import SQLBaseStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def log_failure(failure):
logger.error("Failed to detect application services: %s", failure.value)
logger.error(failure.getTraceback())
class ApplicationServiceStore(SQLBaseStore): class ApplicationServiceStore(SQLBaseStore):
def __init__(self, hs): def __init__(self, hs):
super(ApplicationServiceStore, self).__init__(hs) super(ApplicationServiceStore, self).__init__(hs)
self.hostname = hs.hostname
self.services_cache = [] self.services_cache = []
self.cache_defer = self._populate_appservice_cache() self._populate_appservice_cache(
self.cache_defer.addErrback(log_failure) hs.config.app_service_config_files
@defer.inlineCallbacks
def unregister_app_service(self, token):
"""Unregisters this service.
This removes all AS specific regex and the base URL. The token is the
only thing preserved for future registration attempts.
"""
yield self.cache_defer # make sure the cache is ready
yield self.runInteraction(
"unregister_app_service",
self._unregister_app_service_txn,
token,
)
# update cache TODO: Should this be in the txn?
for service in self.services_cache:
if service.token == token:
service.url = None
service.namespaces = None
service.hs_token = None
def _unregister_app_service_txn(self, txn, token):
# kill the url to prevent pushes
txn.execute(
"UPDATE application_services SET url=NULL WHERE token=?",
(token,)
) )
# cleanup regex
as_id = self._get_as_id_txn(txn, token)
if not as_id:
logger.warning(
"unregister_app_service_txn: Failed to find as_id for token=",
token
)
return False
txn.execute(
"DELETE FROM application_services_regex WHERE as_id=?",
(as_id,)
)
return True
@defer.inlineCallbacks
def update_app_service(self, service):
"""Update an application service, clobbering what was previously there.
Args:
service(ApplicationService): The updated service.
"""
yield self.cache_defer # make sure the cache is ready
# NB: There is no "insert" since we provide no public-facing API to
# allocate new ASes. It relies on the server admin inserting the AS
# token into the database manually.
if not service.token or not service.url:
raise StoreError(400, "Token and url must be specified.")
if not service.hs_token:
raise StoreError(500, "No HS token")
as_id = yield self.runInteraction(
"update_app_service",
self._update_app_service_txn,
service
)
service.id = as_id
# update cache TODO: Should this be in the txn?
for (index, cache_service) in enumerate(self.services_cache):
if service.token == cache_service.token:
self.services_cache[index] = service
logger.info("Updated: %s", service)
return
# new entry
self.services_cache.append(service)
logger.info("Updated(new): %s", service)
def _update_app_service_txn(self, txn, service):
as_id = self._get_as_id_txn(txn, service.token)
if not as_id:
logger.warning(
"update_app_service_txn: Failed to find as_id for token=",
service.token
)
return
txn.execute(
"UPDATE application_services SET url=?, hs_token=?, sender=? "
"WHERE id=?",
(service.url, service.hs_token, service.sender, as_id,)
)
# cleanup regex
txn.execute(
"DELETE FROM application_services_regex WHERE as_id=?",
(as_id,)
)
for (ns_int, ns_str) in enumerate(ApplicationService.NS_LIST):
if ns_str in service.namespaces:
for regex_obj in service.namespaces[ns_str]:
txn.execute(
"INSERT INTO application_services_regex("
"as_id, namespace, regex) values(?,?,?)",
(as_id, ns_int, json.dumps(regex_obj))
)
return as_id
def _get_as_id_txn(self, txn, token):
cursor = txn.execute(
"SELECT id FROM application_services WHERE token=?",
(token,)
)
res = cursor.fetchone()
if res:
return res[0]
@defer.inlineCallbacks
def get_app_services(self): def get_app_services(self):
yield self.cache_defer # make sure the cache is ready return defer.succeed(self.services_cache)
defer.returnValue(self.services_cache)
@defer.inlineCallbacks
def get_app_service_by_user_id(self, user_id): def get_app_service_by_user_id(self, user_id):
"""Retrieve an application service from their user ID. """Retrieve an application service from their user ID.
@ -175,37 +55,23 @@ class ApplicationServiceStore(SQLBaseStore):
Returns: Returns:
synapse.appservice.ApplicationService or None. synapse.appservice.ApplicationService or None.
""" """
yield self.cache_defer # make sure the cache is ready
for service in self.services_cache: for service in self.services_cache:
if service.sender == user_id: if service.sender == user_id:
defer.returnValue(service) return defer.succeed(service)
return return defer.succeed(None)
defer.returnValue(None)
@defer.inlineCallbacks def get_app_service_by_token(self, token):
def get_app_service_by_token(self, token, from_cache=True):
"""Get the application service with the given appservice token. """Get the application service with the given appservice token.
Args: Args:
token (str): The application service token. token (str): The application service token.
from_cache (bool): True to get this service from the cache, False to Returns:
check the database. synapse.appservice.ApplicationService or None.
Raises:
StoreError if there was a problem retrieving this service.
""" """
yield self.cache_defer # make sure the cache is ready
if from_cache:
for service in self.services_cache: for service in self.services_cache:
if service.token == token: if service.token == token:
defer.returnValue(service) return defer.succeed(service)
return return defer.succeed(None)
defer.returnValue(None)
# TODO: The from_cache=False impl
# TODO: This should be JOINed with the application_services_regex table.
def get_app_service_rooms(self, service): def get_app_service_rooms(self, service):
"""Get a list of RoomsForUser for this application service. """Get a list of RoomsForUser for this application service.
@ -336,18 +202,69 @@ class ApplicationServiceStore(SQLBaseStore):
)) ))
return service_list return service_list
@defer.inlineCallbacks def _load_appservice(self, as_info):
def _populate_appservice_cache(self): required_string_fields = [
"""Populates the ApplicationServiceCache from the database.""" "url", "as_token", "hs_token", "sender_localpart"
sql = ("SELECT r.*, a.* FROM application_services AS a LEFT JOIN " ]
"application_services_regex AS r ON a.id = r.as_id") for field in required_string_fields:
if not isinstance(as_info.get(field), basestring):
raise KeyError("Required string field: '%s'", field)
results = yield self._execute_and_decode("appservice_cache", sql) localpart = as_info["sender_localpart"]
services = self._parse_services_dict(results) if urllib.quote(localpart) != localpart:
raise ValueError(
"sender_localpart needs characters which are not URL encoded."
)
user = UserID(localpart, self.hostname)
user_id = user.to_string()
for service in services: # namespace checks
logger.info("Found application service: %s", service) if not isinstance(as_info.get("namespaces"), dict):
self.services_cache.append(service) raise KeyError("Requires 'namespaces' object.")
for ns in ApplicationService.NS_LIST:
# specific namespaces are optional
if ns in as_info["namespaces"]:
# expect a list of dicts with exclusive and regex keys
for regex_obj in as_info["namespaces"][ns]:
if not isinstance(regex_obj, dict):
raise ValueError(
"Expected namespace entry in %s to be an object,"
" but got %s", ns, regex_obj
)
if not isinstance(regex_obj.get("regex"), basestring):
raise ValueError(
"Missing/bad type 'regex' key in %s", regex_obj
)
if not isinstance(regex_obj.get("exclusive"), bool):
raise ValueError(
"Missing/bad type 'exclusive' key in %s", regex_obj
)
return ApplicationService(
token=as_info["as_token"],
url=as_info["url"],
namespaces=as_info["namespaces"],
hs_token=as_info["hs_token"],
sender=user_id,
id=as_info["as_token"] # the token is the only unique thing here
)
def _populate_appservice_cache(self, config_files):
"""Populates a cache of Application Services from the config files."""
if not isinstance(config_files, list):
logger.warning(
"Expected %s to be a list of AS config files.", config_files
)
return
for config_file in config_files:
try:
with open(config_file, 'r') as f:
appservice = self._load_appservice(yaml.load(f))
logger.info("Loaded application service: %s", appservice)
self.services_cache.append(appservice)
except Exception as e:
logger.error("Failed to load appservice from '%s'", config_file)
logger.exception(e)
class ApplicationServiceTransactionStore(SQLBaseStore): class ApplicationServiceTransactionStore(SQLBaseStore):
@ -365,16 +282,20 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
A Deferred which resolves to a list of ApplicationServices, which A Deferred which resolves to a list of ApplicationServices, which
may be empty. may be empty.
""" """
sql = ( results = yield self._simple_select_list(
"SELECT r.*, a.* FROM application_services_state AS s LEFT JOIN" "application_services_state",
" application_services AS a ON a.id=s.as_id LEFT JOIN" dict(state=state),
" application_services_regex AS r ON r.as_id=a.id WHERE state = ?" ["as_id"]
)
results = yield self._execute_and_decode(
"get_appservices_by_state", sql, state
) )
# NB: This assumes this class is linked with ApplicationServiceStore # NB: This assumes this class is linked with ApplicationServiceStore
defer.returnValue(self._parse_services_dict(results)) as_list = yield self.get_app_services()
services = []
for res in results:
for service in as_list:
if service.id == res["as_id"]:
services.append(service)
defer.returnValue(services)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_appservice_state(self, service): def get_appservice_state(self, service):

View file

@ -14,14 +14,13 @@
*/ */
CREATE TABLE IF NOT EXISTS application_services_state( CREATE TABLE IF NOT EXISTS application_services_state(
as_id INTEGER PRIMARY KEY, as_id TEXT PRIMARY KEY,
state TEXT, state TEXT,
last_txn TEXT, last_txn TEXT
FOREIGN KEY(as_id) REFERENCES application_services(id)
); );
CREATE TABLE IF NOT EXISTS application_services_txns( CREATE TABLE IF NOT EXISTS application_services_txns(
as_id INTEGER NOT NULL, as_id TEXT NOT NULL,
txn_id INTEGER NOT NULL, txn_id INTEGER NOT NULL,
event_ids TEXT NOT NULL, event_ids TEXT NOT NULL,
UNIQUE(as_id, txn_id) ON CONFLICT ROLLBACK UNIQUE(as_id, txn_id) ON CONFLICT ROLLBACK

View file

@ -22,6 +22,8 @@ from synapse.storage.appservice import (
) )
import json import json
import os
import yaml
from mock import Mock from mock import Mock
from tests.utils import SQLiteMemoryDbPool, MockClock from tests.utils import SQLiteMemoryDbPool, MockClock
@ -30,63 +32,39 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
self.as_yaml_files = []
db_pool = SQLiteMemoryDbPool() db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare() yield db_pool.prepare()
hs = HomeServer( hs = HomeServer(
"test", db_pool=db_pool, clock=MockClock(), config=Mock() "test", db_pool=db_pool, clock=MockClock(),
config=Mock(
app_service_config_files=self.as_yaml_files
) )
)
self.as_token = "token1" self.as_token = "token1"
db_pool.runQuery( self.as_url = "some_url"
"INSERT INTO application_services(token) VALUES(?)", self._add_appservice(self.as_token, self.as_url, "some_hs_token", "bob")
(self.as_token,) self._add_appservice("token2", "some_url", "some_hs_token", "bob")
) self._add_appservice("token3", "some_url", "some_hs_token", "bob")
db_pool.runQuery(
"INSERT INTO application_services(token) VALUES(?)", ("token2",)
)
db_pool.runQuery(
"INSERT INTO application_services(token) VALUES(?)", ("token3",)
)
# must be done after inserts # must be done after inserts
self.store = ApplicationServiceStore(hs) self.store = ApplicationServiceStore(hs)
@defer.inlineCallbacks def tearDown(self):
def test_update_and_retrieval_of_service(self): # TODO: suboptimal that we need to create files for tests!
url = "https://matrix.org/appservices/foobar" for f in self.as_yaml_files:
hs_token = "hstok" try:
user_regex = [ os.remove(f)
{"regex": "@foobar_.*:matrix.org", "exclusive": True} except:
] pass
alias_regex = [
{"regex": "#foobar_.*:matrix.org", "exclusive": False}
]
room_regex = [
] def _add_appservice(self, as_token, url, hs_token, sender):
service = ApplicationService( as_yaml = dict(url=url, as_token=as_token, hs_token=hs_token,
url=url, hs_token=hs_token, token=self.as_token, namespaces={ sender_localpart=sender, namespaces={})
ApplicationService.NS_USERS: user_regex, # use the token as the filename
ApplicationService.NS_ALIASES: alias_regex, with open(as_token, 'w') as outfile:
ApplicationService.NS_ROOMS: room_regex outfile.write(yaml.dump(as_yaml))
}) self.as_yaml_files.append(as_token)
yield self.store.update_app_service(service)
stored_service = yield self.store.get_app_service_by_token(
self.as_token
)
self.assertEquals(stored_service.token, self.as_token)
self.assertEquals(stored_service.url, url)
self.assertEquals(
stored_service.namespaces[ApplicationService.NS_ALIASES],
alias_regex
)
self.assertEquals(
stored_service.namespaces[ApplicationService.NS_ROOMS],
room_regex
)
self.assertEquals(
stored_service.namespaces[ApplicationService.NS_USERS],
user_regex
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_retrieve_unknown_service_token(self): def test_retrieve_unknown_service_token(self):
@ -99,7 +77,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
self.as_token self.as_token
) )
self.assertEquals(stored_service.token, self.as_token) self.assertEquals(stored_service.token, self.as_token)
self.assertEquals(stored_service.url, None) self.assertEquals(stored_service.url, self.as_url)
self.assertEquals( self.assertEquals(
stored_service.namespaces[ApplicationService.NS_ALIASES], stored_service.namespaces[ApplicationService.NS_ALIASES],
[] []
@ -123,42 +101,48 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
self.as_yaml_files = []
self.db_pool = SQLiteMemoryDbPool() self.db_pool = SQLiteMemoryDbPool()
yield self.db_pool.prepare() yield self.db_pool.prepare()
hs = HomeServer(
"test", db_pool=self.db_pool, clock=MockClock(), config=Mock()
)
self.as_list = [ self.as_list = [
{ {
"token": "token1", "token": "token1",
"url": "https://matrix-as.org", "url": "https://matrix-as.org",
"id": 3 "id": "token1"
}, },
{ {
"token": "alpha_tok", "token": "alpha_tok",
"url": "https://alpha.com", "url": "https://alpha.com",
"id": 5 "id": "alpha_tok"
}, },
{ {
"token": "beta_tok", "token": "beta_tok",
"url": "https://beta.com", "url": "https://beta.com",
"id": 6 "id": "beta_tok"
}, },
{ {
"token": "delta_tok", "token": "delta_tok",
"url": "https://delta.com", "url": "https://delta.com",
"id": 7 "id": "delta_tok"
}, },
] ]
for s in self.as_list: for s in self.as_list:
yield self._add_service(s["id"], s["url"], s["token"]) yield self._add_service(s["url"], s["token"])
hs = HomeServer(
"test", db_pool=self.db_pool, clock=MockClock(), config=Mock(
app_service_config_files=self.as_yaml_files
)
)
self.store = TestTransactionStore(hs) self.store = TestTransactionStore(hs)
def _add_service(self, as_id, url, token): def _add_service(self, url, as_token):
return self.db_pool.runQuery( as_yaml = dict(url=url, as_token=as_token, hs_token="something",
"INSERT INTO application_services(id, url, token) VALUES(?,?,?)", sender_localpart="a_sender", namespaces={})
(as_id, url, token) # use the token as the filename
) with open(as_token, 'w') as outfile:
outfile.write(yaml.dump(as_yaml))
self.as_yaml_files.append(as_token)
def _set_state(self, id, state, txn=None): def _set_state(self, id, state, txn=None):
return self.db_pool.runQuery( return self.db_pool.runQuery(
@ -410,8 +394,10 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
ApplicationServiceState.DOWN ApplicationServiceState.DOWN
) )
self.assertEquals(2, len(services)) self.assertEquals(2, len(services))
self.assertEquals(self.as_list[2]["id"], services[0].id) self.assertEquals(
self.assertEquals(self.as_list[0]["id"], services[1].id) set([self.as_list[2]["id"], self.as_list[0]["id"]]),
set([services[0].id, services[1].id])
)
# required for ApplicationServiceTransactionStoreTestCase tests # required for ApplicationServiceTransactionStoreTestCase tests