Merge branch 'develop' into rav/flatten_sync_response

This commit is contained in:
Paul "LeoNerd" Evans 2015-11-19 17:21:03 +00:00
commit dd11bf8a79
17 changed files with 163 additions and 87 deletions

View file

@ -1,3 +1,8 @@
Changes in synapse v0.11.0-r2 (2015-11-19)
==========================================
* Fix bug in database port script (PR #387)
Changes in synapse v0.11.0-r1 (2015-11-18) Changes in synapse v0.11.0-r1 (2015-11-18)
========================================== ==========================================

View file

@ -68,6 +68,7 @@ APPEND_ONLY_TABLES = [
"state_groups_state", "state_groups_state",
"event_to_state_groups", "event_to_state_groups",
"rejections", "rejections",
"event_search",
] ]
@ -229,6 +230,38 @@ class Porter(object):
if rows: if rows:
next_chunk = rows[-1][0] + 1 next_chunk = rows[-1][0] + 1
if table == "event_search":
# We have to treat event_search differently since it has a
# different structure in the two different databases.
def insert(txn):
sql = (
"INSERT INTO event_search (event_id, room_id, key, sender, vector)"
" VALUES (?,?,?,?,to_tsvector('english', ?))"
)
rows_dict = [
dict(zip(headers, row))
for row in rows
]
txn.executemany(sql, [
(
row["event_id"],
row["room_id"],
row["key"],
row["sender"],
row["value"],
)
for row in rows_dict
])
self.postgres_store._simple_update_one_txn(
txn,
table="port_from_sqlite3",
keyvalues={"table_name": table},
updatevalues={"rowid": next_chunk},
)
else:
self._convert_rows(table, headers, rows) self._convert_rows(table, headers, rows)
def insert(txn): def insert(txn):

View file

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.11.0-r1" __version__ = "0.11.0-r2"

View file

@ -587,10 +587,7 @@ class Auth(object):
def _get_user_from_macaroon(self, macaroon_str): def _get_user_from_macaroon(self, macaroon_str):
try: try:
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str) macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
self.validate_macaroon( self.validate_macaroon(macaroon, "access", False)
macaroon, "access",
[lambda c: c.startswith("time < ")]
)
user_prefix = "user_id = " user_prefix = "user_id = "
user = None user = None
@ -638,22 +635,34 @@ class Auth(object):
errcode=Codes.UNKNOWN_TOKEN errcode=Codes.UNKNOWN_TOKEN
) )
def validate_macaroon(self, macaroon, type_string, additional_validation_functions): def validate_macaroon(self, macaroon, type_string, verify_expiry):
"""
validate that a Macaroon is understood by and was signed by this server.
Args:
macaroon(pymacaroons.Macaroon): The macaroon to validate
type_string(str): The kind of token this is (e.g. "access", "refresh")
verify_expiry(bool): Whether to verify whether the macaroon has expired.
This should really always be True, but no clients currently implement
token refresh, so we can't enforce expiry yet.
"""
v = pymacaroons.Verifier() v = pymacaroons.Verifier()
v.satisfy_exact("gen = 1") v.satisfy_exact("gen = 1")
v.satisfy_exact("type = " + type_string) v.satisfy_exact("type = " + type_string)
v.satisfy_general(lambda c: c.startswith("user_id = ")) v.satisfy_general(lambda c: c.startswith("user_id = "))
v.satisfy_exact("guest = true") v.satisfy_exact("guest = true")
if verify_expiry:
v.satisfy_general(self._verify_expiry)
else:
v.satisfy_general(lambda c: c.startswith("time < "))
for validation_function in additional_validation_functions:
v.satisfy_general(validation_function)
v.verify(macaroon, self.hs.config.macaroon_secret_key) v.verify(macaroon, self.hs.config.macaroon_secret_key)
v = pymacaroons.Verifier() v = pymacaroons.Verifier()
v.satisfy_general(self._verify_recognizes_caveats) v.satisfy_general(self._verify_recognizes_caveats)
v.verify(macaroon, self.hs.config.macaroon_secret_key) v.verify(macaroon, self.hs.config.macaroon_secret_key)
def verify_expiry(self, caveat): def _verify_expiry(self, caveat):
prefix = "time < " prefix = "time < "
if not caveat.startswith(prefix): if not caveat.startswith(prefix):
return False return False

View file

