0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-06-18 10:38:21 +02:00

Have the Filtering API return Deferreds, so we can do the Datastore implementation nicely

This commit is contained in:
Paul "LeoNerd" Evans 2015-01-27 16:17:56 +00:00
parent b1503112ce
commit 059651efa1
3 changed files with 22 additions and 7 deletions

View file

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
# TODO(paul)
_filters_for_user = {}
@ -24,18 +26,28 @@ class Filtering(object):
super(Filtering, self).__init__()
self.hs = hs
@defer.inlineCallbacks
def get_user_filter(self, user_localpart, filter_id):
filters = _filters_for_user.get(user_localpart, None)
if not filters or filter_id >= len(filters):
raise KeyError()
return filters[filter_id]
# trivial yield to make it a generator so d.iC works
yield
defer.returnValue(filters[filter_id])
@defer.inlineCallbacks
def add_user_filter(self, user_localpart, definition):
filters = _filters_for_user.setdefault(user_localpart, [])
filter_id = len(filters)
filters.append(definition)
return filter_id
# trivial yield, see above
yield
defer.returnValue(filter_id)
# TODO(paul): surely we should probably add a delete_user_filter or
# replace_user_filter at some point? There's no REST API specified for
# them however

View file

@ -54,10 +54,12 @@ class GetFilterRestServlet(RestServlet):
raise SynapseError(400, "Invalid filter_id")
try:
defer.returnValue((200, self.filtering.get_user_filter(
filter = yield self.filtering.get_user_filter(
user_localpart=target_user.localpart,
filter_id=filter_id,
)))
)
defer.returnValue((200, filter))
except KeyError:
raise SynapseError(400, "No such filter")
@ -89,7 +91,7 @@ class CreateFilterRestServlet(RestServlet):
except:
raise SynapseError(400, "Invalid filter definition")
filter_id = self.filtering.add_user_filter(
filter_id = yield self.filtering.add_user_filter(
user_localpart=target_user.localpart,
definition=content,
)

View file

@ -53,14 +53,15 @@ class FilteringTestCase(unittest.TestCase):
self.filtering = hs.get_filtering()
@defer.inlineCallbacks
def test_filter(self):
filter_id = self.filtering.add_user_filter(
filter_id = yield self.filtering.add_user_filter(
user_localpart=user_localpart,
definition={"type": ["m.*"]},
)
self.assertEquals(filter_id, 0)
filter = self.filtering.get_user_filter(
filter = yield self.filtering.get_user_filter(
user_localpart=user_localpart,
filter_id=filter_id,
)