Initial stab at real SQL storage implementation of user filter definitions

This commit is contained in:
Paul "LeoNerd" Evans 2015-01-27 18:46:03 +00:00
parent 0c14a699bb
commit 06cc147012
4 changed files with 78 additions and 15 deletions

View file

@ -61,6 +61,7 @@ SCHEMAS = [
"event_edges", "event_edges",
"event_signatures", "event_signatures",
"media_repository", "media_repository",
"filtering",
] ]

View file

@ -17,6 +17,8 @@ from twisted.internet import defer
from ._base import SQLBaseStore from ._base import SQLBaseStore
import json
# TODO(paul) # TODO(paul)
_filters_for_user = {} _filters_for_user = {}
@ -25,22 +27,41 @@ _filters_for_user = {}
class FilteringStore(SQLBaseStore): class FilteringStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_user_filter(self, user_localpart, filter_id): def get_user_filter(self, user_localpart, filter_id):
filters = _filters_for_user.get(user_localpart, None) def_json = yield self._simple_select_one_onecol(
table="user_filters",
keyvalues={
"user_id": user_localpart,
"filter_id": filter_id,
},
retcol="definition",
allow_none=False,
)
if not filters or filter_id >= len(filters): defer.returnValue(json.loads(def_json))
raise KeyError()
# trivial yield to make it a generator so d.iC works
yield
defer.returnValue(filters[filter_id])
@defer.inlineCallbacks
def add_user_filter(self, user_localpart, definition): def add_user_filter(self, user_localpart, definition):
filters = _filters_for_user.setdefault(user_localpart, []) def_json = json.dumps(definition)
filter_id = len(filters) # Need an atomic transaction to SELECT the maximal ID so far then
filters.append(definition) # INSERT a new one
def _do_txn(txn):
sql = (
"SELECT MAX(filter_id) FROM user_filters "
"WHERE user_id = ?"
)
txn.execute(sql, (user_localpart,))
max_id = txn.fetchone()[0]
if max_id is None:
filter_id = 0
else:
filter_id = max_id + 1
# trivial yield, see above sql = (
yield "INSERT INTO user_filters (user_id, filter_id, definition)"
defer.returnValue(filter_id) "VALUES(?, ?, ?)"
)
txn.execute(sql, (user_localpart, filter_id, def_json))
return filter_id
return self.runInteraction("add_user_filter", _do_txn)

View file

@ -0,0 +1,24 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS user_filters(
user_id TEXT,
filter_id INTEGER,
definition TEXT,
FOREIGN KEY(user_id) REFERENCES users(id)
);
CREATE INDEX IF NOT EXISTS user_filters_by_user_id_filter_id ON user_filters(
user_id, filter_id
);

View file

@ -53,16 +53,33 @@ class FilteringTestCase(unittest.TestCase):
self.filtering = hs.get_filtering() self.filtering = hs.get_filtering()
self.datastore = hs.get_datastore()
@defer.inlineCallbacks @defer.inlineCallbacks
def test_filter(self): def test_add_filter(self):
filter_id = yield self.filtering.add_user_filter( filter_id = yield self.filtering.add_user_filter(
user_localpart=user_localpart, user_localpart=user_localpart,
definition={"type": ["m.*"]}, definition={"type": ["m.*"]},
) )
self.assertEquals(filter_id, 0) self.assertEquals(filter_id, 0)
self.assertEquals({"type": ["m.*"]},
(yield self.datastore.get_user_filter(
user_localpart=user_localpart,
filter_id=0,
))
)
@defer.inlineCallbacks
def test_get_filter(self):
filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart,
definition={"type": ["m.*"]},
)
filter = yield self.filtering.get_user_filter( filter = yield self.filtering.get_user_filter(
user_localpart=user_localpart, user_localpart=user_localpart,
filter_id=filter_id, filter_id=filter_id,
) )
self.assertEquals(filter, {"type": ["m.*"]}) self.assertEquals(filter, {"type": ["m.*"]})