@ -54,7 +54,7 @@ class Filtering(object):
] ]
room_level_definitions = [ room_level_definitions = [
"state", "timeline", "ephemeral", "private_user_data" "state", "timeline", "ephemeral", "account_data"
] ]
for key in top_level_definitions: for key in top_level_definitions:
@ -131,8 +131,8 @@ class FilterCollection(object):
self.filter_json.get("room", {}).get("ephemeral", {}) self.filter_json.get("room", {}).get("ephemeral", {})
) )
self.room_private_user_data = Filter( self.room_account_data = Filter(
self.filter_json.get("room", {}).get("private_user_data", {}) self.filter_json.get("room", {}).get("account_data", {})
) )
self.presence_filter = Filter( self.presence_filter = Filter(
@ -160,8 +160,8 @@ class FilterCollection(object):
def filter_room_ephemeral(self, events): def filter_room_ephemeral(self, events):
return self.room_ephemeral_filter.filter(events) return self.room_ephemeral_filter.filter(events)
def filter_room_private_user_data(self, events): def filter_room_account_data(self, events):
return self.room_private_user_data.filter(events) return self.room_account_data.filter(events)
class Filter(object): class Filter(object):

View file

@ -25,18 +25,29 @@ class ConfigError(Exception):
pass pass
# We split these messages out to allow packages to override with package
# specific instructions.
MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS = """\
Please opt in or out of reporting anonymized homeserver usage statistics, by
setting the `report_stats` key in your config file to either True or False.
"""
MISSING_REPORT_STATS_SPIEL = """\
We would really appreciate it if you could help our project out by reporting
anonymized usage statistics from your homeserver. Only very basic aggregate
data (e.g. number of users) will be reported, but it helps us to track the
growth of the Matrix community, and helps us to make Matrix a success, as well
as to convince other networks that they should peer with us.
Thank you.
"""
MISSING_SERVER_NAME = """\
Missing mandatory `server_name` config option.
"""
class Config(object): class Config(object):
stats_reporting_begging_spiel = (
"We would really appreciate it if you could help our project out by"
" reporting anonymized usage statistics from your homeserver. Only very"
" basic aggregate data (e.g. number of users) will be reported, but it"
" helps us to track the growth of the Matrix community, and helps us to"
" make Matrix a success, as well as to convince other networks that they"
" should peer with us."
"\nThank you."
)
@staticmethod @staticmethod
def parse_size(value): def parse_size(value):
if isinstance(value, int) or isinstance(value, long): if isinstance(value, int) or isinstance(value, long):
@ -215,7 +226,7 @@ class Config(object):
if config_args.report_stats is None: if config_args.report_stats is None:
config_parser.error( config_parser.error(
"Please specify either --report-stats=yes or --report-stats=no\n\n" + "Please specify either --report-stats=yes or --report-stats=no\n\n" +
cls.stats_reporting_begging_spiel MISSING_REPORT_STATS_SPIEL
) )
if not config_files: if not config_files:
config_parser.error( config_parser.error(
@ -290,6 +301,10 @@ class Config(object):
yaml_config = cls.read_config_file(config_file) yaml_config = cls.read_config_file(config_file)
specified_config.update(yaml_config) specified_config.update(yaml_config)
if "server_name" not in specified_config:
sys.stderr.write("\n" + MISSING_SERVER_NAME + "\n")
sys.exit(1)
server_name = specified_config["server_name"] server_name = specified_config["server_name"]
_, config = obj.generate_config( _, config = obj.generate_config(
config_dir_path=config_dir_path, config_dir_path=config_dir_path,
@ -299,11 +314,8 @@ class Config(object):
config.update(specified_config) config.update(specified_config)
if "report_stats" not in config: if "report_stats" not in config:
sys.stderr.write( sys.stderr.write(
"Please opt in or out of reporting anonymized homeserver usage " "\n" + MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" +
"statistics, by setting the report_stats key in your config file " MISSING_REPORT_STATS_SPIEL + "\n")
" ( " + config_path + " ) " +
"to either True or False.\n\n" +
Config.stats_reporting_begging_spiel + "\n")
sys.exit(1) sys.exit(1)
if generate_keys: if generate_keys:

View file

@ -16,19 +16,19 @@
from twisted.internet import defer from twisted.internet import defer
class PrivateUserDataEventSource(object): class AccountDataEventSource(object):
def __init__(self, hs): def __init__(self, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
def get_current_key(self, direction='f'): def get_current_key(self, direction='f'):
return self.store.get_max_private_user_data_stream_id() return self.store.get_max_account_data_stream_id()
@defer.inlineCallbacks @defer.inlineCallbacks
def get_new_events(self, user, from_key, **kwargs): def get_new_events(self, user, from_key, **kwargs):
user_id = user.to_string() user_id = user.to_string()
last_stream_id = from_key last_stream_id = from_key
current_stream_id = yield self.store.get_max_private_user_data_stream_id() current_stream_id = yield self.store.get_max_account_data_stream_id()
tags = yield self.store.get_updated_tags(user_id, last_stream_id) tags = yield self.store.get_updated_tags(user_id, last_stream_id)
results = [] results = []

View file

@ -407,7 +407,7 @@ class AuthHandler(BaseHandler):
try: try:
macaroon = pymacaroons.Macaroon.deserialize(login_token) macaroon = pymacaroons.Macaroon.deserialize(login_token)
auth_api = self.hs.get_auth() auth_api = self.hs.get_auth()
auth_api.validate_macaroon(macaroon, "login", [auth_api.verify_expiry]) auth_api.validate_macaroon(macaroon, "login", True)
return self._get_user_from_macaroon(macaroon) return self._get_user_from_macaroon(macaroon)
except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError): except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
raise AuthError(401, "Invalid token", errcode=Codes.UNKNOWN_TOKEN) raise AuthError(401, "Invalid token", errcode=Codes.UNKNOWN_TOKEN)

View file

@ -436,14 +436,14 @@ class MessageHandler(BaseHandler):
for c in current_state.values() for c in current_state.values()
] ]
private_user_data = [] account_data = []
tags = tags_by_room.get(event.room_id) tags = tags_by_room.get(event.room_id)
if tags: if tags:
private_user_data.append({ account_data.append({
"type": "m.tag", "type": "m.tag",
"content": {"tags": tags}, "content": {"tags": tags},
}) })
d["private_user_data"] = private_user_data d["account_data"] = account_data
except: except:
logger.exception("Failed to get snapshot") logger.exception("Failed to get snapshot")
@ -498,14 +498,14 @@ class MessageHandler(BaseHandler):
user_id, room_id, pagin_config, membership, member_event_id, is_guest user_id, room_id, pagin_config, membership, member_event_id, is_guest
) )
private_user_data = [] account_data = []
tags = yield self.store.get_tags_for_room(user_id, room_id) tags = yield self.store.get_tags_for_room(user_id, room_id)
if tags: if tags:
private_user_data.append({ account_data.append({
"type": "m.tag", "type": "m.tag",
"content": {"tags": tags}, "content": {"tags": tags},
}) })
result["private_user_data"] = private_user_data result["account_data"] = account_data
defer.returnValue(result) defer.returnValue(result)

View file

@ -51,7 +51,7 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
"timeline", # TimelineBatch "timeline", # TimelineBatch
"state", # dict[(str, str), FrozenEvent] "state", # dict[(str, str), FrozenEvent]
"ephemeral", "ephemeral",
"private_user_data", "account_data",
])): ])):
__slots__ = [] __slots__ = []
@ -63,7 +63,7 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
self.timeline self.timeline
or self.state or self.state
or self.ephemeral or self.ephemeral
or self.private_user_data or self.account_data
) )
@ -71,7 +71,7 @@ class ArchivedSyncResult(collections.namedtuple("JoinedSyncResult", [
"room_id", # str "room_id", # str
"timeline", # TimelineBatch "timeline", # TimelineBatch
"state", # dict[(str, str), FrozenEvent] "state", # dict[(str, str), FrozenEvent]
"private_user_data", "account_data",
])): ])):
__slots__ = [] __slots__ = []
@ -82,7 +82,7 @@ class ArchivedSyncResult(collections.namedtuple("JoinedSyncResult", [
return bool( return bool(
self.timeline self.timeline
or self.state or self.state
or self.private_user_data or self.account_data
) )
@ -261,20 +261,20 @@ class SyncHandler(BaseHandler):
timeline=batch, timeline=batch,
state=current_state, state=current_state,
ephemeral=ephemeral_by_room.get(room_id, []), ephemeral=ephemeral_by_room.get(room_id, []),
private_user_data=self.private_user_data_for_room( account_data=self.account_data_for_room(
room_id, tags_by_room room_id, tags_by_room
), ),
)) ))
def private_user_data_for_room(self, room_id, tags_by_room): def account_data_for_room(self, room_id, tags_by_room):
private_user_data = [] account_data = []
tags = tags_by_room.get(room_id) tags = tags_by_room.get(room_id)
if tags is not None: if tags is not None:
private_user_data.append({ account_data.append({
"type": "m.tag", "type": "m.tag",
"content": {"tags": tags}, "content": {"tags": tags},
}) })
return private_user_data return account_data
@defer.inlineCallbacks @defer.inlineCallbacks
def ephemeral_by_room(self, sync_config, now_token, since_token=None): def ephemeral_by_room(self, sync_config, now_token, since_token=None):
@ -357,7 +357,7 @@ class SyncHandler(BaseHandler):
room_id=room_id, room_id=room_id,
timeline=batch, timeline=batch,
state=leave_state, state=leave_state,
private_user_data=self.private_user_data_for_room( account_data=self.account_data_for_room(
room_id, tags_by_room room_id, tags_by_room
), ),
)) ))
@ -412,7 +412,7 @@ class SyncHandler(BaseHandler):
tags_by_room = yield self.store.get_updated_tags( tags_by_room = yield self.store.get_updated_tags(
sync_config.user.to_string(), sync_config.user.to_string(),
since_token.private_user_data_key, since_token.account_data_key,
) )
joined = [] joined = []
@ -468,7 +468,7 @@ class SyncHandler(BaseHandler):
), ),
state=state, state=state,
ephemeral=ephemeral_by_room.get(room_id, []), ephemeral=ephemeral_by_room.get(room_id, []),
private_user_data=self.private_user_data_for_room( account_data=self.account_data_for_room(
room_id, tags_by_room room_id, tags_by_room
), ),
) )
@ -605,7 +605,7 @@ class SyncHandler(BaseHandler):
timeline=batch, timeline=batch,
state=state, state=state,
ephemeral=ephemeral_by_room.get(room_id, []), ephemeral=ephemeral_by_room.get(room_id, []),
private_user_data=self.private_user_data_for_room( account_data=self.account_data_for_room(
room_id, tags_by_room room_id, tags_by_room
), ),
) )
@ -653,7 +653,7 @@ class SyncHandler(BaseHandler):
room_id=leave_event.room_id, room_id=leave_event.room_id,
timeline=batch, timeline=batch,
state=state_events_delta, state=state_events_delta,
private_user_data=self.private_user_data_for_room( account_data=self.account_data_for_room(
leave_event.room_id, tags_by_room leave_event.room_id, tags_by_room
), ),
) )

