0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-09-30 05:29:00 +02:00

Merge remote-tracking branch 'upstream/release-v1.45'

This commit is contained in:
Tulir Asokan 2021-10-14 13:26:30 +03:00
commit 9cc654965e
15 changed files with 1024 additions and 122 deletions

View file

@ -1,3 +1,26 @@
Synapse 1.45.0rc2 (2021-10-14)
==============================
**Note:** This release candidate [fixes](https://github.com/matrix-org/synapse/issues/11053) the user directory [bug](https://github.com/matrix-org/synapse/issues/11025) present in 1.45.0rc1. However, the [performance issue](https://github.com/matrix-org/synapse/issues/11049) which appeared in v1.44.0 is yet to be resolved.
Bugfixes
--------
- Fix a long-standing bug when using multiple event persister workers where events were not correctly sent down `/sync` due to a race. ([\#11045](https://github.com/matrix-org/synapse/issues/11045))
- Fix a bug introduced in Synapse 1.45.0rc1 where the user directory would stop updating if it processed an event from a
user not in the `users` table. ([\#11053](https://github.com/matrix-org/synapse/issues/11053))
- Fix a bug introduced in Synapse v1.44.0 when logging errors during oEmbed processing. ([\#11061](https://github.com/matrix-org/synapse/issues/11061))
Internal Changes
----------------
- Add an 'approximate difference' method to `StateFilter`. ([\#10825](https://github.com/matrix-org/synapse/issues/10825))
- Fix inconsistent behavior of `get_last_client_by_ip` when reporting data that has not been stored in the database yet. ([\#10970](https://github.com/matrix-org/synapse/issues/10970))
- Fix a bug introduced in Synapse 1.21.0 that causes opentracing and Prometheus metrics for replication requests to be measured incorrectly. ([\#10996](https://github.com/matrix-org/synapse/issues/10996))
- Ensure that cache config tests do not share state. ([\#11036](https://github.com/matrix-org/synapse/issues/11036))
Synapse 1.45.0rc1 (2021-10-12) Synapse 1.45.0rc1 (2021-10-12)
============================== ==============================

6
debian/changelog vendored
View file

@ -1,3 +1,9 @@
matrix-synapse-py3 (1.45.0~rc2) stable; urgency=medium
* New synapse release 1.45.0~rc2.
-- Synapse Packaging team <packages@matrix.org> Thu, 14 Oct 2021 10:58:24 +0100
matrix-synapse-py3 (1.45.0~rc1) stable; urgency=medium matrix-synapse-py3 (1.45.0~rc1) stable; urgency=medium
[ Nick @ Beeper ] [ Nick @ Beeper ]

View file

@ -47,7 +47,7 @@ try:
except ImportError: except ImportError:
pass pass
__version__ = "1.45.0rc1" __version__ = "1.45.0rc2"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when # We import here so that we don't have to install a bunch of deps when

View file

@ -807,6 +807,14 @@ def trace(func=None, opname=None):
result.addCallbacks(call_back, err_back) result.addCallbacks(call_back, err_back)
else: else:
if inspect.isawaitable(result):
logger.error(
"@trace may not have wrapped %s correctly! "
"The function is not async but returned a %s.",
func.__qualname__,
type(result).__name__,
)
scope.__exit__(None, None, None) scope.__exit__(None, None, None)
return result return result

View file

@ -182,85 +182,87 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
) )
@trace(opname="outgoing_replication_request") @trace(opname="outgoing_replication_request")
@outgoing_gauge.track_inprogress()
async def send_request(*, instance_name="master", **kwargs): async def send_request(*, instance_name="master", **kwargs):
if instance_name == local_instance_name: with outgoing_gauge.track_inprogress():
raise Exception("Trying to send HTTP request to self") if instance_name == local_instance_name:
if instance_name == "master": raise Exception("Trying to send HTTP request to self")
host = master_host if instance_name == "master":
port = master_port host = master_host
elif instance_name in instance_map: port = master_port
host = instance_map[instance_name].host elif instance_name in instance_map:
port = instance_map[instance_name].port host = instance_map[instance_name].host
else: port = instance_map[instance_name].port
raise Exception( else:
"Instance %r not in 'instance_map' config" % (instance_name,) raise Exception(
"Instance %r not in 'instance_map' config" % (instance_name,)
)
data = await cls._serialize_payload(**kwargs)
url_args = [
urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS
]
if cls.CACHE:
txn_id = random_string(10)
url_args.append(txn_id)
if cls.METHOD == "POST":
request_func = client.post_json_get_json
elif cls.METHOD == "PUT":
request_func = client.put_json
elif cls.METHOD == "GET":
request_func = client.get_json
else:
# We have already asserted in the constructor that a
# compatible was picked, but lets be paranoid.
raise Exception(
"Unknown METHOD on %s replication endpoint" % (cls.NAME,)
)
uri = "http://%s:%s/_synapse/replication/%s/%s" % (
host,
port,
cls.NAME,
"/".join(url_args),
) )
data = await cls._serialize_payload(**kwargs) try:
# We keep retrying the same request for timeouts. This is so that we
# have a good idea that the request has either succeeded or failed
# on the master, and so whether we should clean up or not.
while True:
headers: Dict[bytes, List[bytes]] = {}
# Add an authorization header, if configured.
if replication_secret:
headers[b"Authorization"] = [
b"Bearer " + replication_secret
]
opentracing.inject_header_dict(headers, check_destination=False)
try:
result = await request_func(uri, data, headers=headers)
break
except RequestTimedOutError:
if not cls.RETRY_ON_TIMEOUT:
raise
url_args = [ logger.warning("%s request timed out; retrying", cls.NAME)
urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS
]
if cls.CACHE: # If we timed out we probably don't need to worry about backing
txn_id = random_string(10) # off too much, but lets just wait a little anyway.
url_args.append(txn_id) await clock.sleep(1)
except HttpResponseException as e:
# We convert to SynapseError as we know that it was a SynapseError
# on the main process that we should send to the client. (And
# importantly, not stack traces everywhere)
_outgoing_request_counter.labels(cls.NAME, e.code).inc()
raise e.to_synapse_error()
except Exception as e:
_outgoing_request_counter.labels(cls.NAME, "ERR").inc()
raise SynapseError(502, "Failed to talk to main process") from e
if cls.METHOD == "POST": _outgoing_request_counter.labels(cls.NAME, 200).inc()
request_func = client.post_json_get_json return result
elif cls.METHOD == "PUT":
request_func = client.put_json
elif cls.METHOD == "GET":
request_func = client.get_json
else:
# We have already asserted in the constructor that a
# compatible was picked, but lets be paranoid.
raise Exception(
"Unknown METHOD on %s replication endpoint" % (cls.NAME,)
)
uri = "http://%s:%s/_synapse/replication/%s/%s" % (
host,
port,
cls.NAME,
"/".join(url_args),
)
try:
# We keep retrying the same request for timeouts. This is so that we
# have a good idea that the request has either succeeded or failed on
# the master, and so whether we should clean up or not.
while True:
headers: Dict[bytes, List[bytes]] = {}
# Add an authorization header, if configured.
if replication_secret:
headers[b"Authorization"] = [b"Bearer " + replication_secret]
opentracing.inject_header_dict(headers, check_destination=False)
try:
result = await request_func(uri, data, headers=headers)
break
except RequestTimedOutError:
if not cls.RETRY_ON_TIMEOUT:
raise
logger.warning("%s request timed out; retrying", cls.NAME)
# If we timed out we probably don't need to worry about backing
# off too much, but lets just wait a little anyway.
await clock.sleep(1)
except HttpResponseException as e:
# We convert to SynapseError as we know that it was a SynapseError
# on the main process that we should send to the client. (And
# importantly, not stack traces everywhere)
_outgoing_request_counter.labels(cls.NAME, e.code).inc()
raise e.to_synapse_error()
except Exception as e:
_outgoing_request_counter.labels(cls.NAME, "ERR").inc()
raise SynapseError(502, "Failed to talk to main process") from e
_outgoing_request_counter.labels(cls.NAME, 200).inc()
return result
return send_request return send_request

View file

@ -191,7 +191,7 @@ class OEmbedProvider:
except Exception as e: except Exception as e:
# Trap any exception and let the code follow as usual. # Trap any exception and let the code follow as usual.
logger.warning(f"Error parsing oEmbed metadata from {url}: {e:r}") logger.warning("Error parsing oEmbed metadata from %s: %r", url, e)
open_graph_response = {} open_graph_response = {}
cache_age = None cache_age = None

View file

@ -538,15 +538,20 @@ class ClientIpStore(ClientIpWorkerStore):
""" """
ret = await super().get_last_client_ip_by_device(user_id, device_id) ret = await super().get_last_client_ip_by_device(user_id, device_id)
# Update what is retrieved from the database with data which is pending insertion. # Update what is retrieved from the database with data which is pending
# insertion, as if it has already been stored in the database.
for key in self._batch_row_update: for key in self._batch_row_update:
uid, access_token, ip = key uid, _access_token, ip = key
if uid == user_id: if uid == user_id:
user_agent, did, last_seen = self._batch_row_update[key] user_agent, did, last_seen = self._batch_row_update[key]
if did is None:
# These updates don't make it to the `devices` table
continue
if not device_id or did == device_id: if not device_id or did == device_id:
ret[(user_id, device_id)] = { ret[(user_id, did)] = {
"user_id": user_id, "user_id": user_id,
"access_token": access_token,
"ip": ip, "ip": ip,
"user_agent": user_agent, "user_agent": user_agent,
"device_id": did, "device_id": did,

View file

@ -26,6 +26,8 @@ from typing import (
cast, cast,
) )
from synapse.api.errors import StoreError
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -383,7 +385,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
"""Certain classes of local user are omitted from the user directory. """Certain classes of local user are omitted from the user directory.
Is this user one of them? Is this user one of them?
""" """
# App service users aren't usually contactable, so exclude them. # We're opting to exclude the appservice sender (user defined by the
# `sender_localpart` in the appservice registration) even though
# technically it could be DM-able. In the future, this could potentially
# be configurable per-appservice whether the appservice sender can be
# contacted.
if self.get_app_service_by_user_id(user) is not None:
return False
# We're opting to exclude appservice users (anyone matching the user
# namespace regex in the appservice registration) even though technically
# they could be DM-able. In the future, this could potentially
# be configurable per-appservice whether the appservice users can be
# contacted.
if self.get_if_app_services_interested_in_user(user): if self.get_if_app_services_interested_in_user(user):
# TODO we might want to make this configurable for each app service # TODO we might want to make this configurable for each app service
return False return False
@ -393,8 +407,14 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
return False return False
# Deactivated users aren't contactable, so should not appear in the user directory. # Deactivated users aren't contactable, so should not appear in the user directory.
if await self.get_user_deactivated_status(user): try:
if await self.get_user_deactivated_status(user):
return False
except StoreError:
# No such user in the users table. No need to do this when calling
# is_support_user---that returns False if the user is missing.
return False return False
return True return True
async def is_room_world_readable_or_publicly_joinable(self, room_id: str) -> bool: async def is_room_world_readable_or_publicly_joinable(self, room_id: str) -> bool:

View file

@ -15,9 +15,11 @@ import logging
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Awaitable, Awaitable,
Collection,
Dict, Dict,
Iterable, Iterable,
List, List,
Mapping,
Optional, Optional,
Set, Set,
Tuple, Tuple,
@ -29,7 +31,7 @@ from frozendict import frozendict
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.events import EventBase from synapse.events import EventBase
from synapse.types import MutableStateMap, StateMap from synapse.types import MutableStateMap, StateKey, StateMap
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import FrozenSet # noqa: used within quoted type hint; flake8 sad from typing import FrozenSet # noqa: used within quoted type hint; flake8 sad
@ -134,6 +136,23 @@ class StateFilter:
include_others=True, include_others=True,
) )
@staticmethod
def freeze(types: Mapping[str, Optional[Collection[str]]], include_others: bool):
"""
Returns a (frozen) StateFilter with the same contents as the parameters
specified here, which can be made of mutable types.
"""
types_with_frozen_values: Dict[str, Optional[FrozenSet[str]]] = {}
for state_types, state_keys in types.items():
if state_keys is not None:
types_with_frozen_values[state_types] = frozenset(state_keys)
else:
types_with_frozen_values[state_types] = None
return StateFilter(
frozendict(types_with_frozen_values), include_others=include_others
)
def return_expanded(self) -> "StateFilter": def return_expanded(self) -> "StateFilter":
"""Creates a new StateFilter where type wild cards have been removed """Creates a new StateFilter where type wild cards have been removed
(except for memberships). The returned filter is a superset of the (except for memberships). The returned filter is a superset of the
@ -356,6 +375,157 @@ class StateFilter:
return member_filter, non_member_filter return member_filter, non_member_filter
def _decompose_into_four_parts(
self,
) -> Tuple[Tuple[bool, Set[str]], Tuple[Set[str], Set[StateKey]]]:
"""
Decomposes this state filter into 4 constituent parts, which can be
thought of as this:
all? - minus_wildcards + plus_wildcards + plus_state_keys
where
* all represents ALL state
* minus_wildcards represents entire state types to remove
* plus_wildcards represents entire state types to add
* plus_state_keys represents individual state keys to add
See `recompose_from_four_parts` for the other direction of this
correspondence.
"""
is_all = self.include_others
excluded_types: Set[str] = {t for t in self.types if is_all}
wildcard_types: Set[str] = {t for t, s in self.types.items() if s is None}
concrete_keys: Set[StateKey] = set(self.concrete_types())
return (is_all, excluded_types), (wildcard_types, concrete_keys)
@staticmethod
def _recompose_from_four_parts(
all_part: bool,
minus_wildcards: Set[str],
plus_wildcards: Set[str],
plus_state_keys: Set[StateKey],
) -> "StateFilter":
"""
Recomposes a state filter from 4 parts.
See `decompose_into_four_parts` (the other direction of this
correspondence) for descriptions on each of the parts.
"""
# {state type -> set of state keys OR None for wildcard}
# (The same structure as that of a StateFilter.)
new_types: Dict[str, Optional[Set[str]]] = {}
# if we start with all, insert the excluded statetypes as empty sets
# to prevent them from being included
if all_part:
new_types.update({state_type: set() for state_type in minus_wildcards})
# insert the plus wildcards
new_types.update({state_type: None for state_type in plus_wildcards})
# insert the specific state keys
for state_type, state_key in plus_state_keys:
if state_type in new_types:
entry = new_types[state_type]
if entry is not None:
entry.add(state_key)
elif not all_part:
# don't insert if the entire type is already included by
# include_others as this would actually shrink the state allowed
# by this filter.
new_types[state_type] = {state_key}
return StateFilter.freeze(new_types, include_others=all_part)
def approx_difference(self, other: "StateFilter") -> "StateFilter":
"""
Returns a state filter which represents `self - other`.
This is useful for determining what state remains to be pulled out of the
database if we want the state included by `self` but already have the state
included by `other`.
The returned state filter
- MUST include all state events that are included by this filter (`self`)
unless they are included by `other`;
- MUST NOT include state events not included by this filter (`self`); and
- MAY be an over-approximation: the returned state filter
MAY additionally include some state events from `other`.
This implementation attempts to return the narrowest such state filter.
In the case that `self` contains wildcards for state types where
`other` contains specific state keys, an approximation must be made:
the returned state filter keeps the wildcard, as state filters are not
able to express 'all state keys except some given examples'.
e.g.
StateFilter(m.room.member -> None (wildcard))
minus
StateFilter(m.room.member -> {'@wombat:example.org'})
is approximated as
StateFilter(m.room.member -> None (wildcard))
"""
# We first transform self and other into an alternative representation:
# - whether or not they include all events to begin with ('all')
# - if so, which event types are excluded? ('excludes')
# - which entire event types to include ('wildcards')
# - which concrete state keys to include ('concrete state keys')
(self_all, self_excludes), (
self_wildcards,
self_concrete_keys,
) = self._decompose_into_four_parts()
(other_all, other_excludes), (
other_wildcards,
other_concrete_keys,
) = other._decompose_into_four_parts()
# Start with an estimate of the difference based on self
new_all = self_all
# Wildcards from the other can be added to the exclusion filter
new_excludes = self_excludes | other_wildcards
# We remove wildcards that appeared as wildcards in the other
new_wildcards = self_wildcards - other_wildcards
# We filter out the concrete state keys that appear in the other
# as wildcards or concrete state keys.
new_concrete_keys = {
(state_type, state_key)
for (state_type, state_key) in self_concrete_keys
if state_type not in other_wildcards
} - other_concrete_keys
if other_all:
if self_all:
# If self starts with all, then we add as wildcards any
# types which appear in the other's exclusion filter (but
# aren't in the self exclusion filter). This is as the other
# filter will return everything BUT the types in its exclusion, so
# we need to add those excluded types that also match the self
# filter as wildcard types in the new filter.
new_wildcards |= other_excludes.difference(self_excludes)
# If other is an `include_others` then the difference isn't.
new_all = False
# (We have no need for excludes when we don't start with all, as there
# is nothing to exclude.)
new_excludes = set()
# We also filter out all state types that aren't in the exclusion
# list of the other.
new_wildcards &= other_excludes
new_concrete_keys = {
(state_type, state_key)
for (state_type, state_key) in new_concrete_keys
if state_type in other_excludes
}
# Transform our newly-constructed state filter from the alternative
# representation back into the normal StateFilter representation.
return StateFilter._recompose_from_four_parts(
new_all, new_excludes, new_wildcards, new_concrete_keys
)
class StateGroupStorage: class StateGroupStorage:
"""High level interface to fetching state for event.""" """High level interface to fetching state for event."""

View file

@ -36,7 +36,7 @@ from typing import (
) )
import attr import attr
from sortedcontainers import SortedSet from sortedcontainers import SortedList, SortedSet
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.database import ( from synapse.storage.database import (
@ -265,6 +265,15 @@ class MultiWriterIdGenerator:
# should be less than the minimum of this set (if not empty). # should be less than the minimum of this set (if not empty).
self._unfinished_ids: SortedSet[int] = SortedSet() self._unfinished_ids: SortedSet[int] = SortedSet()
# We also need to track when we've requested some new stream IDs but
# they haven't yet been added to the `_unfinished_ids` set. Every time
# we request a new stream ID we add the current max stream ID to the
# list, and remove it once we've added the newly allocated IDs to the
# `_unfinished_ids` set. This means that we *may* be allocated stream
# IDs above those in the list, and so we can't advance the local current
# position beyond the minimum stream ID in this list.
self._in_flight_fetches: SortedList[int] = SortedList()
# Set of local IDs that we've processed that are larger than the current # Set of local IDs that we've processed that are larger than the current
# position, due to there being smaller unpersisted IDs. # position, due to there being smaller unpersisted IDs.
self._finished_ids: Set[int] = set() self._finished_ids: Set[int] = set()
@ -290,6 +299,9 @@ class MultiWriterIdGenerator:
) )
self._known_persisted_positions: List[int] = [] self._known_persisted_positions: List[int] = []
# The maximum stream ID that we have seen been allocated across any writer.
self._max_seen_allocated_stream_id = 1
self._sequence_gen = PostgresSequenceGenerator(sequence_name) self._sequence_gen = PostgresSequenceGenerator(sequence_name)
# We check that the table and sequence haven't diverged. # We check that the table and sequence haven't diverged.
@ -305,6 +317,10 @@ class MultiWriterIdGenerator:
# This goes and fills out the above state from the database. # This goes and fills out the above state from the database.
self._load_current_ids(db_conn, tables) self._load_current_ids(db_conn, tables)
self._max_seen_allocated_stream_id = max(
self._current_positions.values(), default=1
)
def _load_current_ids( def _load_current_ids(
self, self,
db_conn: LoggingDatabaseConnection, db_conn: LoggingDatabaseConnection,
@ -411,10 +427,32 @@ class MultiWriterIdGenerator:
cur.close() cur.close()
def _load_next_id_txn(self, txn: Cursor) -> int: def _load_next_id_txn(self, txn: Cursor) -> int:
return self._sequence_gen.get_next_id_txn(txn) stream_ids = self._load_next_mult_id_txn(txn, 1)
return stream_ids[0]
def _load_next_mult_id_txn(self, txn: Cursor, n: int) -> List[int]: def _load_next_mult_id_txn(self, txn: Cursor, n: int) -> List[int]:
return self._sequence_gen.get_next_mult_txn(txn, n) # We need to track that we've requested some more stream IDs, and what
# the current max allocated stream ID is. This is to prevent a race
# where we've been allocated stream IDs but they have not yet been added
# to the `_unfinished_ids` set, allowing the current position to advance
# past them.
with self._lock:
current_max = self._max_seen_allocated_stream_id
self._in_flight_fetches.add(current_max)
try:
stream_ids = self._sequence_gen.get_next_mult_txn(txn, n)
with self._lock:
self._unfinished_ids.update(stream_ids)
self._max_seen_allocated_stream_id = max(
self._max_seen_allocated_stream_id, self._unfinished_ids[-1]
)
finally:
with self._lock:
self._in_flight_fetches.remove(current_max)
return stream_ids
def get_next(self) -> AsyncContextManager[int]: def get_next(self) -> AsyncContextManager[int]:
""" """
@ -463,9 +501,6 @@ class MultiWriterIdGenerator:
next_id = self._load_next_id_txn(txn) next_id = self._load_next_id_txn(txn)
with self._lock:
self._unfinished_ids.add(next_id)
txn.call_after(self._mark_id_as_finished, next_id) txn.call_after(self._mark_id_as_finished, next_id)
txn.call_on_exception(self._mark_id_as_finished, next_id) txn.call_on_exception(self._mark_id_as_finished, next_id)
@ -497,15 +532,27 @@ class MultiWriterIdGenerator:
new_cur: Optional[int] = None new_cur: Optional[int] = None
if self._unfinished_ids: if self._unfinished_ids or self._in_flight_fetches:
# If there are unfinished IDs then the new position will be the # If there are unfinished IDs then the new position will be the
# largest finished ID less than the minimum unfinished ID. # largest finished ID strictly less than the minimum unfinished
# ID.
# The minimum unfinished ID needs to take account of both
# `_unfinished_ids` and `_in_flight_fetches`.
if self._unfinished_ids and self._in_flight_fetches:
# `_in_flight_fetches` stores the maximum safe stream ID, so
# we add one to make it equivalent to the minimum unsafe ID.
min_unfinished = min(
self._unfinished_ids[0], self._in_flight_fetches[0] + 1
)
elif self._in_flight_fetches:
min_unfinished = self._in_flight_fetches[0] + 1
else:
min_unfinished = self._unfinished_ids[0]
finished = set() finished = set()
min_unfinshed = self._unfinished_ids[0]
for s in self._finished_ids: for s in self._finished_ids:
if s < min_unfinshed: if s < min_unfinished:
if new_cur is None or new_cur < s: if new_cur is None or new_cur < s:
new_cur = s new_cur = s
else: else:
@ -575,6 +622,10 @@ class MultiWriterIdGenerator:
new_id, self._current_positions.get(instance_name, 0) new_id, self._current_positions.get(instance_name, 0)
) )
self._max_seen_allocated_stream_id = max(
self._max_seen_allocated_stream_id, new_id
)
self._add_persisted_position(new_id) self._add_persisted_position(new_id)
def get_persisted_upto_position(self) -> int: def get_persisted_upto_position(self) -> int:
@ -605,7 +656,11 @@ class MultiWriterIdGenerator:
# to report a recent position when asked, rather than a potentially old # to report a recent position when asked, rather than a potentially old
# one (if this instance hasn't written anything for a while). # one (if this instance hasn't written anything for a while).
our_current_position = self._current_positions.get(self._instance_name) our_current_position = self._current_positions.get(self._instance_name)
if our_current_position and not self._unfinished_ids: if (
our_current_position
and not self._unfinished_ids
and not self._in_flight_fetches
):
self._current_positions[self._instance_name] = max( self._current_positions[self._instance_name] = max(
our_current_position, new_id our_current_position, new_id
) )
@ -697,9 +752,6 @@ class _MultiWriterCtxManager:
db_autocommit=True, db_autocommit=True,
) )
with self.id_gen._lock:
self.id_gen._unfinished_ids.update(self.stream_ids)
if self.multiple_ids is None: if self.multiple_ids is None:
return self.stream_ids[0] * self.id_gen._return_factor return self.stream_ids[0] * self.id_gen._return_factor
else: else:

View file

@ -12,25 +12,23 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from unittest.mock import patch
from synapse.config.cache import CacheConfig, add_resizable_cache from synapse.config.cache import CacheConfig, add_resizable_cache
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
from tests.unittest import TestCase from tests.unittest import TestCase
# Patch the global _CACHES so that each test runs against its own state.
@patch("synapse.config.cache._CACHES", new_callable=dict)
class CacheConfigTests(TestCase): class CacheConfigTests(TestCase):
def setUp(self): def setUp(self):
# Reset caches before each test # Reset caches before each test since there's global state involved.
self.config = CacheConfig() self.config = CacheConfig()
def tearDown(self):
self.config.reset() self.config.reset()
def test_individual_caches_from_environ(self, _caches): def tearDown(self):
# Also reset the caches after each test to leave state pristine.
self.config.reset()
def test_individual_caches_from_environ(self):
""" """
Individual cache factors will be loaded from the environment. Individual cache factors will be loaded from the environment.
""" """
@ -43,7 +41,7 @@ class CacheConfigTests(TestCase):
self.assertEqual(dict(self.config.cache_factors), {"something_or_other": 2.0}) self.assertEqual(dict(self.config.cache_factors), {"something_or_other": 2.0})
def test_config_overrides_environ(self, _caches): def test_config_overrides_environ(self):
""" """
Individual cache factors defined in the environment will take precedence Individual cache factors defined in the environment will take precedence
over those in the config. over those in the config.
@ -60,7 +58,7 @@ class CacheConfigTests(TestCase):
{"foo": 1.0, "bar": 3.0, "something_or_other": 2.0}, {"foo": 1.0, "bar": 3.0, "something_or_other": 2.0},
) )
def test_individual_instantiated_before_config_load(self, _caches): def test_individual_instantiated_before_config_load(self):
""" """
If a cache is instantiated before the config is read, it will be given If a cache is instantiated before the config is read, it will be given
the default cache size in the interim, and then resized once the config the default cache size in the interim, and then resized once the config
@ -76,7 +74,7 @@ class CacheConfigTests(TestCase):
self.assertEqual(cache.max_size, 300) self.assertEqual(cache.max_size, 300)
def test_individual_instantiated_after_config_load(self, _caches): def test_individual_instantiated_after_config_load(self):
""" """
If a cache is instantiated after the config is read, it will be If a cache is instantiated after the config is read, it will be
immediately resized to the correct size given the per_cache_factor if immediately resized to the correct size given the per_cache_factor if
@ -89,7 +87,7 @@ class CacheConfigTests(TestCase):
add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
self.assertEqual(cache.max_size, 200) self.assertEqual(cache.max_size, 200)
def test_global_instantiated_before_config_load(self, _caches): def test_global_instantiated_before_config_load(self):
""" """
If a cache is instantiated before the config is read, it will be given If a cache is instantiated before the config is read, it will be given
the default cache size in the interim, and then resized to the new the default cache size in the interim, and then resized to the new
@ -104,7 +102,7 @@ class CacheConfigTests(TestCase):
self.assertEqual(cache.max_size, 400) self.assertEqual(cache.max_size, 400)
def test_global_instantiated_after_config_load(self, _caches): def test_global_instantiated_after_config_load(self):
""" """
If a cache is instantiated after the config is read, it will be If a cache is instantiated after the config is read, it will be
immediately resized to the correct size given the global factor if there immediately resized to the correct size given the global factor if there
@ -117,7 +115,7 @@ class CacheConfigTests(TestCase):
add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
self.assertEqual(cache.max_size, 150) self.assertEqual(cache.max_size, 150)
def test_cache_with_asterisk_in_name(self, _caches): def test_cache_with_asterisk_in_name(self):
"""Some caches have asterisks in their name, test that they are set correctly.""" """Some caches have asterisks in their name, test that they are set correctly."""
config = { config = {
@ -143,7 +141,7 @@ class CacheConfigTests(TestCase):
add_resizable_cache("*cache_c*", cache_resize_callback=cache_c.set_cache_factor) add_resizable_cache("*cache_c*", cache_resize_callback=cache_c.set_cache_factor)
self.assertEqual(cache_c.max_size, 200) self.assertEqual(cache_c.max_size, 200)
def test_apply_cache_factor_from_config(self, _caches): def test_apply_cache_factor_from_config(self):
"""Caches can disable applying cache factor updates, mainly used by """Caches can disable applying cache factor updates, mainly used by
event cache size. event cache size.
""" """

View file

@ -63,7 +63,9 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
hostname="test", hostname="test",
id="1234", id="1234",
namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
sender="@as:test", # Note: this user does not match the regex above, so that tests
# can distinguish the sender from the AS user.
sender="@as_main:test",
) )
mock_load_appservices = Mock(return_value=[self.appservice]) mock_load_appservices = Mock(return_value=[self.appservice])
@ -122,7 +124,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
{(alice, bob, private), (bob, alice, private)}, {(alice, bob, private), (bob, alice, private)},
) )
# The next three tests (test_population_excludes_*) all setup # The next four tests (test_excludes_*) all setup
# - A normal user included in the user dir # - A normal user included in the user dir
# - A public and private room created by that user # - A public and private room created by that user
# - A user excluded from the room dir, belonging to both rooms # - A user excluded from the room dir, belonging to both rooms
@ -179,6 +181,34 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
) )
self._check_only_one_user_in_directory(user, public) self._check_only_one_user_in_directory(user, public)
def test_excludes_appservice_sender(self) -> None:
user = self.register_user("user", "pass")
token = self.login(user, "pass")
room = self.helper.create_room_as(user, is_public=True, tok=token)
self.helper.join(room, self.appservice.sender, tok=self.appservice.token)
self._check_only_one_user_in_directory(user, room)
def test_user_not_in_users_table(self) -> None:
"""Unclear how it happens, but on matrix.org we've seen join events
for users who aren't in the users table. Test that we don't fall over
when processing such a user.
"""
user1 = self.register_user("user1", "pass")
token1 = self.login(user1, "pass")
room = self.helper.create_room_as(user1, is_public=True, tok=token1)
# Inject a join event for a user who doesn't exist
self.get_success(inject_member_event(self.hs, room, "@not-a-user:test", "join"))
# Another new user registers and joins the room
user2 = self.register_user("user2", "pass")
token2 = self.login(user2, "pass")
self.helper.join(room, user2, tok=token2)
# The dodgy event should not have stopped us from processing user2's join.
in_public = self.get_success(self.user_dir_helper.get_users_in_public_rooms())
self.assertEqual(set(in_public), {(user1, room), (user2, room)})
def _create_rooms_and_inject_memberships( def _create_rooms_and_inject_memberships(
self, creator: str, token: str, joiner: str self, creator: str, token: str, joiner: str
) -> Tuple[str, str]: ) -> Tuple[str, str]:
@ -230,7 +260,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
) )
) )
profile = self.get_success(self.store.get_user_in_directory(support_user_id)) profile = self.get_success(self.store.get_user_in_directory(support_user_id))
self.assertTrue(profile is None) self.assertIsNone(profile)
display_name = "display_name" display_name = "display_name"
profile_info = ProfileInfo(avatar_url="avatar_url", display_name=display_name) profile_info = ProfileInfo(avatar_url="avatar_url", display_name=display_name)
@ -264,7 +294,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# profile is not in directory # profile is not in directory
profile = self.get_success(self.store.get_user_in_directory(r_user_id)) profile = self.get_success(self.store.get_user_in_directory(r_user_id))
self.assertTrue(profile is None) self.assertIsNone(profile)
# update profile after deactivation # update profile after deactivation
self.get_success( self.get_success(
@ -273,7 +303,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# profile is furthermore not in directory # profile is furthermore not in directory
profile = self.get_success(self.store.get_user_in_directory(r_user_id)) profile = self.get_success(self.store.get_user_in_directory(r_user_id))
self.assertTrue(profile is None) self.assertIsNone(profile)
def test_handle_local_profile_change_with_appservice_user(self) -> None: def test_handle_local_profile_change_with_appservice_user(self) -> None:
# create user # create user
@ -283,7 +313,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# profile is not in directory # profile is not in directory
profile = self.get_success(self.store.get_user_in_directory(as_user_id)) profile = self.get_success(self.store.get_user_in_directory(as_user_id))
self.assertTrue(profile is None) self.assertIsNone(profile)
# update profile # update profile
profile_info = ProfileInfo(avatar_url="avatar_url", display_name="4L1c3") profile_info = ProfileInfo(avatar_url="avatar_url", display_name="4L1c3")
@ -293,7 +323,28 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# profile is still not in directory # profile is still not in directory
profile = self.get_success(self.store.get_user_in_directory(as_user_id)) profile = self.get_success(self.store.get_user_in_directory(as_user_id))
self.assertTrue(profile is None) self.assertIsNone(profile)
def test_handle_local_profile_change_with_appservice_sender(self) -> None:
# profile is not in directory
profile = self.get_success(
self.store.get_user_in_directory(self.appservice.sender)
)
self.assertIsNone(profile)
# update profile
profile_info = ProfileInfo(avatar_url="avatar_url", display_name="4L1c3")
self.get_success(
self.handler.handle_local_profile_change(
self.appservice.sender, profile_info
)
)
# profile is still not in directory
profile = self.get_success(
self.store.get_user_in_directory(self.appservice.sender)
)
self.assertIsNone(profile)
def test_handle_user_deactivated_support_user(self) -> None: def test_handle_user_deactivated_support_user(self) -> None:
s_user_id = "@support:test" s_user_id = "@support:test"

View file

@ -146,6 +146,49 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
], ],
) )
@parameterized.expand([(False,), (True,)])
def test_get_last_client_ip_by_device(self, after_persisting: bool):
"""Test `get_last_client_ip_by_device` for persisted and unpersisted data"""
self.reactor.advance(12345678)
user_id = "@user:id"
device_id = "MY_DEVICE"
# Insert a user IP
self.get_success(
self.store.store_device(
user_id,
device_id,
"display name",
)
)
self.get_success(
self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", device_id
)
)
if after_persisting:
# Trigger the storage loop
self.reactor.advance(10)
result = self.get_success(
self.store.get_last_client_ip_by_device(user_id, device_id)
)
self.assertEqual(
result,
{
(user_id, device_id): {
"user_id": user_id,
"device_id": device_id,
"ip": "ip",
"user_agent": "user_agent",
"last_seen": 12345678000,
},
},
)
@parameterized.expand([(False,), (True,)]) @parameterized.expand([(False,), (True,)])
def test_get_user_ip_and_agents(self, after_persisting: bool): def test_get_user_ip_and_agents(self, after_persisting: bool):
"""Test `get_user_ip_and_agents` for persisted and unpersisted data""" """Test `get_user_ip_and_agents` for persisted and unpersisted data"""

