0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-14 17:23:53 +01:00

Merge branch 'erikj/shared_secret' into erikj/test2

This commit is contained in:
Erik Johnston 2016-07-06 14:46:31 +01:00
commit a17e7caeb7
10 changed files with 213 additions and 111 deletions

View file

@ -25,18 +25,26 @@ import urllib2
import yaml import yaml
def request_registration(user, password, server_location, shared_secret): def request_registration(user, password, server_location, shared_secret, admin=False):
mac = hmac.new( mac = hmac.new(
key=shared_secret, key=shared_secret,
msg=user,
digestmod=hashlib.sha1, digestmod=hashlib.sha1,
).hexdigest() )
mac.update(user)
mac.update("\x00")
mac.update(password)
mac.update("\x00")
mac.update("admin" if admin else "notadmin")
mac = mac.hexdigest()
data = { data = {
"user": user, "user": user,
"password": password, "password": password,
"mac": mac, "mac": mac,
"type": "org.matrix.login.shared_secret", "type": "org.matrix.login.shared_secret",
"admin": admin,
} }
server_location = server_location.rstrip("/") server_location = server_location.rstrip("/")
@ -68,7 +76,7 @@ def request_registration(user, password, server_location, shared_secret):
sys.exit(1) sys.exit(1)
def register_new_user(user, password, server_location, shared_secret): def register_new_user(user, password, server_location, shared_secret, admin):
if not user: if not user:
try: try:
default_user = getpass.getuser() default_user = getpass.getuser()
@ -99,7 +107,14 @@ def register_new_user(user, password, server_location, shared_secret):
print "Passwords do not match" print "Passwords do not match"
sys.exit(1) sys.exit(1)
request_registration(user, password, server_location, shared_secret) if not admin:
admin = raw_input("Make admin [no]: ")
if admin in ("y", "yes", "true"):
admin = True
else:
admin = False
request_registration(user, password, server_location, shared_secret, bool(admin))
if __name__ == "__main__": if __name__ == "__main__":
@ -119,6 +134,11 @@ if __name__ == "__main__":
default=None, default=None,
help="New password for user. Will prompt if omitted.", help="New password for user. Will prompt if omitted.",
) )
parser.add_argument(
"-a", "--admin",
action="store_true",
help="Register new user as an admin. Will prompt if omitted.",
)
group = parser.add_mutually_exclusive_group(required=True) group = parser.add_mutually_exclusive_group(required=True)
group.add_argument( group.add_argument(
@ -151,4 +171,4 @@ if __name__ == "__main__":
else: else:
secret = args.shared_secret secret = args.shared_secret
register_new_user(args.user, args.password, args.server_url, secret) register_new_user(args.user, args.password, args.server_url, secret, args.admin)

View file

@ -42,8 +42,9 @@ class Codes(object):
TOO_LARGE = "M_TOO_LARGE" TOO_LARGE = "M_TOO_LARGE"
EXCLUSIVE = "M_EXCLUSIVE" EXCLUSIVE = "M_EXCLUSIVE"
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED" THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
THREEPID_IN_USE = "THREEPID_IN_USE" THREEPID_IN_USE = "M_THREEPID_IN_USE"
INVALID_USERNAME = "M_INVALID_USERNAME" INVALID_USERNAME = "M_INVALID_USERNAME"
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
class CodeMessageException(RuntimeError): class CodeMessageException(RuntimeError):

View file

@ -23,10 +23,14 @@ class PasswordConfig(Config):
def read_config(self, config): def read_config(self, config):
password_config = config.get("password_config", {}) password_config = config.get("password_config", {})
self.password_enabled = password_config.get("enabled", True) self.password_enabled = password_config.get("enabled", True)
self.password_pepper = password_config.get("pepper", "")
def default_config(self, config_dir_path, server_name, **kwargs): def default_config(self, config_dir_path, server_name, **kwargs):
return """ return """
# Enable password for login. # Enable password for login.
password_config: password_config:
enabled: true enabled: true
# Change to a secret random string.
# DO NOT CHANGE THIS AFTER INITIAL SETUP!
#pepper: ""
""" """

View file

@ -750,7 +750,8 @@ class AuthHandler(BaseHandler):
Returns: Returns:
Hashed password (str). Hashed password (str).
""" """
return bcrypt.hashpw(password, bcrypt.gensalt(self.bcrypt_rounds)) return bcrypt.hashpw(password + self.hs.config.password_pepper,
bcrypt.gensalt(self.bcrypt_rounds))
def validate_hash(self, password, stored_hash): def validate_hash(self, password, stored_hash):
"""Validates that self.hash(password) == stored_hash. """Validates that self.hash(password) == stored_hash.
@ -763,6 +764,7 @@ class AuthHandler(BaseHandler):
Whether self.hash(password) == stored_hash (bool). Whether self.hash(password) == stored_hash (bool).
""" """
if stored_hash: if stored_hash:
return bcrypt.hashpw(password, stored_hash.encode('utf-8')) == stored_hash return bcrypt.hashpw(password + self.hs.config.password_pepper,
stored_hash.encode('utf-8')) == stored_hash
else: else:
return False return False

View file

@ -21,7 +21,7 @@ from synapse.api.errors import (
) )
from ._base import BaseHandler from ._base import BaseHandler
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError, Codes
import json import json
import logging import logging
@ -41,6 +41,20 @@ class IdentityHandler(BaseHandler):
hs.config.use_insecure_ssl_client_just_for_testing_do_not_use hs.config.use_insecure_ssl_client_just_for_testing_do_not_use
) )
def _should_trust_id_server(self, id_server):
if id_server not in self.trusted_id_servers:
if self.trust_any_id_server_just_for_testing_do_not_use:
logger.warn(
"Trusting untrustworthy ID server %r even though it isn't"
" in the trusted id list for testing because"
" 'use_insecure_ssl_client_just_for_testing_do_not_use'"
" is set in the config",
id_server,
)
else:
return False
return True
@defer.inlineCallbacks @defer.inlineCallbacks
def threepid_from_creds(self, creds): def threepid_from_creds(self, creds):
yield run_on_reactor() yield run_on_reactor()
@ -59,18 +73,11 @@ class IdentityHandler(BaseHandler):
else: else:
raise SynapseError(400, "No client_secret in creds") raise SynapseError(400, "No client_secret in creds")
if id_server not in self.trusted_id_servers: if not self._should_trust_id_server(id_server):
if self.trust_any_id_server_just_for_testing_do_not_use:
logger.warn( logger.warn(
"Trusting untrustworthy ID server %r even though it isn't" '%s is not a trusted ID server: rejecting 3pid ' +
" in the trusted id list for testing because" 'credentials', id_server
" 'use_insecure_ssl_client_just_for_testing_do_not_use'"
" is set in the config",
id_server,
) )
else:
logger.warn('%s is not a trusted ID server: rejecting 3pid ' +
'credentials', id_server)
defer.returnValue(None) defer.returnValue(None)
data = {} data = {}
@ -129,6 +136,12 @@ class IdentityHandler(BaseHandler):
def requestEmailToken(self, id_server, email, client_secret, send_attempt, **kwargs): def requestEmailToken(self, id_server, email, client_secret, send_attempt, **kwargs):
yield run_on_reactor() yield run_on_reactor()
if not self._should_trust_id_server(id_server):
raise SynapseError(
400, "Untrusted ID server '%s'" % id_server,
Codes.SERVER_NOT_TRUSTED
)
params = { params = {
'email': email, 'email': email,
'client_secret': client_secret, 'client_secret': client_secret,

View file

@ -90,7 +90,8 @@ class RegistrationHandler(BaseHandler):
password=None, password=None,
generate_token=True, generate_token=True,
guest_access_token=None, guest_access_token=None,
make_guest=False make_guest=False,
admin=False,
): ):
"""Registers a new client on the server. """Registers a new client on the server.
@ -141,6 +142,7 @@ class RegistrationHandler(BaseHandler):
# If the user was a guest then they already have a profile # If the user was a guest then they already have a profile
None if was_guest else user.localpart None if was_guest else user.localpart
), ),
admin=admin,
) )
else: else:
# autogen a sequential user ID # autogen a sequential user ID

