forked from MirrorHub/synapse
Convert ReadWriteLock to async/await. (#8202)
This commit is contained in:
parent
b4826d6eb1
commit
d2ac767de2
4 changed files with 39 additions and 33 deletions
1
changelog.d/8202.misc
Normal file
1
changelog.d/8202.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Convert various parts of the codebase to async/await.
|
|
@ -14,15 +14,18 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from twisted.python.failure import Failure
|
from twisted.python.failure import Failure
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
|
from synapse.api.filtering import Filter
|
||||||
from synapse.logging.context import run_in_background
|
from synapse.logging.context import run_in_background
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.storage.state import StateFilter
|
from synapse.storage.state import StateFilter
|
||||||
from synapse.types import RoomStreamToken
|
from synapse.streams.config import PaginationConfig
|
||||||
|
from synapse.types import Requester, RoomStreamToken
|
||||||
from synapse.util.async_helpers import ReadWriteLock
|
from synapse.util.async_helpers import ReadWriteLock
|
||||||
from synapse.util.stringutils import random_string
|
from synapse.util.stringutils import random_string
|
||||||
from synapse.visibility import filter_events_for_client
|
from synapse.visibility import filter_events_for_client
|
||||||
|
@ -247,15 +250,16 @@ class PaginationHandler(object):
|
||||||
)
|
)
|
||||||
return purge_id
|
return purge_id
|
||||||
|
|
||||||
async def _purge_history(self, purge_id, room_id, token, delete_local_events):
|
async def _purge_history(
|
||||||
|
self, purge_id: str, room_id: str, token: str, delete_local_events: bool
|
||||||
|
) -> None:
|
||||||
"""Carry out a history purge on a room.
|
"""Carry out a history purge on a room.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
purge_id (str): The id for this purge
|
purge_id: The id for this purge
|
||||||
room_id (str): The room to purge from
|
room_id: The room to purge from
|
||||||
token (str): topological token to delete events before
|
token: topological token to delete events before
|
||||||
delete_local_events (bool): True to delete local events as well as
|
delete_local_events: True to delete local events as well as remote ones
|
||||||
remote ones
|
|
||||||
"""
|
"""
|
||||||
self._purges_in_progress_by_room.add(room_id)
|
self._purges_in_progress_by_room.add(room_id)
|
||||||
try:
|
try:
|
||||||
|
@ -291,9 +295,9 @@ class PaginationHandler(object):
|
||||||
"""
|
"""
|
||||||
return self._purges_by_id.get(purge_id)
|
return self._purges_by_id.get(purge_id)
|
||||||
|
|
||||||
async def purge_room(self, room_id):
|
async def purge_room(self, room_id: str) -> None:
|
||||||
"""Purge the given room from the database"""
|
"""Purge the given room from the database"""
|
||||||
with (await self.pagination_lock.write(room_id)):
|
with await self.pagination_lock.write(room_id):
|
||||||
# check we know about the room
|
# check we know about the room
|
||||||
await self.store.get_room_version_id(room_id)
|
await self.store.get_room_version_id(room_id)
|
||||||
|
|
||||||
|
@ -307,23 +311,22 @@ class PaginationHandler(object):
|
||||||
|
|
||||||
async def get_messages(
|
async def get_messages(
|
||||||
self,
|
self,
|
||||||
requester,
|
requester: Requester,
|
||||||
room_id=None,
|
room_id: Optional[str] = None,
|
||||||
pagin_config=None,
|
pagin_config: Optional[PaginationConfig] = None,
|
||||||
as_client_event=True,
|
as_client_event: bool = True,
|
||||||
event_filter=None,
|
event_filter: Optional[Filter] = None,
|
||||||
):
|
) -> Dict[str, Any]:
|
||||||
"""Get messages in a room.
|
"""Get messages in a room.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
requester (Requester): The user requesting messages.
|
requester: The user requesting messages.
|
||||||
room_id (str): The room they want messages from.
|
room_id: The room they want messages from.
|
||||||
pagin_config (synapse.api.streams.PaginationConfig): The pagination
|
pagin_config: The pagination config rules to apply, if any.
|
||||||
config rules to apply, if any.
|
as_client_event: True to get events in client-server format.
|
||||||
as_client_event (bool): True to get events in client-server format.
|
event_filter: Filter to apply to results or None
|
||||||
event_filter (Filter): Filter to apply to results or None
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: Pagination API results
|
Pagination API results
|
||||||
"""
|
"""
|
||||||
user_id = requester.user.to_string()
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
@ -343,7 +346,7 @@ class PaginationHandler(object):
|
||||||
|
|
||||||
source_config = pagin_config.get_source_config("room")
|
source_config = pagin_config.get_source_config("room")
|
||||||
|
|
||||||
with (await self.pagination_lock.read(room_id)):
|
with await self.pagination_lock.read(room_id):
|
||||||
(
|
(
|
||||||
membership,
|
membership,
|
||||||
member_event_id,
|
member_event_id,
|
||||||
|
|
|
@ -20,6 +20,7 @@ from contextlib import contextmanager
|
||||||
from typing import Dict, Sequence, Set, Union
|
from typing import Dict, Sequence, Set, Union
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
from typing_extensions import ContextManager
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from twisted.internet.defer import CancelledError
|
from twisted.internet.defer import CancelledError
|
||||||
|
@ -338,11 +339,11 @@ class Linearizer(object):
|
||||||
|
|
||||||
|
|
||||||
class ReadWriteLock(object):
|
class ReadWriteLock(object):
|
||||||
"""A deferred style read write lock.
|
"""An async read write lock.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
with (yield read_write_lock.read("test_key")):
|
with await read_write_lock.read("test_key"):
|
||||||
# do some work
|
# do some work
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -365,8 +366,7 @@ class ReadWriteLock(object):
|
||||||
# Latest writer queued
|
# Latest writer queued
|
||||||
self.key_to_current_writer = {} # type: Dict[str, defer.Deferred]
|
self.key_to_current_writer = {} # type: Dict[str, defer.Deferred]
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def read(self, key: str) -> ContextManager:
|
||||||
def read(self, key):
|
|
||||||
new_defer = defer.Deferred()
|
new_defer = defer.Deferred()
|
||||||
|
|
||||||
curr_readers = self.key_to_current_readers.setdefault(key, set())
|
curr_readers = self.key_to_current_readers.setdefault(key, set())
|
||||||
|
@ -376,7 +376,8 @@ class ReadWriteLock(object):
|
||||||
|
|
||||||
# We wait for the latest writer to finish writing. We can safely ignore
|
# We wait for the latest writer to finish writing. We can safely ignore
|
||||||
# any existing readers... as they're readers.
|
# any existing readers... as they're readers.
|
||||||
yield make_deferred_yieldable(curr_writer)
|
if curr_writer:
|
||||||
|
await make_deferred_yieldable(curr_writer)
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def _ctx_manager():
|
def _ctx_manager():
|
||||||
|
@ -388,8 +389,7 @@ class ReadWriteLock(object):
|
||||||
|
|
||||||
return _ctx_manager()
|
return _ctx_manager()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def write(self, key: str) -> ContextManager:
|
||||||
def write(self, key):
|
|
||||||
new_defer = defer.Deferred()
|
new_defer = defer.Deferred()
|
||||||
|
|
||||||
curr_readers = self.key_to_current_readers.get(key, set())
|
curr_readers = self.key_to_current_readers.get(key, set())
|
||||||
|
@ -405,7 +405,7 @@ class ReadWriteLock(object):
|
||||||
curr_readers.clear()
|
curr_readers.clear()
|
||||||
self.key_to_current_writer[key] = new_defer
|
self.key_to_current_writer[key] = new_defer
|
||||||
|
|
||||||
yield make_deferred_yieldable(defer.gatherResults(to_wait_on))
|
await make_deferred_yieldable(defer.gatherResults(to_wait_on))
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def _ctx_manager():
|
def _ctx_manager():
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.util.async_helpers import ReadWriteLock
|
from synapse.util.async_helpers import ReadWriteLock
|
||||||
|
|
||||||
|
@ -43,6 +44,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
|
||||||
rwlock.read(key), # 5
|
rwlock.read(key), # 5
|
||||||
rwlock.write(key), # 6
|
rwlock.write(key), # 6
|
||||||
]
|
]
|
||||||
|
ds = [defer.ensureDeferred(d) for d in ds]
|
||||||
|
|
||||||
self._assert_called_before_not_after(ds, 2)
|
self._assert_called_before_not_after(ds, 2)
|
||||||
|
|
||||||
|
@ -73,12 +75,12 @@ class ReadWriteLockTestCase(unittest.TestCase):
|
||||||
with ds[6].result:
|
with ds[6].result:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
d = rwlock.write(key)
|
d = defer.ensureDeferred(rwlock.write(key))
|
||||||
self.assertTrue(d.called)
|
self.assertTrue(d.called)
|
||||||
with d.result:
|
with d.result:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
d = rwlock.read(key)
|
d = defer.ensureDeferred(rwlock.read(key))
|
||||||
self.assertTrue(d.called)
|
self.assertTrue(d.called)
|
||||||
with d.result:
|
with d.result:
|
||||||
pass
|
pass
|
||||||
|
|
Loading…
Reference in a new issue