View file

@ -274,8 +274,8 @@ class SyncRestServlet(RestServlet):
serialized_state = [serialize(e) for e in state_events] serialized_state = [serialize(e) for e in state_events]
serialized_timeline = [serialize(e) for e in timeline_events] serialized_timeline = [serialize(e) for e in timeline_events]
private_user_data = filter.filter_room_private_user_data( account_data = filter.filter_room_account_data(
room.private_user_data room.account_data
) )
result = { result = {
@ -285,7 +285,7 @@ class SyncRestServlet(RestServlet):
"limited": room.timeline.limited, "limited": room.timeline.limited,
}, },
"state": {"events": serialized_state}, "state": {"events": serialized_state},
"private_user_data": {"events": private_user_data}, "account_data": {"events": account_data},
} }
if joined: if joined:

View file

@ -81,7 +81,7 @@ class TagServlet(RestServlet):
max_id = yield self.store.add_tag_to_room(user_id, room_id, tag, body) max_id = yield self.store.add_tag_to_room(user_id, room_id, tag, body)
yield self.notifier.on_new_event( yield self.notifier.on_new_event(
"private_user_data_key", max_id, users=[user_id] "account_data_key", max_id, users=[user_id]
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -95,7 +95,7 @@ class TagServlet(RestServlet):
max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag) max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag)
yield self.notifier.on_new_event( yield self.notifier.on_new_event(
"private_user_data_key", max_id, users=[user_id] "account_data_key", max_id, users=[user_id]
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))

