Merge pull request #1164 from pik/error-codes

Clarify Error codes for GET /filter/
This commit is contained in:
Erik Johnston 2016-10-19 14:26:17 +01:00 committed by GitHub
commit 78c083f159
4 changed files with 98 additions and 52 deletions

View file

@ -15,7 +15,7 @@
from twisted.internet import defer
from synapse.api.errors import AuthError, SynapseError
from synapse.api.errors import AuthError, SynapseError, StoreError, Codes
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import UserID
@ -45,7 +45,7 @@ class GetFilterRestServlet(RestServlet):
raise AuthError(403, "Cannot get filters for other users")
if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only get filters for local users")
raise AuthError(403, "Can only get filters for local users")
try:
filter_id = int(filter_id)
@ -59,8 +59,8 @@ class GetFilterRestServlet(RestServlet):
)
defer.returnValue((200, filter.get_filter_json()))
except KeyError:
raise SynapseError(400, "No such filter")
except (KeyError, StoreError):
raise SynapseError(400, "No such filter", errcode=Codes.NOT_FOUND)
class CreateFilterRestServlet(RestServlet):
@ -74,6 +74,7 @@ class CreateFilterRestServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request, user_id):
target_user = UserID.from_string(user_id)
requester = yield self.auth.get_user_by_req(request)
@ -81,10 +82,9 @@ class CreateFilterRestServlet(RestServlet):
raise AuthError(403, "Cannot create filters for other users")
if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only create filters for local users")
raise AuthError(403, "Can only create filters for local users")
content = parse_json_object_from_request(request)
filter_id = yield self.filtering.add_user_filter(
user_localpart=target_user.localpart,
user_filter=content,

View file

@ -85,7 +85,6 @@ class LoggingTransaction(object):
sql_logger.debug("[SQL] {%s} %s", self.name, sql)
sql = self.database_engine.convert_param_style(sql)
if args:
try:
sql_logger.debug(

View file

@ -15,78 +15,125 @@
from twisted.internet import defer
from . import V2AlphaRestTestCase
from tests import unittest
from synapse.rest.client.v2_alpha import filter
from synapse.api.errors import StoreError
from synapse.api.errors import Codes
import synapse.types
from synapse.types import UserID
from ....utils import MockHttpResource, setup_test_homeserver
PATH_PREFIX = "/_matrix/client/v2_alpha"
class FilterTestCase(V2AlphaRestTestCase):
class FilterTestCase(unittest.TestCase):
USER_ID = "@apple:test"
EXAMPLE_FILTER = {"type": ["m.*"]}
EXAMPLE_FILTER_JSON = '{"type": ["m.*"]}'
TO_REGISTER = [filter]
def make_datastore_mock(self):
datastore = super(FilterTestCase, self).make_datastore_mock()
@defer.inlineCallbacks
def setUp(self):
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
self._user_filters = {}
self.hs = yield setup_test_homeserver(
http_client=None,
resource_for_client=self.mock_resource,
resource_for_federation=self.mock_resource,
)
def add_user_filter(user_localpart, definition):
filters = self._user_filters.setdefault(user_localpart, [])
filter_id = len(filters)
filters.append(definition)
return defer.succeed(filter_id)
datastore.add_user_filter = add_user_filter
self.auth = self.hs.get_auth()
def get_user_filter(user_localpart, filter_id):
if user_localpart not in self._user_filters:
raise StoreError(404, "No user")
filters = self._user_filters[user_localpart]
if filter_id >= len(filters):
raise StoreError(404, "No filter")
return defer.succeed(filters[filter_id])
datastore.get_user_filter = get_user_filter
def get_user_by_access_token(token=None, allow_guest=False):
return {
"user": UserID.from_string(self.USER_ID),
"token_id": 1,
"is_guest": False,
}
return datastore
def get_user_by_req(request, allow_guest=False, rights="access"):
return synapse.types.create_requester(
UserID.from_string(self.USER_ID), 1, False, None)
self.auth.get_user_by_access_token = get_user_by_access_token
self.auth.get_user_by_req = get_user_by_req
self.store = self.hs.get_datastore()
self.filtering = self.hs.get_filtering()
for r in self.TO_REGISTER:
r.register_servlets(self.hs, self.mock_resource)
@defer.inlineCallbacks
def test_add_filter(self):
(code, response) = yield self.mock_resource.trigger(
"POST", "/user/%s/filter" % (self.USER_ID), '{"type": ["m.*"]}'
"POST", "/user/%s/filter" % (self.USER_ID), self.EXAMPLE_FILTER_JSON
)
self.assertEquals(200, code)
self.assertEquals({"filter_id": "0"}, response)
filter = yield self.store.get_user_filter(
user_localpart='apple',
filter_id=0,
)
self.assertEquals(filter, self.EXAMPLE_FILTER)
self.assertIn("apple", self._user_filters)
self.assertEquals(len(self._user_filters["apple"]), 1)
self.assertEquals({"type": ["m.*"]}, self._user_filters["apple"][0])
@defer.inlineCallbacks
def test_add_filter_for_other_user(self):
(code, response) = yield self.mock_resource.trigger(
"POST", "/user/%s/filter" % ('@watermelon:test'), self.EXAMPLE_FILTER_JSON
)
self.assertEquals(403, code)
self.assertEquals(response['errcode'], Codes.FORBIDDEN)
@defer.inlineCallbacks
def test_add_filter_non_local_user(self):
_is_mine = self.hs.is_mine
self.hs.is_mine = lambda target_user: False
(code, response) = yield self.mock_resource.trigger(
"POST", "/user/%s/filter" % (self.USER_ID), self.EXAMPLE_FILTER_JSON
)
self.hs.is_mine = _is_mine
self.assertEquals(403, code)
self.assertEquals(response['errcode'], Codes.FORBIDDEN)
@defer.inlineCallbacks
def test_get_filter(self):
self._user_filters["apple"] = [
{"type": ["m.*"]}
]
filter_id = yield self.filtering.add_user_filter(
user_localpart='apple',
user_filter=self.EXAMPLE_FILTER
)
(code, response) = yield self.mock_resource.trigger_get(
"/user/%s/filter/0" % (self.USER_ID)
"/user/%s/filter/%s" % (self.USER_ID, filter_id)
)
self.assertEquals(200, code)
self.assertEquals({"type": ["m.*"]}, response)
self.assertEquals(self.EXAMPLE_FILTER, response)
@defer.inlineCallbacks
def test_get_filter_non_existant(self):
(code, response) = yield self.mock_resource.trigger_get(
"/user/%s/filter/12382148321" % (self.USER_ID)
)
self.assertEquals(400, code)
self.assertEquals(response['errcode'], Codes.NOT_FOUND)
# Currently invalid params do not have an appropriate errcode
# in errors.py
@defer.inlineCallbacks
def test_get_filter_invalid_id(self):
(code, response) = yield self.mock_resource.trigger_get(
"/user/%s/filter/foobar" % (self.USER_ID)
)
self.assertEquals(400, code)
# No ID also returns an invalid_id error
@defer.inlineCallbacks
def test_get_filter_no_id(self):
self._user_filters["apple"] = [
{"type": ["m.*"]}
]
(code, response) = yield self.mock_resource.trigger_get(
"/user/%s/filter/2" % (self.USER_ID)
"/user/%s/filter/" % (self.USER_ID)
)
self.assertEquals(404, code)
@defer.inlineCallbacks
def test_get_filter_no_user(self):
(code, response) = yield self.mock_resource.trigger_get(
"/user/%s/filter/0" % (self.USER_ID)
)
self.assertEquals(404, code)
self.assertEquals(400, code)

View file

@ -189,7 +189,7 @@ class MockHttpResource(HttpServer):
)
defer.returnValue((code, response))
except CodeMessageException as e:
defer.returnValue((e.code, cs_error(e.msg)))
defer.returnValue((e.code, cs_error(e.msg, code=e.errcode)))
raise KeyError("No event can handle %s" % path)