0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2025-01-01 22:33:51 +01:00

Merge pull request #7866 from matrix-org/rav/fix_guest_user_id

Fix guest user registration with lots of client readers
This commit is contained in:
Richard van der Hoff 2020-07-16 13:54:45 +01:00 committed by GitHub
commit a827838706
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 197 additions and 82 deletions

1
changelog.d/7866.bugfix Normal file
View file

@ -0,0 +1 @@
Fix 'Unable to find a suitable guest user ID' error when using multiple client_reader workers.

View file

@ -48,6 +48,7 @@ from synapse.storage.data_stores.main.media_repository import (
)
from synapse.storage.data_stores.main.registration import (
RegistrationBackgroundUpdateStore,
find_max_generated_user_id_localpart,
)
from synapse.storage.data_stores.main.room import RoomBackgroundUpdateStore
from synapse.storage.data_stores.main.roommember import RoomMemberBackgroundUpdateStore
@ -622,8 +623,10 @@ class Porter(object):
)
)
# Step 5. Do final post-processing
# Step 5. Set up sequences
self.progress.set_state("Setting up sequence generators")
await self._setup_state_group_id_seq()
await self._setup_user_id_seq()
self.progress.done()
except Exception as e:
@ -793,6 +796,13 @@ class Porter(object):
return self.postgres_store.db.runInteraction("setup_state_group_id_seq", r)
def _setup_user_id_seq(self):
def r(txn):
next_id = find_max_generated_user_id_localpart(txn) + 1
txn.execute("ALTER SEQUENCE user_id_seq RESTART WITH %s", (next_id,))
return self.postgres_store.db.runInteraction("setup_user_id_seq", r)
##############################################
# The following is simply UI stuff

View file