View file

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 25 SCHEMA_VERSION = 26
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))

View file

@ -0,0 +1,17 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
ALTER TABLE private_user_data_max_stream_id RENAME TO account_data_max_stream_id;

View file

@ -28,17 +28,17 @@ class TagsStore(SQLBaseStore):
def __init__(self, hs): def __init__(self, hs):
super(TagsStore, self).__init__(hs) super(TagsStore, self).__init__(hs)
self._private_user_data_id_gen = StreamIdGenerator( self._account_data_id_gen = StreamIdGenerator(
"private_user_data_max_stream_id", "stream_id" "account_data_max_stream_id", "stream_id"
) )
def get_max_private_user_data_stream_id(self): def get_max_account_data_stream_id(self):
"""Get the current max stream id for the private user data stream """Get the current max stream id for the private user data stream
Returns: Returns:
A deferred int. A deferred int.
""" """
return self._private_user_data_id_gen.get_max_token(self) return self._account_data_id_gen.get_max_token(self)
@cached() @cached()
def get_tags_for_user(self, user_id): def get_tags_for_user(self, user_id):
@ -144,12 +144,12 @@ class TagsStore(SQLBaseStore):
) )
self._update_revision_txn(txn, user_id, room_id, next_id) self._update_revision_txn(txn, user_id, room_id, next_id)
with (yield self._private_user_data_id_gen.get_next(self)) as next_id: with (yield self._account_data_id_gen.get_next(self)) as next_id:
yield self.runInteraction("add_tag", add_tag_txn, next_id) yield self.runInteraction("add_tag", add_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,)) self.get_tags_for_user.invalidate((user_id,))
result = yield self._private_user_data_id_gen.get_max_token(self) result = yield self._account_data_id_gen.get_max_token(self)
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -166,12 +166,12 @@ class TagsStore(SQLBaseStore):
txn.execute(sql, (user_id, room_id, tag)) txn.execute(sql, (user_id, room_id, tag))
self._update_revision_txn(txn, user_id, room_id, next_id) self._update_revision_txn(txn, user_id, room_id, next_id)
with (yield self._private_user_data_id_gen.get_next(self)) as next_id: with (yield self._account_data_id_gen.get_next(self)) as next_id:
yield self.runInteraction("remove_tag", remove_tag_txn, next_id) yield self.runInteraction("remove_tag", remove_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,)) self.get_tags_for_user.invalidate((user_id,))
result = yield self._private_user_data_id_gen.get_max_token(self) result = yield self._account_data_id_gen.get_max_token(self)
defer.returnValue(result) defer.returnValue(result)
def _update_revision_txn(self, txn, user_id, room_id, next_id): def _update_revision_txn(self, txn, user_id, room_id, next_id):
@ -185,7 +185,7 @@ class TagsStore(SQLBaseStore):
""" """
update_max_id_sql = ( update_max_id_sql = (
"UPDATE private_user_data_max_stream_id" "UPDATE account_data_max_stream_id"
" SET stream_id = ?" " SET stream_id = ?"
" WHERE stream_id < ?" " WHERE stream_id < ?"
) )

