Use a postgres sequence to generate guest user IDs

This commit is contained in:
Richard van der Hoff 2020-07-16 11:46:44 +01:00
parent 3c36ae17a5
commit c445bc0cad
4 changed files with 83 additions and 52 deletions

View file

@ -48,6 +48,7 @@ from synapse.storage.data_stores.main.media_repository import (
)
from synapse.storage.data_stores.main.registration import (
RegistrationBackgroundUpdateStore,
find_max_generated_user_id_localpart,
)
from synapse.storage.data_stores.main.room import RoomBackgroundUpdateStore
from synapse.storage.data_stores.main.roommember import RoomMemberBackgroundUpdateStore
@ -622,8 +623,10 @@ class Porter(object):
)
)
# Step 5. Do final post-processing
# Step 5. Set up sequences
self.progress.set_state("Setting up sequence generators")
await self._setup_state_group_id_seq()
await self._setup_user_id_seq()
self.progress.done()
except Exception as e:
@ -793,6 +796,13 @@ class Porter(object):
return self.postgres_store.db.runInteraction("setup_state_group_id_seq", r)
def _setup_user_id_seq(self):
def r(txn):
next_id = find_max_generated_user_id_localpart(txn) + 1
txn.execute("ALTER SEQUENCE user_id_seq RESTART WITH %s", (next_id,))
return self.postgres_store.db.runInteraction("setup_user_id_seq", r)
##############################################
# The following is simply UI stuff

View file

@ -28,7 +28,6 @@ from synapse.replication.http.register import (
)
from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, UserID, create_requester
from synapse.util.async_helpers import Linearizer
from ._base import BaseHandler
@ -50,14 +49,7 @@ class RegistrationHandler(BaseHandler):
self.user_directory_handler = hs.get_user_directory_handler()
self.identity_handler = self.hs.get_handlers().identity_handler
self.ratelimiter = hs.get_registration_ratelimiter()
self._next_generated_user_id = None
self.macaroon_gen = hs.get_macaroon_generator()
self._generate_user_id_linearizer = Linearizer(
name="_generate_user_id_linearizer"
)
self._server_notices_mxid = hs.config.server_notices_mxid
if hs.config.worker_app:
@ -219,7 +211,7 @@ class RegistrationHandler(BaseHandler):
if fail_count > 10:
raise SynapseError(500, "Unable to find a suitable guest user ID")
localpart = await self._generate_user_id()
localpart = await self.store.generate_user_id()
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
self.check_user_id_not_appservice_exclusive(user_id)
@ -510,18 +502,6 @@ class RegistrationHandler(BaseHandler):
errcode=Codes.EXCLUSIVE,
)
async def _generate_user_id(self):
if self._next_generated_user_id is None:
with await self._generate_user_id_linearizer.queue(()):
if self._next_generated_user_id is None:
self._next_generated_user_id = (
await self.store.find_next_generated_user_id_localpart()
)
id = self._next_generated_user_id
self._next_generated_user_id += 1
return str(id)
def check_registration_ratelimit(self, address):
"""A simple helper method to check whether the registration rate limit has been hit
for a given IP address

View file

@ -27,6 +27,8 @@ from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidati
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import Database
from synapse.storage.types import Cursor
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import UserID
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
@ -42,6 +44,10 @@ class RegistrationWorkerStore(SQLBaseStore):
self.config = hs.config
self.clock = hs.get_clock()
self._user_id_seq = build_sequence_generator(
database.engine, find_max_generated_user_id_localpart, "user_id_seq",
)
@cached()
def get_user_by_id(self, user_id):
return self.db.simple_select_one(
@ -481,39 +487,17 @@ class RegistrationWorkerStore(SQLBaseStore):
ret = yield self.db.runInteraction("count_real_users", _count_users)
return ret
@defer.inlineCallbacks
def find_next_generated_user_id_localpart(self):
async def generate_user_id(self) -> str:
"""Generate a suitable localpart for a guest user
Returns: a (hopefully) free localpart
"""
Gets the localpart of the next generated user ID.
Generated user IDs are integers, so we find the largest integer user ID
already taken and return that plus one.
"""
def _find_next_generated_user_id(txn):
# We bound between '@0' and '@a' to avoid pulling the entire table
# out.
txn.execute("SELECT name FROM users WHERE '@0' <= name AND name < '@a'")
regex = re.compile(r"^@(\d+):")
max_found = 0
for (user_id,) in txn:
match = regex.search(user_id)
if match:
max_found = max(int(match.group(1)), max_found)
return max_found + 1
return (
(
yield self.db.runInteraction(
"find_next_generated_user_id", _find_next_generated_user_id
)
)
next_id = await self.db.runInteraction(
"generate_user_id", self._user_id_seq.get_next_id_txn
)
return str(next_id)
async def get_user_id_by_threepid(self, medium: str, address: str) -> Optional[str]:
"""Returns user id from threepid
@ -1573,3 +1557,26 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
keyvalues={"user_id": user_id},
values={"expiration_ts_ms": expiration_ts, "email_sent": False},
)
def find_max_generated_user_id_localpart(cur: Cursor) -> int:
"""
Gets the localpart of the max current generated user ID.
Generated user IDs are integers, so we find the largest integer user ID
already taken and return that.
"""
# We bound between '@0' and '@a' to avoid pulling the entire table
# out.
cur.execute("SELECT name FROM users WHERE '@0' <= name AND name < '@a'")
regex = re.compile(r"^@(\d+):")
max_found = 0
for (user_id,) in cur:
match = regex.search(user_id)
if match:
max_found = max(int(match.group(1)), max_found)
return max_found

View file

@ -0,0 +1,34 @@
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Adds a postgres SEQUENCE for generating guest user IDs.
"""
from synapse.storage.data_stores.main.registration import (
find_max_generated_user_id_localpart,
)
from synapse.storage.engines import PostgresEngine
def run_create(cur, database_engine, *args, **kwargs):
if not isinstance(database_engine, PostgresEngine):
return
next_id = find_max_generated_user_id_localpart(cur) + 1
cur.execute("CREATE SEQUENCE user_id_seq START WITH %s", (next_id,))
def run_upgrade(*args, **kwargs):
pass