@ -28,7 +28,6 @@ from synapse.replication.http.register import (
)
from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, UserID, create_requester
from synapse.util.async_helpers import Linearizer
from ._base import BaseHandler
@ -50,14 +49,7 @@ class RegistrationHandler(BaseHandler):
self.user_directory_handler = hs.get_user_directory_handler()
self.identity_handler = self.hs.get_handlers().identity_handler
self.ratelimiter = hs.get_registration_ratelimiter()
self._next_generated_user_id = None
self.macaroon_gen = hs.get_macaroon_generator()
self._generate_user_id_linearizer = Linearizer(
name="_generate_user_id_linearizer"
)
self._server_notices_mxid = hs.config.server_notices_mxid
if hs.config.worker_app:
@ -219,7 +211,7 @@ class RegistrationHandler(BaseHandler):
if fail_count > 10:
raise SynapseError(500, "Unable to find a suitable guest user ID")
localpart = await self._generate_user_id()
localpart = await self.store.generate_user_id()
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
self.check_user_id_not_appservice_exclusive(user_id)
@ -510,18 +502,6 @@ class RegistrationHandler(BaseHandler):
errcode=Codes.EXCLUSIVE,
)
async def _generate_user_id(self):
if self._next_generated_user_id is None:
with await self._generate_user_id_linearizer.queue(()):
if self._next_generated_user_id is None:
self._next_generated_user_id = (
await self.store.find_next_generated_user_id_localpart()
)
id = self._next_generated_user_id
self._next_generated_user_id += 1
return str(id)
def check_registration_ratelimit(self, address):
"""A simple helper method to check whether the registration rate limit has been hit
for a given IP address

View file

@ -27,6 +27,8 @@ from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidati
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import Database
from synapse.storage.types import Cursor
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import UserID
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
@ -42,6 +44,10 @@ class RegistrationWorkerStore(SQLBaseStore):
self.config = hs.config
self.clock = hs.get_clock()
self._user_id_seq = build_sequence_generator(
database.engine, find_max_generated_user_id_localpart, "user_id_seq",
)
@cached()
def get_user_by_id(self, user_id):
return self.db.simple_select_one(
@ -481,39 +487,17 @@ class RegistrationWorkerStore(SQLBaseStore):
ret = yield self.db.runInteraction("count_real_users", _count_users)
return ret
@defer.inlineCallbacks
def find_next_generated_user_id_localpart(self):
async def generate_user_id(self) -> str:
"""Generate a suitable localpart for a guest user
Returns: a (hopefully) free localpart
"""
Gets the localpart of the next generated user ID.
Generated user IDs are integers, so we find the largest integer user ID
already taken and return that plus one.
"""
def _find_next_generated_user_id(txn):
# We bound between '@0' and '@a' to avoid pulling the entire table
# out.
txn.execute("SELECT name FROM users WHERE '@0' <= name AND name < '@a'")
regex = re.compile(r"^@(\d+):")
max_found = 0
for (user_id,) in txn:
match = regex.search(user_id)
if match:
max_found = max(int(match.group(1)), max_found)
return max_found + 1
return (
(
yield self.db.runInteraction(
"find_next_generated_user_id", _find_next_generated_user_id
)
)
next_id = await self.db.runInteraction(
"generate_user_id", self._user_id_seq.get_next_id_txn
)
return str(next_id)
async def get_user_id_by_threepid(self, medium: str, address: str) -> Optional[str]:
"""Returns user id from threepid
@ -1573,3 +1557,26 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
keyvalues={"user_id": user_id},
values={"expiration_ts_ms": expiration_ts, "email_sent": False},
)
def find_max_generated_user_id_localpart(cur: Cursor) -> int:
"""
Gets the localpart of the max current generated user ID.
Generated user IDs are integers, so we find the largest integer user ID
already taken and return that.
"""
# We bound between '@0' and '@a' to avoid pulling the entire table
# out.
cur.execute("SELECT name FROM users WHERE '@0' <= name AND name < '@a'")
regex = re.compile(r"^@(\d+):")
max_found = 0
for (user_id,) in cur:
match = regex.search(user_id)
if match:
max_found = max(int(match.group(1)), max_found)
return max_found

View file

@ -0,0 +1,34 @@
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Adds a postgres SEQUENCE for generating guest user IDs.
"""
from synapse.storage.data_stores.main.registration import (
find_max_generated_user_id_localpart,
)
from synapse.storage.engines import PostgresEngine
def run_create(cur, database_engine, *args, **kwargs):
if not isinstance(database_engine, PostgresEngine):
return
next_id = find_max_generated_user_id_localpart(cur) + 1
cur.execute("CREATE SEQUENCE user_id_seq START WITH %s", (next_id,))
def run_upgrade(*args, **kwargs):
pass

View file

@ -24,6 +24,8 @@ from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore
from synapse.storage.database import Database
from synapse.storage.state import StateFilter
from synapse.storage.types import Cursor
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import StateMap
from synapse.util.caches.descriptors import cached
from synapse.util.caches.dictionary_cache import DictionaryCache
@ -92,6 +94,14 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
"*stateGroupMembersCache*", 500000,
)
def get_max_state_group_txn(txn: Cursor):
txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
return txn.fetchone()[0]
self._state_group_seq_gen = build_sequence_generator(
self.database_engine, get_max_state_group_txn, "state_group_id_seq"
)
@cached(max_entries=10000, iterable=True)
def get_state_group_delta(self, state_group):
"""Given a state group try to return a previous group and a delta between
@ -386,7 +396,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
# AFAIK, this can never happen
raise Exception("current_state_ids cannot be None")
state_group = self.database_engine.get_next_state_group_id(txn)
state_group = self._state_group_seq_gen.get_next_id_txn(txn)
self.db.simple_insert_txn(
txn,

View file

@ -91,12 +91,6 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
def lock_table(self, txn, table: str) -> None:
...
@abc.abstractmethod
def get_next_state_group_id(self, txn) -> int:
"""Returns an int that can be used as a new state_group ID
"""
...
@property
@abc.abstractmethod
def server_version(self) -> str:

View file

@ -154,12 +154,6 @@ class PostgresEngine(BaseDatabaseEngine):
def lock_table(self, txn, table):
txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,))
def get_next_state_group_id(self, txn):
"""Returns an int that can be used as a new state_group ID
"""
txn.execute("SELECT nextval('state_group_id_seq')")
return txn.fetchone()[0]
@property
def server_version(self):
"""Returns a string giving the server version. For example: '8.1.5'

View file