View file

@ -21,7 +21,7 @@ from synapse.handlers.presence import PresenceEventSource
from synapse.handlers.room import RoomEventSource from synapse.handlers.room import RoomEventSource
from synapse.handlers.typing import TypingNotificationEventSource from synapse.handlers.typing import TypingNotificationEventSource
from synapse.handlers.receipts import ReceiptEventSource from synapse.handlers.receipts import ReceiptEventSource
from synapse.handlers.private_user_data import PrivateUserDataEventSource from synapse.handlers.account_data import AccountDataEventSource
class EventSources(object): class EventSources(object):
@ -30,7 +30,7 @@ class EventSources(object):
"presence": PresenceEventSource, "presence": PresenceEventSource,
"typing": TypingNotificationEventSource, "typing": TypingNotificationEventSource,
"receipt": ReceiptEventSource, "receipt": ReceiptEventSource,
"private_user_data": PrivateUserDataEventSource, "account_data": AccountDataEventSource,
} }
def __init__(self, hs): def __init__(self, hs):
@ -54,8 +54,8 @@ class EventSources(object):
receipt_key=( receipt_key=(
yield self.sources["receipt"].get_current_key() yield self.sources["receipt"].get_current_key()
), ),
private_user_data_key=( account_data_key=(
yield self.sources["private_user_data"].get_current_key() yield self.sources["account_data"].get_current_key()
), ),
) )
defer.returnValue(token) defer.returnValue(token)

View file

@ -103,7 +103,7 @@ class StreamToken(
"presence_key", "presence_key",
"typing_key", "typing_key",
"receipt_key", "receipt_key",
"private_user_data_key", "account_data_key",
)) ))
): ):
_SEPARATOR = "_" _SEPARATOR = "_"
@ -138,7 +138,7 @@ class StreamToken(
or (int(other.presence_key) < int(self.presence_key)) or (int(other.presence_key) < int(self.presence_key))
or (int(other.typing_key) < int(self.typing_key)) or (int(other.typing_key) < int(self.typing_key))
or (int(other.receipt_key) < int(self.receipt_key)) or (int(other.receipt_key) < int(self.receipt_key))
or (int(other.private_user_data_key) < int(self.private_user_data_key)) or (int(other.account_data_key) < int(self.account_data_key))
) )
def copy_and_advance(self, key, new_value): def copy_and_advance(self, key, new_value):