View file

@ -21,7 +21,7 @@ from synapse.api.room_versions import RoomVersions
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.types import RoomID, UserID from synapse.types import RoomID, UserID
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase, TestCase
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -105,7 +105,6 @@ class StateStoreTestCase(HomeserverTestCase):
self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id}) self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id})
def test_get_state_for_event(self): def test_get_state_for_event(self):
# this defaults to a linear DAG as each new injection defaults to whatever # this defaults to a linear DAG as each new injection defaults to whatever
# forward extremities are currently in the DB for this room. # forward extremities are currently in the DB for this room.
e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {}) e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
@ -483,3 +482,513 @@ class StateStoreTestCase(HomeserverTestCase):
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict) self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
class StateFilterDifferenceTestCase(TestCase):
def assert_difference(
self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter
):
self.assertEqual(
minuend.approx_difference(subtrahend),
expected,
f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}",
)
def test_state_filter_difference_no_include_other_minus_no_include_other(self):
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), both a and b do not have the
include_others flag set.
"""
# (wildcard on state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.Create: None},
include_others=False,
),
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
include_others=False,
),
StateFilter.freeze({EventTypes.Create: None}, include_others=False),
)
# (wildcard on state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
self.assert_difference(
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
StateFilter.freeze(
{EventTypes.Member: {"@wombat:spqr"}},
include_others=False,
),
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
)
# (wildcard on state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=False,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
)
# (specific state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
StateFilter.freeze(
{EventTypes.CanonicalAlias: {""}},
include_others=False,
),
)
# (specific state keys) - (specific state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
)
# (specific state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
)
def test_state_filter_difference_include_other_minus_no_include_other(self):
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), only a has the include_others flag set.
"""
# (wildcard on state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.Create: None},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Create: None,
EventTypes.Member: set(),
EventTypes.CanonicalAlias: set(),
},
include_others=True,
),
)
# (wildcard on state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
# This also shows that the resultant state filter is normalised.
self.assert_difference(
StateFilter.freeze({EventTypes.Member: None}, include_others=True),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
EventTypes.Create: {""},
},
include_others=False,
),
StateFilter(types=frozendict(), include_others=True),
)
# (wildcard on state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=False,
),
StateFilter(
types=frozendict(),
include_others=True,
),
)
# (specific state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.CanonicalAlias: {""},
EventTypes.Member: set(),
},
include_others=True,
),
)
# (specific state keys) - (specific state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
)
# (specific state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
)
def test_state_filter_difference_include_other_minus_include_other(self):
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), both a and b have the include_others
flag set.
"""
# (wildcard on state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.Create: None},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
include_others=True,
),
StateFilter(types=frozendict(), include_others=False),
)
# (wildcard on state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
self.assert_difference(
StateFilter.freeze({EventTypes.Member: None}, include_others=True),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
include_others=False,
),
)
# (wildcard on state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
)
# (specific state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=True,
),
StateFilter(
types=frozendict(),
include_others=False,
),
)
# (specific state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
EventTypes.Create: {""},
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
EventTypes.Create: set(),
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@spqr:spqr"},
EventTypes.Create: {""},
},
include_others=False,
),
)
# (specific state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
},
include_others=False,
),
)
def test_state_filter_difference_no_include_other_minus_include_other(self):
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), only b has the include_others flag set.
"""
# (wildcard on state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.Create: None},
include_others=False,
),
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
include_others=True,
),
StateFilter(types=frozendict(), include_others=False),
)
# (wildcard on state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
self.assert_difference(
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
StateFilter.freeze(
{EventTypes.Member: {"@wombat:spqr"}},
include_others=True,
),
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
)
# (wildcard on state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
)
# (specific state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=True,
),
StateFilter(
types=frozendict(),
include_others=False,
),
)
# (specific state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@spqr:spqr"},
},
include_others=False,
),
)
# (specific state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
},
include_others=False,
),
)
def test_state_filter_difference_simple_cases(self):
"""
Tests some very simple cases of the StateFilter approx_difference,
that are not explicitly tested by the more in-depth tests.
"""
self.assert_difference(StateFilter.all(), StateFilter.all(), StateFilter.none())
self.assert_difference(
StateFilter.all(),
StateFilter.none(),
StateFilter.all(),
)

