mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-15 15:33:50 +01:00
Merge branch 'develop' into dbkr/email_notifs_on_pusher
This commit is contained in:
commit
1f71f386f6
19 changed files with 511 additions and 27 deletions
|
@ -612,7 +612,8 @@ class Auth(object):
|
||||||
def get_user_from_macaroon(self, macaroon_str):
|
def get_user_from_macaroon(self, macaroon_str):
|
||||||
try:
|
try:
|
||||||
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
|
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
|
||||||
self.validate_macaroon(macaroon, "access", False)
|
|
||||||
|
self.validate_macaroon(macaroon, "access", self.hs.config.expire_access_token)
|
||||||
|
|
||||||
user_prefix = "user_id = "
|
user_prefix = "user_id = "
|
||||||
user = None
|
user = None
|
||||||
|
|
|
@ -57,6 +57,8 @@ class KeyConfig(Config):
|
||||||
seed = self.signing_key[0].seed
|
seed = self.signing_key[0].seed
|
||||||
self.macaroon_secret_key = hashlib.sha256(seed)
|
self.macaroon_secret_key = hashlib.sha256(seed)
|
||||||
|
|
||||||
|
self.expire_access_token = config.get("expire_access_token", False)
|
||||||
|
|
||||||
def default_config(self, config_dir_path, server_name, is_generating_file=False,
|
def default_config(self, config_dir_path, server_name, is_generating_file=False,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
base_key_name = os.path.join(config_dir_path, server_name)
|
base_key_name = os.path.join(config_dir_path, server_name)
|
||||||
|
@ -69,6 +71,9 @@ class KeyConfig(Config):
|
||||||
return """\
|
return """\
|
||||||
macaroon_secret_key: "%(macaroon_secret_key)s"
|
macaroon_secret_key: "%(macaroon_secret_key)s"
|
||||||
|
|
||||||
|
# Used to enable access token expiration.
|
||||||
|
expire_access_token: False
|
||||||
|
|
||||||
## Signing Keys ##
|
## Signing Keys ##
|
||||||
|
|
||||||
# Path to the signing key to sign messages with
|
# Path to the signing key to sign messages with
|
||||||
|
|
|
@ -32,6 +32,7 @@ class RegistrationConfig(Config):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.registration_shared_secret = config.get("registration_shared_secret")
|
self.registration_shared_secret = config.get("registration_shared_secret")
|
||||||
|
self.user_creation_max_duration = int(config["user_creation_max_duration"])
|
||||||
|
|
||||||
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
|
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
|
||||||
self.trusted_third_party_id_servers = config["trusted_third_party_id_servers"]
|
self.trusted_third_party_id_servers = config["trusted_third_party_id_servers"]
|
||||||
|
@ -54,6 +55,11 @@ class RegistrationConfig(Config):
|
||||||
# secret, even if registration is otherwise disabled.
|
# secret, even if registration is otherwise disabled.
|
||||||
registration_shared_secret: "%(registration_shared_secret)s"
|
registration_shared_secret: "%(registration_shared_secret)s"
|
||||||
|
|
||||||
|
# Sets the expiry for the short term user creation in
|
||||||
|
# milliseconds. For instance the bellow duration is two weeks
|
||||||
|
# in milliseconds.
|
||||||
|
user_creation_max_duration: 1209600000
|
||||||
|
|
||||||
# Set the number of bcrypt rounds used to generate password hash.
|
# Set the number of bcrypt rounds used to generate password hash.
|
||||||
# Larger numbers increase the work factor needed to generate the hash.
|
# Larger numbers increase the work factor needed to generate the hash.
|
||||||
# The default number of rounds is 12.
|
# The default number of rounds is 12.
|
||||||
|
|
|
@ -521,11 +521,11 @@ class AuthHandler(BaseHandler):
|
||||||
))
|
))
|
||||||
return m.serialize()
|
return m.serialize()
|
||||||
|
|
||||||
def generate_short_term_login_token(self, user_id):
|
def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
|
||||||
macaroon = self._generate_base_macaroon(user_id)
|
macaroon = self._generate_base_macaroon(user_id)
|
||||||
macaroon.add_first_party_caveat("type = login")
|
macaroon.add_first_party_caveat("type = login")
|
||||||
now = self.hs.get_clock().time_msec()
|
now = self.hs.get_clock().time_msec()
|
||||||
expiry = now + (2 * 60 * 1000)
|
expiry = now + duration_in_ms
|
||||||
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
||||||
return macaroon.serialize()
|
return macaroon.serialize()
|
||||||
|
|
||||||
|
|
|
@ -358,6 +358,59 @@ class RegistrationHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
defer.returnValue(data)
|
defer.returnValue(data)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_or_create_user(self, localpart, displayname, duration_seconds):
|
||||||
|
"""Creates a new user or returns an access token for an existing one
|
||||||
|
|
||||||
|
Args:
|
||||||
|
localpart : The local part of the user ID to register. If None,
|
||||||
|
one will be randomly generated.
|
||||||
|
Returns:
|
||||||
|
A tuple of (user_id, access_token).
|
||||||
|
Raises:
|
||||||
|
RegistrationError if there was a problem registering.
|
||||||
|
"""
|
||||||
|
yield run_on_reactor()
|
||||||
|
|
||||||
|
if localpart is None:
|
||||||
|
raise SynapseError(400, "Request must include user id")
|
||||||
|
|
||||||
|
need_register = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield self.check_username(localpart)
|
||||||
|
except SynapseError as e:
|
||||||
|
if e.errcode == Codes.USER_IN_USE:
|
||||||
|
need_register = False
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
user = UserID(localpart, self.hs.hostname)
|
||||||
|
user_id = user.to_string()
|
||||||
|
auth_handler = self.hs.get_handlers().auth_handler
|
||||||
|
token = auth_handler.generate_short_term_login_token(user_id, duration_seconds)
|
||||||
|
|
||||||
|
if need_register:
|
||||||
|
yield self.store.register(
|
||||||
|
user_id=user_id,
|
||||||
|
token=token,
|
||||||
|
password_hash=None
|
||||||
|
)
|
||||||
|
|
||||||
|
yield registered_user(self.distributor, user)
|
||||||
|
else:
|
||||||
|
yield self.store.flush_user(user_id=user_id)
|
||||||
|
yield self.store.add_access_token_to_user(user_id=user_id, token=token)
|
||||||
|
|
||||||
|
if displayname is not None:
|
||||||
|
logger.info("setting user display name: %s -> %s", user_id, displayname)
|
||||||
|
profile_handler = self.hs.get_handlers().profile_handler
|
||||||
|
yield profile_handler.set_displayname(
|
||||||
|
user, user, displayname
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((user_id, token))
|
||||||
|
|
||||||
def auth_handler(self):
|
def auth_handler(self):
|
||||||
return self.hs.get_handlers().auth_handler
|
return self.hs.get_handlers().auth_handler
|
||||||
|
|
||||||
|
|
|
@ -164,8 +164,8 @@ class ReplicationResource(Resource):
|
||||||
"Replicating %d rows of %s from %s -> %s",
|
"Replicating %d rows of %s from %s -> %s",
|
||||||
len(stream_content["rows"]),
|
len(stream_content["rows"]),
|
||||||
stream_name,
|
stream_name,
|
||||||
stream_content["position"],
|
|
||||||
request_streams.get(stream_name),
|
request_streams.get(stream_name),
|
||||||
|
stream_content["position"],
|
||||||
)
|
)
|
||||||
|
|
||||||
request.write(json.dumps(result, ensure_ascii=False))
|
request.write(json.dumps(result, ensure_ascii=False))
|
||||||
|
|
61
synapse/replication/slave/storage/account_data.py
Normal file
61
synapse/replication/slave/storage/account_data.py
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from ._base import BaseSlavedStore
|
||||||
|
from ._slaved_id_tracker import SlavedIdTracker
|
||||||
|
from synapse.storage.account_data import AccountDataStore
|
||||||
|
|
||||||
|
|
||||||
|
class SlavedAccountDataStore(BaseSlavedStore):
|
||||||
|
|
||||||
|
def __init__(self, db_conn, hs):
|
||||||
|
super(SlavedAccountDataStore, self).__init__(db_conn, hs)
|
||||||
|
self._account_data_id_gen = SlavedIdTracker(
|
||||||
|
db_conn, "account_data_max_stream_id", "stream_id",
|
||||||
|
)
|
||||||
|
|
||||||
|
get_global_account_data_by_type_for_users = (
|
||||||
|
AccountDataStore.__dict__["get_global_account_data_by_type_for_users"]
|
||||||
|
)
|
||||||
|
|
||||||
|
get_global_account_data_by_type_for_user = (
|
||||||
|
AccountDataStore.__dict__["get_global_account_data_by_type_for_user"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def stream_positions(self):
|
||||||
|
result = super(SlavedAccountDataStore, self).stream_positions()
|
||||||
|
position = self._account_data_id_gen.get_current_token()
|
||||||
|
result["user_account_data"] = position
|
||||||
|
result["room_account_data"] = position
|
||||||
|
result["tag_account_data"] = position
|
||||||
|
return result
|
||||||
|
|
||||||
|
def process_replication(self, result):
|
||||||
|
stream = result.get("user_account_data")
|
||||||
|
if stream:
|
||||||
|
self._account_data_id_gen.advance(int(stream["position"]))
|
||||||
|
for row in stream["rows"]:
|
||||||
|
user_id, data_type = row[1:3]
|
||||||
|
self.get_global_account_data_by_type_for_user.invalidate(
|
||||||
|
(data_type, user_id,)
|
||||||
|
)
|
||||||
|
|
||||||
|
stream = result.get("room_account_data")
|
||||||
|
if stream:
|
||||||
|
self._account_data_id_gen.advance(int(stream["position"]))
|
||||||
|
|
||||||
|
stream = result.get("tag_account_data")
|
||||||
|
if stream:
|
||||||
|
self._account_data_id_gen.advance(int(stream["position"]))
|
|
@ -165,12 +165,14 @@ class SlavedEventStore(BaseSlavedStore):
|
||||||
|
|
||||||
stream = result.get("forward_ex_outliers")
|
stream = result.get("forward_ex_outliers")
|
||||||
if stream:
|
if stream:
|
||||||
|
self._stream_id_gen.advance(stream["position"])
|
||||||
for row in stream["rows"]:
|
for row in stream["rows"]:
|
||||||
event_id = row[1]
|
event_id = row[1]
|
||||||
self._invalidate_get_event_cache(event_id)
|
self._invalidate_get_event_cache(event_id)
|
||||||
|
|
||||||
stream = result.get("backward_ex_outliers")
|
stream = result.get("backward_ex_outliers")
|
||||||
if stream:
|
if stream:
|
||||||
|
self._backfill_id_gen.advance(-stream["position"])
|
||||||
for row in stream["rows"]:
|
for row in stream["rows"]:
|
||||||
event_id = row[1]
|
event_id = row[1]
|
||||||
self._invalidate_get_event_cache(event_id)
|
self._invalidate_get_event_cache(event_id)
|
||||||
|
|
|
@ -355,5 +355,76 @@ class RegisterRestServlet(ClientV1RestServlet):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CreateUserRestServlet(ClientV1RestServlet):
|
||||||
|
"""Handles user creation via a server-to-server interface
|
||||||
|
"""
|
||||||
|
|
||||||
|
PATTERNS = client_path_patterns("/createUser$", releases=())
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(CreateUserRestServlet, self).__init__(hs)
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.direct_user_creation_max_duration = hs.config.user_creation_max_duration
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, request):
|
||||||
|
user_json = parse_json_object_from_request(request)
|
||||||
|
|
||||||
|
if "access_token" not in request.args:
|
||||||
|
raise SynapseError(400, "Expected application service token.")
|
||||||
|
|
||||||
|
app_service = yield self.store.get_app_service_by_token(
|
||||||
|
request.args["access_token"][0]
|
||||||
|
)
|
||||||
|
if not app_service:
|
||||||
|
raise SynapseError(403, "Invalid application service token.")
|
||||||
|
|
||||||
|
logger.debug("creating user: %s", user_json)
|
||||||
|
|
||||||
|
response = yield self._do_create(user_json)
|
||||||
|
|
||||||
|
defer.returnValue((200, response))
|
||||||
|
|
||||||
|
def on_OPTIONS(self, request):
|
||||||
|
return 403, {}
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _do_create(self, user_json):
|
||||||
|
yield run_on_reactor()
|
||||||
|
|
||||||
|
if "localpart" not in user_json:
|
||||||
|
raise SynapseError(400, "Expected 'localpart' key.")
|
||||||
|
|
||||||
|
if "displayname" not in user_json:
|
||||||
|
raise SynapseError(400, "Expected 'displayname' key.")
|
||||||
|
|
||||||
|
if "duration_seconds" not in user_json:
|
||||||
|
raise SynapseError(400, "Expected 'duration_seconds' key.")
|
||||||
|
|
||||||
|
localpart = user_json["localpart"].encode("utf-8")
|
||||||
|
displayname = user_json["displayname"].encode("utf-8")
|
||||||
|
duration_seconds = 0
|
||||||
|
try:
|
||||||
|
duration_seconds = int(user_json["duration_seconds"])
|
||||||
|
except ValueError:
|
||||||
|
raise SynapseError(400, "Failed to parse 'duration_seconds'")
|
||||||
|
if duration_seconds > self.direct_user_creation_max_duration:
|
||||||
|
duration_seconds = self.direct_user_creation_max_duration
|
||||||
|
|
||||||
|
handler = self.handlers.registration_handler
|
||||||
|
user_id, token = yield handler.get_or_create_user(
|
||||||
|
localpart=localpart,
|
||||||
|
displayname=displayname,
|
||||||
|
duration_seconds=duration_seconds
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue({
|
||||||
|
"user_id": user_id,
|
||||||
|
"access_token": token,
|
||||||
|
"home_server": self.hs.hostname,
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs, http_server):
|
def register_servlets(hs, http_server):
|
||||||
RegisterRestServlet(hs).register(http_server)
|
RegisterRestServlet(hs).register(http_server)
|
||||||
|
CreateUserRestServlet(hs).register(http_server)
|
||||||
|
|
|
@ -453,7 +453,9 @@ class SQLBaseStore(object):
|
||||||
keyvalues (dict): The unique key tables and their new values
|
keyvalues (dict): The unique key tables and their new values
|
||||||
values (dict): The nonunique columns and their new values
|
values (dict): The nonunique columns and their new values
|
||||||
insertion_values (dict): key/values to use when inserting
|
insertion_values (dict): key/values to use when inserting
|
||||||
Returns: A deferred
|
Returns:
|
||||||
|
Deferred(bool): True if a new entry was created, False if an
|
||||||
|
existing one was updated.
|
||||||
"""
|
"""
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
desc,
|
desc,
|
||||||
|
@ -498,6 +500,10 @@ class SQLBaseStore(object):
|
||||||
)
|
)
|
||||||
txn.execute(sql, allvalues.values())
|
txn.execute(sql, allvalues.values())
|
||||||
|
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
def _simple_select_one(self, table, keyvalues, retcols,
|
def _simple_select_one(self, table, keyvalues, retcols,
|
||||||
allow_none=False, desc="_simple_select_one"):
|
allow_none=False, desc="_simple_select_one"):
|
||||||
"""Executes a SELECT query on the named table, which is expected to
|
"""Executes a SELECT query on the named table, which is expected to
|
||||||
|
|
|
@ -224,6 +224,18 @@ class EventPushActionsStore(SQLBaseStore):
|
||||||
(room_id, event_id)
|
(room_id, event_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _remove_push_actions_before_txn(self, txn, room_id, user_id,
|
||||||
|
topological_ordering):
|
||||||
|
txn.call_after(
|
||||||
|
self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
|
||||||
|
(room_id, user_id, )
|
||||||
|
)
|
||||||
|
txn.execute(
|
||||||
|
"DELETE FROM event_push_actions"
|
||||||
|
" WHERE room_id = ? AND user_id = ? AND topological_ordering < ?",
|
||||||
|
(room_id, user_id, topological_ordering,)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _action_has_highlight(actions):
|
def _action_has_highlight(actions):
|
||||||
for action in actions:
|
for action in actions:
|
||||||
|
|
|
@ -156,8 +156,7 @@ class PusherStore(SQLBaseStore):
|
||||||
profile_tag=""):
|
profile_tag=""):
|
||||||
with self._pushers_id_gen.get_next() as stream_id:
|
with self._pushers_id_gen.get_next() as stream_id:
|
||||||
def f(txn):
|
def f(txn):
|
||||||
txn.call_after(self.get_users_with_pushers_in_room.invalidate_all)
|
newly_inserted = self._simple_upsert_txn(
|
||||||
return self._simple_upsert_txn(
|
|
||||||
txn,
|
txn,
|
||||||
"pushers",
|
"pushers",
|
||||||
{
|
{
|
||||||
|
@ -178,11 +177,18 @@ class PusherStore(SQLBaseStore):
|
||||||
"id": stream_id,
|
"id": stream_id,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
defer.returnValue((yield self.runInteraction("add_pusher", f)))
|
if newly_inserted:
|
||||||
|
# get_users_with_pushers_in_room only cares if the user has
|
||||||
|
# at least *one* pusher.
|
||||||
|
txn.call_after(self.get_users_with_pushers_in_room.invalidate_all)
|
||||||
|
|
||||||
|
yield self.runInteraction("add_pusher", f)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id):
|
def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id):
|
||||||
def delete_pusher_txn(txn, stream_id):
|
def delete_pusher_txn(txn, stream_id):
|
||||||
|
txn.call_after(self.get_users_with_pushers_in_room.invalidate_all)
|
||||||
|
|
||||||
self._simple_delete_one_txn(
|
self._simple_delete_one_txn(
|
||||||
txn,
|
txn,
|
||||||
"pushers",
|
"pushers",
|
||||||
|
@ -194,6 +200,7 @@ class PusherStore(SQLBaseStore):
|
||||||
{"app_id": app_id, "pushkey": pushkey, "user_id": user_id},
|
{"app_id": app_id, "pushkey": pushkey, "user_id": user_id},
|
||||||
{"stream_id": stream_id},
|
{"stream_id": stream_id},
|
||||||
)
|
)
|
||||||
|
|
||||||
with self._pushers_id_gen.get_next() as stream_id:
|
with self._pushers_id_gen.get_next() as stream_id:
|
||||||
yield self.runInteraction(
|
yield self.runInteraction(
|
||||||
"delete_pusher", delete_pusher_txn, stream_id
|
"delete_pusher", delete_pusher_txn, stream_id
|
||||||
|
|
|
@ -100,7 +100,7 @@ class ReceiptsStore(SQLBaseStore):
|
||||||
|
|
||||||
defer.returnValue([ev for res in results.values() for ev in res])
|
defer.returnValue([ev for res in results.values() for ev in res])
|
||||||
|
|
||||||
@cachedInlineCallbacks(num_args=3, max_entries=5000)
|
@cachedInlineCallbacks(num_args=3, max_entries=5000, lru=True, tree=True)
|
||||||
def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
|
def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
|
||||||
"""Get receipts for a single room for sending to clients.
|
"""Get receipts for a single room for sending to clients.
|
||||||
|
|
||||||
|
@ -232,7 +232,7 @@ class ReceiptsStore(SQLBaseStore):
|
||||||
self.get_receipts_for_user.invalidate, (user_id, receipt_type)
|
self.get_receipts_for_user.invalidate, (user_id, receipt_type)
|
||||||
)
|
)
|
||||||
# FIXME: This shouldn't invalidate the whole cache
|
# FIXME: This shouldn't invalidate the whole cache
|
||||||
txn.call_after(self.get_linearized_receipts_for_room.invalidate_all)
|
txn.call_after(self.get_linearized_receipts_for_room.invalidate_many, (room_id,))
|
||||||
|
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
self._receipts_stream_cache.entity_has_changed,
|
self._receipts_stream_cache.entity_has_changed,
|
||||||
|
@ -244,6 +244,17 @@ class ReceiptsStore(SQLBaseStore):
|
||||||
(user_id, room_id, receipt_type)
|
(user_id, room_id, receipt_type)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
res = self._simple_select_one_txn(
|
||||||
|
txn,
|
||||||
|
table="events",
|
||||||
|
retcols=["topological_ordering", "stream_ordering"],
|
||||||
|
keyvalues={"event_id": event_id},
|
||||||
|
allow_none=True
|
||||||
|
)
|
||||||
|
|
||||||
|
topological_ordering = int(res["topological_ordering"]) if res else None
|
||||||
|
stream_ordering = int(res["stream_ordering"]) if res else None
|
||||||
|
|
||||||
# We don't want to clobber receipts for more recent events, so we
|
# We don't want to clobber receipts for more recent events, so we
|
||||||
# have to compare orderings of existing receipts
|
# have to compare orderings of existing receipts
|
||||||
sql = (
|
sql = (
|
||||||
|
@ -255,16 +266,7 @@ class ReceiptsStore(SQLBaseStore):
|
||||||
txn.execute(sql, (room_id, receipt_type, user_id))
|
txn.execute(sql, (room_id, receipt_type, user_id))
|
||||||
results = txn.fetchall()
|
results = txn.fetchall()
|
||||||
|
|
||||||
if results:
|
if results and topological_ordering:
|
||||||
res = self._simple_select_one_txn(
|
|
||||||
txn,
|
|
||||||
table="events",
|
|
||||||
retcols=["topological_ordering", "stream_ordering"],
|
|
||||||
keyvalues={"event_id": event_id},
|
|
||||||
)
|
|
||||||
topological_ordering = int(res["topological_ordering"])
|
|
||||||
stream_ordering = int(res["stream_ordering"])
|
|
||||||
|
|
||||||
for to, so, _ in results:
|
for to, so, _ in results:
|
||||||
if int(to) > topological_ordering:
|
if int(to) > topological_ordering:
|
||||||
return False
|
return False
|
||||||
|
@ -294,6 +296,14 @@ class ReceiptsStore(SQLBaseStore):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if receipt_type == "m.read" and topological_ordering:
|
||||||
|
self._remove_push_actions_before_txn(
|
||||||
|
txn,
|
||||||
|
room_id=room_id,
|
||||||
|
user_id=user_id,
|
||||||
|
topological_ordering=topological_ordering,
|
||||||
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -367,7 +377,7 @@ class ReceiptsStore(SQLBaseStore):
|
||||||
self.get_receipts_for_user.invalidate, (user_id, receipt_type)
|
self.get_receipts_for_user.invalidate, (user_id, receipt_type)
|
||||||
)
|
)
|
||||||
# FIXME: This shouldn't invalidate the whole cache
|
# FIXME: This shouldn't invalidate the whole cache
|
||||||
txn.call_after(self.get_linearized_receipts_for_room.invalidate_all)
|
txn.call_after(self.get_linearized_receipts_for_room.invalidate_many, (room_id,))
|
||||||
|
|
||||||
self._simple_delete_txn(
|
self._simple_delete_txn(
|
||||||
txn,
|
txn,
|
||||||
|
|
38
synapse/storage/schema/delta/32/remove_indices.sql
Normal file
38
synapse/storage/schema/delta/32/remove_indices.sql
Normal file
|
@ -0,0 +1,38 @@
|
||||||
|
/* Copyright 2016 OpenMarket Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
-- The following indices are redundant, other indices are equivalent or
|
||||||
|
-- supersets
|
||||||
|
DROP INDEX IF EXISTS events_room_id; -- Prefix of events_room_stream
|
||||||
|
DROP INDEX IF EXISTS events_order; -- Prefix of events_order_topo_stream_room
|
||||||
|
DROP INDEX IF EXISTS events_topological_ordering; -- Prefix of events_order_topo_stream_room
|
||||||
|
DROP INDEX IF EXISTS events_stream_ordering; -- Duplicate of PRIMARY KEY
|
||||||
|
DROP INDEX IF EXISTS state_groups_id; -- Duplicate of PRIMARY KEY
|
||||||
|
DROP INDEX IF EXISTS event_to_state_groups_id; -- Duplicate of PRIMARY KEY
|
||||||
|
DROP INDEX IF EXISTS event_push_actions_room_id_event_id_user_id_profile_tag; -- Duplicate of UNIQUE CONSTRAINT
|
||||||
|
|
||||||
|
DROP INDEX IF EXISTS event_destinations_id; -- Prefix of UNIQUE CONSTRAINT
|
||||||
|
DROP INDEX IF EXISTS st_extrem_id; -- Prefix of UNIQUE CONSTRAINT
|
||||||
|
DROP INDEX IF EXISTS event_content_hashes_id; -- Prefix of UNIQUE CONSTRAINT
|
||||||
|
DROP INDEX IF EXISTS event_signatures_id; -- Prefix of UNIQUE CONSTRAINT
|
||||||
|
DROP INDEX IF EXISTS event_edge_hashes_id; -- Prefix of UNIQUE CONSTRAINT
|
||||||
|
DROP INDEX IF EXISTS redactions_event_id; -- Duplicate of UNIQUE CONSTRAINT
|
||||||
|
DROP INDEX IF EXISTS room_hosts_room_id; -- Prefix of UNIQUE CONSTRAINT
|
||||||
|
|
||||||
|
-- The following indices were unused
|
||||||
|
DROP INDEX IF EXISTS remote_media_cache_thumbnails_media_id;
|
||||||
|
DROP INDEX IF EXISTS evauth_edges_auth_id;
|
||||||
|
DROP INDEX IF EXISTS presence_stream_state;
|
|
@ -284,12 +284,12 @@ class AuthTestCase(unittest.TestCase):
|
||||||
macaroon.add_first_party_caveat("time < 1") # ms
|
macaroon.add_first_party_caveat("time < 1") # ms
|
||||||
|
|
||||||
self.hs.clock.now = 5000 # seconds
|
self.hs.clock.now = 5000 # seconds
|
||||||
|
self.hs.config.expire_access_token = True
|
||||||
yield self.auth.get_user_from_macaroon(macaroon.serialize())
|
# yield self.auth.get_user_from_macaroon(macaroon.serialize())
|
||||||
# TODO(daniel): Turn on the check that we validate expiration, when we
|
# TODO(daniel): Turn on the check that we validate expiration, when we
|
||||||
# validate expiration (and remove the above line, which will start
|
# validate expiration (and remove the above line, which will start
|
||||||
# throwing).
|
# throwing).
|
||||||
# with self.assertRaises(AuthError) as cm:
|
with self.assertRaises(AuthError) as cm:
|
||||||
# yield self.auth.get_user_from_macaroon(macaroon.serialize())
|
yield self.auth.get_user_from_macaroon(macaroon.serialize())
|
||||||
# self.assertEqual(401, cm.exception.code)
|
self.assertEqual(401, cm.exception.code)
|
||||||
# self.assertIn("Invalid macaroon", cm.exception.msg)
|
self.assertIn("Invalid macaroon", cm.exception.msg)
|
||||||
|
|
67
tests/handlers/test_register.py
Normal file
67
tests/handlers/test_register.py
Normal file
|
@ -0,0 +1,67 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2015, 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
from .. import unittest
|
||||||
|
|
||||||
|
from synapse.handlers.register import RegistrationHandler
|
||||||
|
|
||||||
|
from tests.utils import setup_test_homeserver
|
||||||
|
|
||||||
|
from mock import Mock
|
||||||
|
|
||||||
|
|
||||||
|
class RegistrationHandlers(object):
|
||||||
|
def __init__(self, hs):
|
||||||
|
self.registration_handler = RegistrationHandler(hs)
|
||||||
|
|
||||||
|
|
||||||
|
class RegistrationTestCase(unittest.TestCase):
|
||||||
|
""" Tests the RegistrationHandler. """
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def setUp(self):
|
||||||
|
self.mock_distributor = Mock()
|
||||||
|
self.mock_distributor.declare("registered_user")
|
||||||
|
self.mock_captcha_client = Mock()
|
||||||
|
hs = yield setup_test_homeserver(
|
||||||
|
handlers=None,
|
||||||
|
http_client=None,
|
||||||
|
expire_access_token=True)
|
||||||
|
hs.handlers = RegistrationHandlers(hs)
|
||||||
|
self.handler = hs.get_handlers().registration_handler
|
||||||
|
hs.get_handlers().profile_handler = Mock()
|
||||||
|
self.mock_handler = Mock(spec=[
|
||||||
|
"generate_short_term_login_token",
|
||||||
|
])
|
||||||
|
|
||||||
|
hs.get_handlers().auth_handler = self.mock_handler
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_user_is_created_and_logged_in_if_doesnt_exist(self):
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
The user doess not exist in this case so it will register and log it in
|
||||||
|
"""
|
||||||
|
duration_ms = 200
|
||||||
|
local_part = "someone"
|
||||||
|
display_name = "someone"
|
||||||
|
user_id = "@someone:test"
|
||||||
|
mock_token = self.mock_handler.generate_short_term_login_token
|
||||||
|
mock_token.return_value = 'secret'
|
||||||
|
result_user_id, result_token = yield self.handler.get_or_create_user(
|
||||||
|
local_part, display_name, duration_ms)
|
||||||
|
self.assertEquals(result_user_id, user_id)
|
||||||
|
self.assertEquals(result_token, 'secret')
|
56
tests/replication/slave/storage/test_account_data.py
Normal file
56
tests/replication/slave/storage/test_account_data.py
Normal file
|
@ -0,0 +1,56 @@
|
||||||
|
# Copyright 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from ._base import BaseSlavedStoreTestCase
|
||||||
|
|
||||||
|
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
USER_ID = "@feeling:blue"
|
||||||
|
TYPE = "my.type"
|
||||||
|
|
||||||
|
|
||||||
|
class SlavedAccountDataStoreTestCase(BaseSlavedStoreTestCase):
|
||||||
|
|
||||||
|
STORE_TYPE = SlavedAccountDataStore
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_user_account_data(self):
|
||||||
|
yield self.master_store.add_account_data_for_user(
|
||||||
|
USER_ID, TYPE, {"a": 1}
|
||||||
|
)
|
||||||
|
yield self.replicate()
|
||||||
|
yield self.check(
|
||||||
|
"get_global_account_data_by_type_for_user",
|
||||||
|
[TYPE, USER_ID], {"a": 1}
|
||||||
|
)
|
||||||
|
yield self.check(
|
||||||
|
"get_global_account_data_by_type_for_users",
|
||||||
|
[TYPE, [USER_ID]], {USER_ID: {"a": 1}}
|
||||||
|
)
|
||||||
|
|
||||||
|
yield self.master_store.add_account_data_for_user(
|
||||||
|
USER_ID, TYPE, {"a": 2}
|
||||||
|
)
|
||||||
|
yield self.replicate()
|
||||||
|
yield self.check(
|
||||||
|
"get_global_account_data_by_type_for_user",
|
||||||
|
[TYPE, USER_ID], {"a": 2}
|
||||||
|
)
|
||||||
|
yield self.check(
|
||||||
|
"get_global_account_data_by_type_for_users",
|
||||||
|
[TYPE, [USER_ID]], {USER_ID: {"a": 2}}
|
||||||
|
)
|
88
tests/rest/client/v1/test_register.py
Normal file
88
tests/rest/client/v1/test_register.py
Normal file
|
@ -0,0 +1,88 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2015, 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from synapse.rest.client.v1.register import CreateUserRestServlet
|
||||||
|
from twisted.internet import defer
|
||||||
|
from mock import Mock
|
||||||
|
from tests import unittest
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
class CreateUserServletTestCase(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
# do the dance to hook up request data to self.request_data
|
||||||
|
self.request_data = ""
|
||||||
|
self.request = Mock(
|
||||||
|
content=Mock(read=Mock(side_effect=lambda: self.request_data)),
|
||||||
|
path='/_matrix/client/api/v1/createUser'
|
||||||
|
)
|
||||||
|
self.request.args = {}
|
||||||
|
|
||||||
|
self.appservice = None
|
||||||
|
self.auth = Mock(get_appservice_by_req=Mock(
|
||||||
|
side_effect=lambda x: defer.succeed(self.appservice))
|
||||||
|
)
|
||||||
|
|
||||||
|
self.auth_result = (False, None, None, None)
|
||||||
|
self.auth_handler = Mock(
|
||||||
|
check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
|
||||||
|
get_session_data=Mock(return_value=None)
|
||||||
|
)
|
||||||
|
self.registration_handler = Mock()
|
||||||
|
self.identity_handler = Mock()
|
||||||
|
self.login_handler = Mock()
|
||||||
|
|
||||||
|
# do the dance to hook it up to the hs global
|
||||||
|
self.handlers = Mock(
|
||||||
|
auth_handler=self.auth_handler,
|
||||||
|
registration_handler=self.registration_handler,
|
||||||
|
identity_handler=self.identity_handler,
|
||||||
|
login_handler=self.login_handler
|
||||||
|
)
|
||||||
|
self.hs = Mock()
|
||||||
|
self.hs.hostname = "supergbig~testing~thing.com"
|
||||||
|
self.hs.get_auth = Mock(return_value=self.auth)
|
||||||
|
self.hs.get_handlers = Mock(return_value=self.handlers)
|
||||||
|
self.hs.config.enable_registration = True
|
||||||
|
# init the thing we're testing
|
||||||
|
self.servlet = CreateUserRestServlet(self.hs)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_POST_createuser_with_valid_user(self):
|
||||||
|
user_id = "@someone:interesting"
|
||||||
|
token = "my token"
|
||||||
|
self.request.args = {
|
||||||
|
"access_token": "i_am_an_app_service"
|
||||||
|
}
|
||||||
|
self.request_data = json.dumps({
|
||||||
|
"localpart": "someone",
|
||||||
|
"displayname": "someone interesting",
|
||||||
|
"duration_seconds": 200
|
||||||
|
})
|
||||||
|
|
||||||
|
self.registration_handler.get_or_create_user = Mock(
|
||||||
|
return_value=(user_id, token)
|
||||||
|
)
|
||||||
|
|
||||||
|
(code, result) = yield self.servlet.on_POST(self.request)
|
||||||
|
self.assertEquals(code, 200)
|
||||||
|
|
||||||
|
det_data = {
|
||||||
|
"user_id": user_id,
|
||||||
|
"access_token": token,
|
||||||
|
"home_server": self.hs.hostname
|
||||||
|
}
|
||||||
|
self.assertDictContainsSubset(det_data, result)
|
|
@ -49,6 +49,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
|
||||||
config.event_cache_size = 1
|
config.event_cache_size = 1
|
||||||
config.enable_registration = True
|
config.enable_registration = True
|
||||||
config.macaroon_secret_key = "not even a little secret"
|
config.macaroon_secret_key = "not even a little secret"
|
||||||
|
config.expire_access_token = False
|
||||||
config.server_name = "server.under.test"
|
config.server_name = "server.under.test"
|
||||||
config.trusted_third_party_id_servers = []
|
config.trusted_third_party_id_servers = []
|
||||||
config.room_invite_state_types = []
|
config.room_invite_state_types = []
|
||||||
|
|
Loading…
Reference in a new issue