mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 13:13:50 +01:00
Make numeric user_id checker start at @0, and don't ratelimit on checking (#6338)
This commit is contained in:
commit
a6ebef1bfd
4 changed files with 40 additions and 21 deletions
1
changelog.d/6338.bugfix
Normal file
1
changelog.d/6338.bugfix
Normal file
|
@ -0,0 +1 @@
|
|||
Prevent the server taking a long time to start up when guest registration is enabled.
|
|
@ -24,7 +24,6 @@ from synapse.api.errors import (
|
|||
AuthError,
|
||||
Codes,
|
||||
ConsentNotGivenError,
|
||||
LimitExceededError,
|
||||
RegistrationError,
|
||||
SynapseError,
|
||||
)
|
||||
|
@ -168,6 +167,7 @@ class RegistrationHandler(BaseHandler):
|
|||
Raises:
|
||||
RegistrationError if there was a problem registering.
|
||||
"""
|
||||
yield self.check_registration_ratelimit(address)
|
||||
|
||||
yield self.auth.check_auth_blocking(threepid=threepid)
|
||||
password_hash = None
|
||||
|
@ -217,8 +217,13 @@ class RegistrationHandler(BaseHandler):
|
|||
|
||||
else:
|
||||
# autogen a sequential user ID
|
||||
fail_count = 0
|
||||
user = None
|
||||
while not user:
|
||||
# Fail after being unable to find a suitable ID a few times
|
||||
if fail_count > 10:
|
||||
raise SynapseError(500, "Unable to find a suitable guest user ID")
|
||||
|
||||
localpart = yield self._generate_user_id()
|
||||
user = UserID(localpart, self.hs.hostname)
|
||||
user_id = user.to_string()
|
||||
|
@ -233,10 +238,14 @@ class RegistrationHandler(BaseHandler):
|
|||
create_profile_with_displayname=default_display_name,
|
||||
address=address,
|
||||
)
|
||||
|
||||
# Successfully registered
|
||||
break
|
||||
except SynapseError:
|
||||
# if user id is taken, just generate another
|
||||
user = None
|
||||
user_id = None
|
||||
fail_count += 1
|
||||
|
||||
if not self.hs.config.user_consent_at_registration:
|
||||
yield self._auto_join_rooms(user_id)
|
||||
|
@ -414,6 +423,29 @@ class RegistrationHandler(BaseHandler):
|
|||
ratelimit=False,
|
||||
)
|
||||
|
||||
def check_registration_ratelimit(self, address):
|
||||
"""A simple helper method to check whether the registration rate limit has been hit
|
||||
for a given IP address
|
||||
|
||||
Args:
|
||||
address (str|None): the IP address used to perform the registration. If this is
|
||||
None, no ratelimiting will be performed.
|
||||
|
||||
Raises:
|
||||
LimitExceededError: If the rate limit has been exceeded.
|
||||
"""
|
||||
if not address:
|
||||
return
|
||||
|
||||
time_now = self.clock.time()
|
||||
|
||||
self.ratelimiter.ratelimit(
|
||||
address,
|
||||
time_now_s=time_now,
|
||||
rate_hz=self.hs.config.rc_registration.per_second,
|
||||
burst_count=self.hs.config.rc_registration.burst_count,
|
||||
)
|
||||
|
||||
def register_with_store(
|
||||
self,
|
||||
user_id,
|
||||
|
@ -446,22 +478,6 @@ class RegistrationHandler(BaseHandler):
|
|||
Returns:
|
||||
Deferred
|
||||
"""
|
||||
# Don't rate limit for app services
|
||||
if appservice_id is None and address is not None:
|
||||
time_now = self.clock.time()
|
||||
|
||||
allowed, time_allowed = self.ratelimiter.can_do_action(
|
||||
address,
|
||||
time_now_s=time_now,
|
||||
rate_hz=self.hs.config.rc_registration.per_second,
|
||||
burst_count=self.hs.config.rc_registration.burst_count,
|
||||
)
|
||||
|
||||
if not allowed:
|
||||
raise LimitExceededError(
|
||||
retry_after_ms=int(1000 * (time_allowed - time_now))
|
||||
)
|
||||
|
||||
if self.hs.config.worker_app:
|
||||
return self._register_client(
|
||||
user_id=user_id,
|
||||
|
|
|
@ -75,6 +75,8 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
|
|||
async def _handle_request(self, request, user_id):
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
self.registration_handler.check_registration_ratelimit(content["address"])
|
||||
|
||||
await self.registration_handler.register_with_store(
|
||||
user_id=user_id,
|
||||
password_hash=content["password_hash"],
|
||||
|
|
|
@ -488,14 +488,14 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||
we can. Unfortunately, it's possible some of them are already taken by
|
||||
existing users, and there may be gaps in the already taken range. This
|
||||
function returns the start of the first allocatable gap. This is to
|
||||
avoid the case of ID 10000000 being pre-allocated, so us wasting the
|
||||
first (and shortest) many generated user IDs.
|
||||
avoid the case of ID 1000 being pre-allocated and starting at 1001 while
|
||||
0-999 are available.
|
||||
"""
|
||||
|
||||
def _find_next_generated_user_id(txn):
|
||||
# We bound between '@1' and '@a' to avoid pulling the entire table
|
||||
# We bound between '@0' and '@a' to avoid pulling the entire table
|
||||
# out.
|
||||
txn.execute("SELECT name FROM users WHERE '@1' <= name AND name < '@a'")
|
||||
txn.execute("SELECT name FROM users WHERE '@0' <= name AND name < '@a'")
|
||||
|
||||
regex = re.compile(r"^@(\d+):")
|
||||
|
||||
|
|
Loading…
Reference in a new issue