View file

@ -256,7 +256,7 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
users = self.get_success(self.user_dir_helper.get_users_in_user_directory()) users = self.get_success(self.user_dir_helper.get_users_in_user_directory())
self.assertEqual(users, {u1, u2, u3}) self.assertEqual(users, {u1, u2, u3})
# The next three tests (test_population_excludes_*) all set up # The next four tests (test_population_excludes_*) all set up
# - A normal user included in the user dir # - A normal user included in the user dir
# - A public and private room created by that user # - A public and private room created by that user
# - A user excluded from the room dir, belonging to both rooms # - A user excluded from the room dir, belonging to both rooms
@ -364,6 +364,21 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
# Check the AS user is not in the directory. # Check the AS user is not in the directory.
self._check_room_sharing_tables(user, public, private) self._check_room_sharing_tables(user, public, private)
def test_population_excludes_appservice_sender(self) -> None:
user = self.register_user("user", "pass")
token = self.login(user, "pass")
# Join the AS sender to rooms owned by the normal user.
public, private = self._create_rooms_and_inject_memberships(
user, token, self.appservice.sender
)
# Rebuild the directory.
self._purge_and_rebuild_user_dir()
# Check the AS sender is not in the directory.
self._check_room_sharing_tables(user, public, private)
def test_population_conceals_private_nickname(self) -> None: def test_population_conceals_private_nickname(self) -> None:
# Make a private room, and set a nickname within # Make a private room, and set a nickname within
user = self.register_user("aaaa", "pass") user = self.register_user("aaaa", "pass")