0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-07-06 19:28:46 +02:00

Merge branch 'develop' into rav/invalid_request_utf8

This commit is contained in:
Richard van der Hoff 2017-11-13 11:56:22 +00:00
commit 8b33ac8f6c
6 changed files with 46 additions and 18 deletions

View file

@ -15,7 +15,6 @@
"""Contains functions for registering clients.""" """Contains functions for registering clients."""
import logging import logging
import urllib
from twisted.internet import defer from twisted.internet import defer
@ -23,6 +22,7 @@ from synapse.api.errors import (
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
) )
from synapse.http.client import CaptchaServerHttpClient from synapse.http.client import CaptchaServerHttpClient
from synapse import types
from synapse.types import UserID from synapse.types import UserID
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from ._base import BaseHandler from ._base import BaseHandler
@ -47,7 +47,7 @@ class RegistrationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def check_username(self, localpart, guest_access_token=None, def check_username(self, localpart, guest_access_token=None,
assigned_user_id=None): assigned_user_id=None):
if urllib.quote(localpart.encode('utf-8')) != localpart: if types.contains_invalid_mxid_characters(localpart):
raise SynapseError( raise SynapseError(
400, 400,
"User ID can only contain characters a-z, 0-9, or '=_-./'", "User ID can only contain characters a-z, 0-9, or '=_-./'",
@ -253,7 +253,7 @@ class RegistrationHandler(BaseHandler):
""" """
Registers email_id as SAML2 Based Auth. Registers email_id as SAML2 Based Auth.
""" """
if urllib.quote(localpart) != localpart: if types.contains_invalid_mxid_characters(localpart):
raise SynapseError( raise SynapseError(
400, 400,
"User ID can only contain characters a-z, 0-9, or '=_-./'", "User ID can only contain characters a-z, 0-9, or '=_-./'",

View file

@ -359,7 +359,7 @@ class RegisterRestServlet(ClientV1RestServlet):
if compare_digest(want_mac, got_mac): if compare_digest(want_mac, got_mac):
handler = self.handlers.registration_handler handler = self.handlers.registration_handler
user_id, token = yield handler.register( user_id, token = yield handler.register(
localpart=user, localpart=user.lower(),
password=password, password=password,
admin=bool(admin), admin=bool(admin),
) )

View file

@ -224,6 +224,12 @@ class RegisterRestServlet(RestServlet):
# 'user' key not 'username'). Since this is a new addition, we'll # 'user' key not 'username'). Since this is a new addition, we'll
# fallback to 'username' if they gave one. # fallback to 'username' if they gave one.
desired_username = body.get("user", desired_username) desired_username = body.get("user", desired_username)
# XXX we should check that desired_username is valid. Currently
# we give appservices carte blanche for any insanity in mxids,
# because the IRC bridges rely on being able to register stupid
# IDs.
access_token = get_access_token_from_request(request) access_token = get_access_token_from_request(request)
if isinstance(desired_username, basestring): if isinstance(desired_username, basestring):
@ -233,6 +239,15 @@ class RegisterRestServlet(RestServlet):
defer.returnValue((200, result)) # we throw for non 200 responses defer.returnValue((200, result)) # we throw for non 200 responses
return return
# for either shared secret or regular registration, downcase the
# provided username before attempting to register it. This should mean
# that people who try to register with upper-case in their usernames
# don't get a nasty surprise. (Note that we treat username
# case-insenstively in login, so they are free to carry on imagining
# that their username is CrAzYh4cKeR if that keeps them happy)
if desired_username is not None:
desired_username = desired_username.lower()
# == Shared Secret Registration == (e.g. create new user scripts) # == Shared Secret Registration == (e.g. create new user scripts)
if 'mac' in body: if 'mac' in body:
# FIXME: Should we really be determining if this is shared secret # FIXME: Should we really be determining if this is shared secret
@ -336,6 +351,9 @@ class RegisterRestServlet(RestServlet):
new_password = params.get("password", None) new_password = params.get("password", None)
guest_access_token = params.get("guest_access_token", None) guest_access_token = params.get("guest_access_token", None)
if desired_username is not None:
desired_username = desired_username.lower()
(registered_user_id, _) = yield self.registration_handler.register( (registered_user_id, _) = yield self.registration_handler.register(
localpart=desired_username, localpart=desired_username,
password=new_password, password=new_password,
@ -417,13 +435,22 @@ class RegisterRestServlet(RestServlet):
def _do_shared_secret_registration(self, username, password, body): def _do_shared_secret_registration(self, username, password, body):
if not self.hs.config.registration_shared_secret: if not self.hs.config.registration_shared_secret:
raise SynapseError(400, "Shared secret registration is not enabled") raise SynapseError(400, "Shared secret registration is not enabled")
if not username:
raise SynapseError(
400, "username must be specified", errcode=Codes.BAD_JSON,
)
user = username.encode("utf-8") # use the username from the original request rather than the
# downcased one in `username` for the mac calculation
user = body["username"].encode("utf-8")
# str() because otherwise hmac complains that 'unicode' does not # str() because otherwise hmac complains that 'unicode' does not
# have the buffer interface # have the buffer interface
got_mac = str(body["mac"]) got_mac = str(body["mac"])
# FIXME this is different to the /v1/register endpoint, which
# includes the password and admin flag in the hashed text. Why are
# these different?
want_mac = hmac.new( want_mac = hmac.new(
key=self.hs.config.registration_shared_secret, key=self.hs.config.registration_shared_secret,
msg=user, msg=user,

View file

@ -16,6 +16,8 @@ import logging
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.caches.descriptors import Cache from synapse.util.caches.descriptors import Cache
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
import synapse.metrics import synapse.metrics
@ -178,6 +180,10 @@ class SQLBaseStore(object):
self._get_event_cache = Cache("*getEvent*", keylen=3, self._get_event_cache = Cache("*getEvent*", keylen=3,
max_entries=hs.config.event_cache_size) max_entries=hs.config.event_cache_size)
self._state_group_cache = DictionaryCache(
"*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR
)
self._event_fetch_lock = threading.Condition() self._event_fetch_lock = threading.Condition()
self._event_fetch_list = [] self._event_fetch_list = []
self._event_fetch_ongoing = 0 self._event_fetch_ongoing = 0

View file

@ -63,7 +63,7 @@ class AccountDataStore(SQLBaseStore):
"get_account_data_for_user", get_account_data_for_user_txn "get_account_data_for_user", get_account_data_for_user_txn
) )
@cachedInlineCallbacks(num_args=2) @cachedInlineCallbacks(num_args=2, max_entries=5000)
def get_global_account_data_by_type_for_user(self, data_type, user_id): def get_global_account_data_by_type_for_user(self, data_type, user_id):
""" """
Returns: Returns:

View file

@ -13,17 +13,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections import namedtuple from ._base import SQLBaseStore
import logging from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches import intern_string
from synapse.util.stringutils import to_ascii
from synapse.storage.engines import PostgresEngine
from twisted.internet import defer from twisted.internet import defer
from collections import namedtuple
from synapse.storage.engines import PostgresEngine import logging
from synapse.util.caches import intern_string, CACHE_SIZE_FACTOR
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.stringutils import to_ascii
from ._base import SQLBaseStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -82,10 +81,6 @@ class StateStore(SQLBaseStore):
where_clause="type='m.room.member'", where_clause="type='m.room.member'",
) )
self._state_group_cache = DictionaryCache(
"*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR
)
@cached(max_entries=100000, iterable=True) @cached(max_entries=100000, iterable=True)
def get_current_state_ids(self, room_id): def get_current_state_ids(self, room_id):
"""Get the current state event ids for a room based on the """Get the current state event ids for a room based on the