View file

@ -324,6 +324,14 @@ class RegisterRestServlet(ClientV1RestServlet):
raise SynapseError(400, "Shared secret registration is not enabled") raise SynapseError(400, "Shared secret registration is not enabled")
user = register_json["user"].encode("utf-8") user = register_json["user"].encode("utf-8")
password = register_json["password"].encode("utf-8")
admin = register_json.get("admin", None)
# Its important to check as we use null bytes as HMAC field separators
if "\x00" in user:
raise SynapseError(400, "Invalid user")
if "\x00" in password:
raise SynapseError(400, "Invalid password")
# str() because otherwise hmac complains that 'unicode' does not # str() because otherwise hmac complains that 'unicode' does not
# have the buffer interface # have the buffer interface
@ -331,17 +339,21 @@ class RegisterRestServlet(ClientV1RestServlet):
want_mac = hmac.new( want_mac = hmac.new(
key=self.hs.config.registration_shared_secret, key=self.hs.config.registration_shared_secret,
msg=user,
digestmod=sha1, digestmod=sha1,
).hexdigest() )
want_mac.update(user)
password = register_json["password"].encode("utf-8") want_mac.update("\x00")
want_mac.update(password)
want_mac.update("\x00")
want_mac.update("admin" if admin else "notadmin")
want_mac = want_mac.hexdigest()
if compare_digest(want_mac, got_mac): if compare_digest(want_mac, got_mac):
handler = self.handlers.registration_handler handler = self.handlers.registration_handler
user_id, token = yield handler.register( user_id, token = yield handler.register(
localpart=user, localpart=user,
password=password, password=password,
admin=bool(admin),
) )
self._remove_session(session) self._remove_session(session)
defer.returnValue({ defer.returnValue({

View file

@ -16,6 +16,8 @@
from ._base import SQLBaseStore from ._base import SQLBaseStore
from twisted.internet import defer from twisted.internet import defer
from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.types import RoomStreamToken
from .stream import lower_bound
import logging import logging
import ujson as json import ujson as json
@ -73,6 +75,9 @@ class EventPushActionsStore(SQLBaseStore):
stream_ordering = results[0][0] stream_ordering = results[0][0]
topological_ordering = results[0][1] topological_ordering = results[0][1]
token = RoomStreamToken(
topological_ordering, stream_ordering
)
sql = ( sql = (
"SELECT sum(notif), sum(highlight)" "SELECT sum(notif), sum(highlight)"
@ -80,15 +85,10 @@ class EventPushActionsStore(SQLBaseStore):
" WHERE" " WHERE"
" user_id = ?" " user_id = ?"
" AND room_id = ?" " AND room_id = ?"
" AND (" " AND %s"
" topological_ordering > ?" ) % (lower_bound(token, self.database_engine, inclusive=False),)
" OR (topological_ordering = ? AND stream_ordering > ?)"
")" txn.execute(sql, (user_id, room_id))
)
txn.execute(sql, (
user_id, room_id,
topological_ordering, topological_ordering, stream_ordering
))
row = txn.fetchone() row = txn.fetchone()
if row: if row:
return { return {

View file

@ -77,7 +77,7 @@ class RegistrationStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def register(self, user_id, token, password_hash, def register(self, user_id, token, password_hash,
was_guest=False, make_guest=False, appservice_id=None, was_guest=False, make_guest=False, appservice_id=None,
create_profile_with_localpart=None): create_profile_with_localpart=None, admin=False):
"""Attempts to register an account. """Attempts to register an account.
Args: Args:
@ -104,6 +104,7 @@ class RegistrationStore(SQLBaseStore):
make_guest, make_guest,
appservice_id, appservice_id,
create_profile_with_localpart, create_profile_with_localpart,
admin
) )
self.get_user_by_id.invalidate((user_id,)) self.get_user_by_id.invalidate((user_id,))
self.is_guest.invalidate((user_id,)) self.is_guest.invalidate((user_id,))
@ -118,6 +119,7 @@ class RegistrationStore(SQLBaseStore):
make_guest, make_guest,
appservice_id, appservice_id,
create_profile_with_localpart, create_profile_with_localpart,
admin,
): ):
now = int(self.clock.time()) now = int(self.clock.time())
@ -125,29 +127,33 @@ class RegistrationStore(SQLBaseStore):
try: try:
if was_guest: if was_guest:
txn.execute("UPDATE users SET" self._simple_update_one_txn(
" password_hash = ?," txn,
" upgrade_ts = ?," "users",
" is_guest = ?" keyvalues={
" WHERE name = ?", "name": user_id,
[password_hash, now, 1 if make_guest else 0, user_id]) },
updatevalues={
"password_hash": password_hash,
"upgrade_ts": now,
"is_guest": 1 if make_guest else 0,
"appservice_id": appservice_id,
"admin": 1 if admin else 0,
}
)
else: else:
txn.execute("INSERT INTO users " self._simple_insert_txn(
"(" txn,
" name," "users",
" password_hash," values={
" creation_ts," "name": user_id,
" is_guest," "password_hash": password_hash,
" appservice_id" "creation_ts": now,
") " "is_guest": 1 if make_guest else 0,
"VALUES (?,?,?,?,?)", "appservice_id": appservice_id,
[ "admin": 1 if admin else 0,
user_id, }
password_hash, )
now,
1 if make_guest else 0,
appservice_id,
])
except self.database_engine.module.IntegrityError: except self.database_engine.module.IntegrityError:
raise StoreError( raise StoreError(
400, "User ID already taken.", errcode=Codes.USER_IN_USE 400, "User ID already taken.", errcode=Codes.USER_IN_USE

View file

@ -40,6 +40,7 @@ from synapse.util.caches.descriptors import cached
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.types import RoomStreamToken from synapse.types import RoomStreamToken
from synapse.util.logcontext import preserve_fn from synapse.util.logcontext import preserve_fn
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
import logging import logging
@ -54,25 +55,43 @@ _STREAM_TOKEN = "stream"
_TOPOLOGICAL_TOKEN = "topological" _TOPOLOGICAL_TOKEN = "topological"
def lower_bound(token): def lower_bound(token, engine, inclusive=False):
inclusive = "=" if inclusive else ""
if token.topological is None: if token.topological is None:
return "(%d < %s)" % (token.stream, "stream_ordering") return "(%d <%s %s)" % (token.stream, inclusive, "stream_ordering")
else: else:
return "(%d < %s OR (%d = %s AND %d < %s))" % ( if isinstance(engine, PostgresEngine):
# Postgres doesn't optimise ``(x < a) OR (x=a AND y<b)`` as well
# as it optimises ``(x,y) < (a,b)`` on multicolumn indexes. So we
# use the later form when running against postgres.
return "((%d,%d) <%s (%s,%s))" % (
token.topological, token.stream, inclusive,
"topological_ordering", "stream_ordering",
)
return "(%d < %s OR (%d = %s AND %d <%s %s))" % (
token.topological, "topological_ordering", token.topological, "topological_ordering",
token.topological, "topological_ordering", token.topological, "topological_ordering",
token.stream, "stream_ordering", token.stream, inclusive, "stream_ordering",
) )
def upper_bound(token): def upper_bound(token, engine, inclusive=True):
inclusive = "=" if inclusive else ""
if token.topological is None: if token.topological is None:
return "(%d >= %s)" % (token.stream, "stream_ordering") return "(%d >%s %s)" % (token.stream, inclusive, "stream_ordering")
else: else:
return "(%d > %s OR (%d = %s AND %d >= %s))" % ( if isinstance(engine, PostgresEngine):
# Postgres doesn't optimise ``(x > a) OR (x=a AND y>b)`` as well
# as it optimises ``(x,y) > (a,b)`` on multicolumn indexes. So we
# use the later form when running against postgres.
return "((%d,%d) >%s (%s,%s))" % (
token.topological, token.stream, inclusive,
"topological_ordering", "stream_ordering",
)
return "(%d > %s OR (%d = %s AND %d >%s %s))" % (
token.topological, "topological_ordering", token.topological, "topological_ordering",
token.topological, "topological_ordering", token.topological, "topological_ordering",
token.stream, "stream_ordering", token.stream, inclusive, "stream_ordering",
) )
@ -308,18 +327,22 @@ class StreamStore(SQLBaseStore):
args = [False, room_id] args = [False, room_id]
if direction == 'b': if direction == 'b':
order = "DESC" order = "DESC"
bounds = upper_bound(RoomStreamToken.parse(from_key)) bounds = upper_bound(
if to_key: RoomStreamToken.parse(from_key), self.database_engine
bounds = "%s AND %s" % (
bounds, lower_bound(RoomStreamToken.parse(to_key))
) )
if to_key:
bounds = "%s AND %s" % (bounds, lower_bound(
RoomStreamToken.parse(to_key), self.database_engine
))
else: else:
order = "ASC" order = "ASC"
bounds = lower_bound(RoomStreamToken.parse(from_key)) bounds = lower_bound(
if to_key: RoomStreamToken.parse(from_key), self.database_engine
bounds = "%s AND %s" % (
bounds, upper_bound(RoomStreamToken.parse(to_key))
) )
if to_key:
bounds = "%s AND %s" % (bounds, upper_bound(
RoomStreamToken.parse(to_key), self.database_engine
))
if int(limit) > 0: if int(limit) > 0:
args.append(int(limit)) args.append(int(limit))
@ -586,32 +609,60 @@ class StreamStore(SQLBaseStore):
retcols=["stream_ordering", "topological_ordering"], retcols=["stream_ordering", "topological_ordering"],
) )
stream_ordering = results["stream_ordering"] token = RoomStreamToken(
topological_ordering = results["topological_ordering"] results["topological_ordering"],
results["stream_ordering"],
)
if isinstance(self.database_engine, Sqlite3Engine):
# SQLite3 doesn't optimise ``(x < a) OR (x = a AND y < b)``
# So we give pass it to SQLite3 as the UNION ALL of the two queries.
query_before = ( query_before = (
"SELECT topological_ordering, stream_ordering, event_id FROM events" "SELECT topological_ordering, stream_ordering, event_id FROM events"
" WHERE room_id = ? AND (topological_ordering < ?" " WHERE room_id = ? AND topological_ordering < ?"
" OR (topological_ordering = ? AND stream_ordering < ?))" " UNION ALL"
" ORDER BY topological_ordering DESC, stream_ordering DESC" " SELECT topological_ordering, stream_ordering, event_id FROM events"
" LIMIT ?" " WHERE room_id = ? AND topological_ordering = ? AND stream_ordering < ?"
" ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?"
)
before_args = (
room_id, token.topological,
room_id, token.topological, token.stream,
before_limit,
) )
query_after = ( query_after = (
"SELECT topological_ordering, stream_ordering, event_id FROM events" "SELECT topological_ordering, stream_ordering, event_id FROM events"
" WHERE room_id = ? AND (topological_ordering > ?" " WHERE room_id = ? AND topological_ordering > ?"
" OR (topological_ordering = ? AND stream_ordering > ?))" " UNION ALL"
" ORDER BY topological_ordering ASC, stream_ordering ASC" " SELECT topological_ordering, stream_ordering, event_id FROM events"
" LIMIT ?" " WHERE room_id = ? AND topological_ordering = ? AND stream_ordering > ?"
" ORDER BY topological_ordering ASC, stream_ordering ASC LIMIT ?"
) )
after_args = (
room_id, token.topological,
room_id, token.topological, token.stream,
after_limit,
)
else:
query_before = (
"SELECT topological_ordering, stream_ordering, event_id FROM events"
" WHERE room_id = ? AND %s"
" ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?"
) % (upper_bound(token, self.database_engine, inclusive=False),)
txn.execute( before_args = (room_id, before_limit)
query_before,
( query_after = (
room_id, topological_ordering, topological_ordering, "SELECT topological_ordering, stream_ordering, event_id FROM events"
stream_ordering, before_limit, " WHERE room_id = ? AND %s"
) " ORDER BY topological_ordering ASC, stream_ordering ASC LIMIT ?"
) ) % (lower_bound(token, self.database_engine, inclusive=False),)
after_args = (room_id, after_limit)
txn.execute(query_before, before_args)
rows = self.cursor_to_dict(txn) rows = self.cursor_to_dict(txn)
events_before = [r["event_id"] for r in rows] events_before = [r["event_id"] for r in rows]
@ -623,17 +674,11 @@ class StreamStore(SQLBaseStore):
)) ))
else: else:
start_token = str(RoomStreamToken( start_token = str(RoomStreamToken(
topological_ordering, token.topological,
stream_ordering - 1, token.stream - 1,
)) ))
txn.execute( txn.execute(query_after, after_args)
query_after,
(
room_id, topological_ordering, topological_ordering,
stream_ordering, after_limit,
)
)
rows = self.cursor_to_dict(txn) rows = self.cursor_to_dict(txn)
events_after = [r["event_id"] for r in rows] events_after = [r["event_id"] for r in rows]
@ -644,10 +689,7 @@ class StreamStore(SQLBaseStore):
rows[-1]["stream_ordering"], rows[-1]["stream_ordering"],
)) ))
else: else:
end_token = str(RoomStreamToken( end_token = str(token)
topological_ordering,
stream_ordering,
))
return { return {
"before": { "before": {