@ -96,19 +96,6 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
def lock_table(self, txn, table):
return
def get_next_state_group_id(self, txn):
"""Returns an int that can be used as a new state_group ID
"""
# We do application locking here since if we're using sqlite then
# we are a single process synapse.
with self._current_state_group_id_lock:
if self._current_state_group_id is None:
txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
self._current_state_group_id = txn.fetchone()[0]
self._current_state_group_id += 1
return self._current_state_group_id
@property
def server_version(self):
"""Gets a string giving the server version. For example: '3.22.0'

View file

@ -21,6 +21,7 @@ from typing import Dict, Set, Tuple
from typing_extensions import Deque
from synapse.storage.database import Database, LoggingTransaction
from synapse.storage.util.sequence import PostgresSequenceGenerator
class IdGenerator(object):
@ -247,7 +248,6 @@ class MultiWriterIdGenerator:
):
self._db = db
self._instance_name = instance_name
self._sequence_name = sequence_name
# We lock as some functions may be called from DB threads.
self._lock = threading.Lock()
@ -260,6 +260,8 @@ class MultiWriterIdGenerator:
# should be less than the minimum of this set (if not empty).
self._unfinished_ids = set() # type: Set[int]
self._sequence_gen = PostgresSequenceGenerator(sequence_name)
def _load_current_ids(
self, db_conn, table: str, instance_column: str, id_column: str
) -> Dict[str, int]:
@ -283,9 +285,7 @@ class MultiWriterIdGenerator:
return current_positions
def _load_next_id_txn(self, txn):
txn.execute("SELECT nextval(?)", (self._sequence_name,))
(next_id,) = txn.fetchone()
return next_id
return self._sequence_gen.get_next_id_txn(txn)
async def get_next(self):
"""

View file

@ -0,0 +1,98 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import threading
from typing import Callable, Optional
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
from synapse.storage.types import Cursor
class SequenceGenerator(metaclass=abc.ABCMeta):
"""A class which generates a unique sequence of integers"""
@abc.abstractmethod
def get_next_id_txn(self, txn: Cursor) -> int:
"""Gets the next ID in the sequence"""
...
class PostgresSequenceGenerator(SequenceGenerator):
"""An implementation of SequenceGenerator which uses a postgres sequence"""
def __init__(self, sequence_name: str):
self._sequence_name = sequence_name
def get_next_id_txn(self, txn: Cursor) -> int:
txn.execute("SELECT nextval(?)", (self._sequence_name,))
return txn.fetchone()[0]
GetFirstCallbackType = Callable[[Cursor], int]
class LocalSequenceGenerator(SequenceGenerator):
"""An implementation of SequenceGenerator which uses local locking
This only works reliably if there are no other worker processes generating IDs at
the same time.
"""
def __init__(self, get_first_callback: GetFirstCallbackType):
"""
Args:
get_first_callback: a callback which is called on the first call to
get_next_id_txn; should return the curreent maximum id
"""
# the callback. this is cleared after it is called, so that it can be GCed.
self._callback = get_first_callback # type: Optional[GetFirstCallbackType]
# The current max value, or None if we haven't looked in the DB yet.
self._current_max_id = None # type: Optional[int]
self._lock = threading.Lock()
def get_next_id_txn(self, txn: Cursor) -> int:
# We do application locking here since if we're using sqlite then
# we are a single process synapse.
with self._lock:
if self._current_max_id is None:
assert self._callback is not None
self._current_max_id = self._callback(txn)
self._callback = None
self._current_max_id += 1
return self._current_max_id
def build_sequence_generator(
database_engine: BaseDatabaseEngine,
get_first_callback: GetFirstCallbackType,
sequence_name: str,
) -> SequenceGenerator:
"""Get the best impl of SequenceGenerator available
This uses PostgresSequenceGenerator on postgres, and a locally-locked impl on
sqlite.
Args:
database_engine: the database engine we are connected to
get_first_callback: a callback which gets the next sequence ID. Used if
we're on sqlite.
sequence_name: the name of a postgres sequence to use.
"""
if isinstance(database_engine, PostgresEngine):
return PostgresSequenceGenerator(sequence_name)
else:
return LocalSequenceGenerator(get_first_callback)