Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes

This commit is contained in:
David Robertson 2021-09-22 13:35:31 +01:00
commit a8340692ab
No known key found for this signature in database
GPG key ID: 903ECE108A39DEDD
164 changed files with 1903 additions and 1288 deletions

View file

@ -61,6 +61,5 @@ jobs:
uses: peaceiris/actions-gh-pages@068dc23d9710f1ba62e86896f84735d869951305 # v3.8.0
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
keep_files: true
publish_dir: ./book
destination_dir: ./${{ steps.vars.outputs.branch-version }}

View file

@ -192,6 +192,7 @@ jobs:
volumes:
- ${{ github.workspace }}:/src
env:
SYTEST_BRANCH: ${{ github.head_ref }}
POSTGRES: ${{ matrix.postgres && 1}}
MULTI_POSTGRES: ${{ (matrix.postgres == 'multi-postgres') && 1}}
WORKERS: ${{ matrix.workers && 1 }}

View file

@ -1,7 +1,23 @@
Synapse 1.43.0rc1 (2021-09-14)
Synapse 1.43.0 (2021-09-21)
===========================
This release drops support for the deprecated, unstable API for [MSC2858 (Multiple SSO Identity Providers)](https://github.com/matrix-org/matrix-doc/blob/master/proposals/2858-Multiple-SSO-Identity-Providers.md#unstable-prefix), as well as the undocumented `experimental.msc2858_enabled` config option. Client authors should update their clients to use the stable API, available since Synapse 1.30.
The documentation has been updated with configuration for routing `/spaces`, `/hierarchy` and `/summary` to workers. See [the upgrade notes](https://github.com/matrix-org/synapse/blob/release-v1.43/docs/upgrade.md#upgrading-to-v1430) for more details.
No significant changes since 1.43.0rc2.
Synapse 1.43.0rc2 (2021-09-17)
==============================
This release drops support for the deprecated, unstable API for [MSC2858](https://github.com/matrix-org/matrix-doc/blob/master/proposals/2858-Multiple-SSO-Identity-Providers.md#unstable-prefix), as well as the undocumented `experimental.msc2858_enabled` config option. Client authors should update their clients to use the stable API, available since Synapse 1.30.
Bugfixes
--------
- Added opentracing logging to help debug [\#9424](https://github.com/matrix-org/synapse/issues/9424). ([\#10828](https://github.com/matrix-org/synapse/issues/10828))
Synapse 1.43.0rc1 (2021-09-14)
==============================
Features
--------

1
changelog.d/10659.misc Normal file
View file

@ -0,0 +1 @@
Fix GitHub Actions config so we can run sytest on synapse from parallel branches.

View file

@ -0,0 +1 @@
Only allow the [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) `/batch_send?chunk_id=xxx` endpoint to connect to an already existing insertion event.

1
changelog.d/10777.misc Normal file
View file

@ -0,0 +1 @@
Split out [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) meta events to their own fields in the `/batch_send` response.

1
changelog.d/10785.misc Normal file
View file

@ -0,0 +1 @@
Add missing type hints to REST servlets.

1
changelog.d/10796.misc Normal file
View file

@ -0,0 +1 @@
Simplify the internal logic which maintains the user directory database tables.

1
changelog.d/10807.bugfix Normal file
View file

@ -0,0 +1 @@
Allow sending a membership event to unban a user. Contributed by @aaronraimist.

1
changelog.d/10810.bugfix Normal file
View file

@ -0,0 +1 @@
Fix a case where logging contexts would go missing when federation requests time out.

1
changelog.d/10812.misc Normal file
View file

@ -0,0 +1 @@
Use direct references to config flags.

View file

@ -0,0 +1 @@
Improve oEmbed previews by processing the author name, photo, and video information.

1
changelog.d/10815.misc Normal file
View file

@ -0,0 +1 @@
Specify the type of token in generic "Invalid token" error messages.

1
changelog.d/10816.misc Normal file
View file

@ -0,0 +1 @@
Make `StateFilter` frozen so it is hashable.

1
changelog.d/10817.misc Normal file
View file

@ -0,0 +1 @@
Add missing type hints to REST servlets.

1
changelog.d/10823.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to the state database.

View file

@ -1 +0,0 @@
Added opentrace logging to help debug #9424.

1
changelog.d/10829.misc Normal file
View file

@ -0,0 +1 @@
Track cache eviction rates more finely in Prometheus' monitoring.

1
changelog.d/10831.misc Normal file
View file

@ -0,0 +1 @@
Add missing type hints to handlers.

1
changelog.d/10834.misc Normal file
View file

@ -0,0 +1 @@
Factor out PNG image data to a constant to be used in several tests.

1
changelog.d/10835.misc Normal file
View file

@ -0,0 +1 @@
Add a test to ensure state events sent by modules get persisted correctly.

1
changelog.d/10838.misc Normal file
View file

@ -0,0 +1 @@
Rename [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) fields and event types from `chunk` to `batch` to match the `/batch_send` endpoint.

1
changelog.d/10839.misc Normal file
View file

@ -0,0 +1 @@
Rename [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) `/batch_send` query parameter from `?prev_event` to more obvious usage with `?prev_event_id`.

1
changelog.d/10843.bugfix Normal file
View file

@ -0,0 +1 @@
Fix a bug causing the `remove_stale_pushers` background job to repeatedly fail and log errors. This bug affected Synapse servers that had been upgraded from version 1.28 or older and are using SQLite.

1
changelog.d/10845.doc Normal file
View file

@ -0,0 +1 @@
Fix some crashes in the Module API example code, by adding JSON encoding/decoding.

1
changelog.d/10856.misc Normal file
View file

@ -0,0 +1 @@
Add missing type hints to handlers.

1
changelog.d/10859.bugfix Normal file
View file

@ -0,0 +1 @@
Fix a bug in Unicode support of the room search admin API. It is now possible to search for rooms with non-ASCII characters.

1
changelog.d/10867.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to `synapse.http.site`.

1
changelog.d/10869.doc Normal file
View file

@ -0,0 +1 @@
Properly remove deleted files from GitHub pages when generating the documentation.

1
changelog.d/10879.misc Normal file
View file

@ -0,0 +1 @@
Include outlier status when we log V2 or V3 events.

12
debian/changelog vendored
View file

@ -1,3 +1,15 @@
matrix-synapse-py3 (1.43.0) stable; urgency=medium
* New synapse release 1.43.0.
-- Synapse Packaging team <packages@matrix.org> Tue, 21 Sep 2021 11:49:05 +0100
matrix-synapse-py3 (1.43.0~rc2) stable; urgency=medium
* New synapse release 1.43.0~rc2.
-- Synapse Packaging team <packages@matrix.org> Fri, 17 Sep 2021 10:43:21 +0100
matrix-synapse-py3 (1.43.0~rc1) stable; urgency=medium
* New synapse release 1.43.0~rc1.

View file

@ -25,16 +25,14 @@ When Synapse is asked to preview a URL it does the following:
3. Kicks off a background process to generate a preview:
1. Checks the database cache by URL and timestamp and returns the result if it
has not expired and was successful (a 2xx return code).
2. Checks if the URL matches an oEmbed pattern. If it does, fetch the oEmbed
response. If this is an image, replace the URL to fetch and continue. If
if it is HTML content, use the HTML as the document and continue.
3. If it doesn't match an oEmbed pattern, downloads the URL and stores it
into a file via the media storage provider and saves the local media
metadata.
5. If the media is an image:
2. Checks if the URL matches an [oEmbed](https://oembed.com/) pattern. If it
does, update the URL to download.
3. Downloads the URL and stores it into a file via the media storage provider
and saves the local media metadata.
4. If the media is an image:
1. Generates thumbnails.
2. Generates an Open Graph response based on image properties.
6. If the media is HTML:
5. If the media is HTML:
1. Decodes the HTML via the stored file.
2. Generates an Open Graph response from the HTML.
3. If an image exists in the Open Graph response:
@ -42,6 +40,13 @@ When Synapse is asked to preview a URL it does the following:
provider and saves the local media metadata.
2. Generates thumbnails.
3. Updates the Open Graph response based on image properties.
6. If the media is JSON and an oEmbed URL was found:
1. Convert the oEmbed response to an Open Graph response.
2. If a thumbnail or image is in the oEmbed response:
1. Downloads the URL and stores it into a file via the media storage
provider and saves the local media metadata.
2. Generates thumbnails.
3. Updates the Open Graph response based on image properties.
7. Stores the result in the database cache.
4. Returns the result.

View file

@ -136,9 +136,9 @@ class IsUserEvilResource(Resource):
self.evil_users = config.get("evil_users") or []
def render_GET(self, request: Request):
user = request.args.get(b"user")[0]
user = request.args.get(b"user")[0].decode()
request.setHeader(b"Content-Type", b"application/json")
return json.dumps({"evil": user in self.evil_users})
return json.dumps({"evil": user in self.evil_users}).encode()
class ListSpamChecker:

View file

@ -2362,12 +2362,16 @@ user_directory:
#enabled: false
# Defines whether to search all users visible to your HS when searching
# the user directory, rather than limiting to users visible in public
# rooms. Defaults to false.
# the user directory. If false, search results will only contain users
# visible in public rooms and users sharing a room with the requester.
# Defaults to false.
#
# If you set it true, you'll have to rebuild the user_directory search
# indexes, see:
# https://matrix-org.github.io/synapse/latest/user_directory.html
# NB. If you set this to true, and the last time the user_directory search
# indexes were (re)built was before Synapse 1.44, you'll have to
# rebuild the indexes in order to search through all known users.
# These indexes are built the first time Synapse starts; admins can
# manually trigger a rebuild following the instructions at
# https://matrix-org.github.io/synapse/latest/user_directory.html
#
# Uncomment to return search results containing all known users, even if that
# user does not share a room with the requester.

View file

@ -60,6 +60,7 @@ files =
synapse/storage/databases/main/session.py,
synapse/storage/databases/main/stream.py,
synapse/storage/databases/main/ui_auth.py,
synapse/storage/databases/state,
synapse/storage/database.py,
synapse/storage/engines,
synapse/storage/keys.py,
@ -86,10 +87,14 @@ files =
tests/handlers/test_sync.py,
tests/rest/client/test_login.py,
tests/rest/client/test_auth.py,
tests/storage/test_state.py,
tests/util/test_itertools.py,
tests/util/test_stream_change_cache.py
[mypy-synapse.rest.client.*]
[mypy-synapse.handlers.*]
disallow_untyped_defs = True
[mypy-synapse.rest.*]
disallow_untyped_defs = True
[mypy-synapse.util.batching_queue]

View file

@ -47,7 +47,7 @@ try:
except ImportError:
pass
__version__ = "1.43.0rc1"
__version__ = "1.43.0"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when

View file

@ -70,8 +70,8 @@ class Auth:
self._auth_blocking = AuthBlocking(self.hs)
self._track_appservice_user_ips = hs.config.track_appservice_user_ips
self._macaroon_secret_key = hs.config.macaroon_secret_key
self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips
self._macaroon_secret_key = hs.config.key.macaroon_secret_key
self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users
async def check_user_in_room(

View file

@ -30,13 +30,15 @@ class AuthBlocking:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self._server_notices_mxid = hs.config.server_notices_mxid
self._hs_disabled = hs.config.hs_disabled
self._hs_disabled_message = hs.config.hs_disabled_message
self._admin_contact = hs.config.admin_contact
self._max_mau_value = hs.config.max_mau_value
self._limit_usage_by_mau = hs.config.limit_usage_by_mau
self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids
self._server_notices_mxid = hs.config.servernotices.server_notices_mxid
self._hs_disabled = hs.config.server.hs_disabled
self._hs_disabled_message = hs.config.server.hs_disabled_message
self._admin_contact = hs.config.server.admin_contact
self._max_mau_value = hs.config.server.max_mau_value
self._limit_usage_by_mau = hs.config.server.limit_usage_by_mau
self._mau_limits_reserved_threepids = (
hs.config.server.mau_limits_reserved_threepids
)
self._server_name = hs.hostname
self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips

View file

@ -121,7 +121,7 @@ class EventTypes:
SpaceParent = "m.space.parent"
MSC2716_INSERTION = "org.matrix.msc2716.insertion"
MSC2716_CHUNK = "org.matrix.msc2716.chunk"
MSC2716_BATCH = "org.matrix.msc2716.batch"
MSC2716_MARKER = "org.matrix.msc2716.marker"
@ -209,11 +209,11 @@ class EventContentFields:
# Used on normal messages to indicate they were historically imported after the fact
MSC2716_HISTORICAL = "org.matrix.msc2716.historical"
# For "insertion" events to indicate what the next chunk ID should be in
# For "insertion" events to indicate what the next batch ID should be in
# order to connect to it
MSC2716_NEXT_CHUNK_ID = "org.matrix.msc2716.next_chunk_id"
# Used on "chunk" events to indicate which insertion event it connects to
MSC2716_CHUNK_ID = "org.matrix.msc2716.chunk_id"
MSC2716_NEXT_BATCH_ID = "org.matrix.msc2716.next_batch_id"
# Used on "batch" events to indicate which insertion event it connects to
MSC2716_BATCH_ID = "org.matrix.msc2716.batch_id"
# For "marker" events
MSC2716_MARKER_INSERTION = "org.matrix.msc2716.marker.insertion"

View file

@ -244,24 +244,8 @@ class RoomVersions:
msc2716_historical=False,
msc2716_redactions=False,
)
MSC2716 = RoomVersion(
"org.matrix.msc2716",
RoomDisposition.UNSTABLE,
EventFormatVersions.V3,
StateResolutionVersions.V2,
enforce_key_validity=True,
special_case_aliases_auth=False,
strict_canonicaljson=True,
limit_notifications_power_levels=True,
msc2176_redaction_rules=False,
msc3083_join_rules=False,
msc3375_redaction_rules=False,
msc2403_knocking=True,
msc2716_historical=True,
msc2716_redactions=False,
)
MSC2716v2 = RoomVersion(
"org.matrix.msc2716v2",
MSC2716v3 = RoomVersion(
"org.matrix.msc2716v3",
RoomDisposition.UNSTABLE,
EventFormatVersions.V3,
StateResolutionVersions.V2,
@ -289,9 +273,9 @@ KNOWN_ROOM_VERSIONS: Dict[str, RoomVersion] = {
RoomVersions.V6,
RoomVersions.MSC2176,
RoomVersions.V7,
RoomVersions.MSC2716,
RoomVersions.V8,
RoomVersions.V9,
RoomVersions.MSC2716v3,
)
}

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List
from typing import Any, List, Tuple, Type
from synapse.util.module_loader import load_module
@ -25,7 +25,7 @@ class PasswordAuthProviderConfig(Config):
section = "authproviders"
def read_config(self, config, **kwargs):
self.password_providers: List[Any] = []
self.password_providers: List[Tuple[Type, Any]] = []
providers = []
# We want to be backwards compatible with the old `ldap_config`

View file

@ -45,12 +45,16 @@ class UserDirectoryConfig(Config):
#enabled: false
# Defines whether to search all users visible to your HS when searching
# the user directory, rather than limiting to users visible in public
# rooms. Defaults to false.
# the user directory. If false, search results will only contain users
# visible in public rooms and users sharing a room with the requester.
# Defaults to false.
#
# If you set it true, you'll have to rebuild the user_directory search
# indexes, see:
# https://matrix-org.github.io/synapse/latest/user_directory.html
# NB. If you set this to true, and the last time the user_directory search
# indexes were (re)built was before Synapse 1.44, you'll have to
# rebuild the indexes in order to search through all known users.
# These indexes are built the first time Synapse starts; admins can
# manually trigger a rebuild following the instructions at
# https://matrix-org.github.io/synapse/latest/user_directory.html
#
# Uncomment to return search results containing all known users, even if that
# user does not share a room with the requester.

View file

@ -102,7 +102,7 @@ class FederationPolicyForHTTPS:
self._config = config
# Check if we're using a custom list of a CA certificates
trust_root = config.federation_ca_trust_root
trust_root = config.tls.federation_ca_trust_root
if trust_root is None:
# Use CA root certs provided by OpenSSL
trust_root = platformTrust()
@ -113,7 +113,7 @@ class FederationPolicyForHTTPS:
# moving to TLS 1.2 by default, we want to respect the config option if
# it is set to 1.0 (which the alternate option, raiseMinimumTo, will not
# let us do).
minTLS = _TLS_VERSION_MAP[config.federation_client_minimum_tls_version]
minTLS = _TLS_VERSION_MAP[config.tls.federation_client_minimum_tls_version]
_verify_ssl = CertificateOptions(
trustRoot=trust_root, insecurelyLowerMinimumTo=minTLS
@ -125,10 +125,10 @@ class FederationPolicyForHTTPS:
self._no_verify_ssl_context = _no_verify_ssl.getContext()
self._no_verify_ssl_context.set_info_callback(_context_info_cb)
self._should_verify = self._config.federation_verify_certificates
self._should_verify = self._config.tls.federation_verify_certificates
self._federation_certificate_verification_whitelist = (
self._config.federation_certificate_verification_whitelist
self._config.tls.federation_certificate_verification_whitelist
)
def get_options(self, host: bytes):

View file

@ -572,7 +572,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
super().__init__(hs)
self.clock = hs.get_clock()
self.client = hs.get_federation_http_client()
self.key_servers = self.config.key_servers
self.key_servers = self.config.key.key_servers
async def _fetch_keys(
self, keys_to_fetch: List[_FetchKeyRequest]

View file

@ -213,7 +213,7 @@ def check(
if (
event.type == EventTypes.MSC2716_INSERTION
or event.type == EventTypes.MSC2716_CHUNK
or event.type == EventTypes.MSC2716_BATCH
or event.type == EventTypes.MSC2716_MARKER
):
check_historical(room_version_obj, event, auth_events)
@ -552,14 +552,14 @@ def check_historical(
auth_events: StateMap[EventBase],
) -> None:
"""Check whether the event sender is allowed to send historical related
events like "insertion", "chunk", and "marker".
events like "insertion", "batch", and "marker".
Returns:
None
Raises:
AuthError if the event sender is not allowed to send historical related events
("insertion", "chunk", and "marker").
("insertion", "batch", and "marker").
"""
# Ignore the auth checks in room versions that do not support historical
# events
@ -573,7 +573,7 @@ def check_historical(
if user_level < historical_level:
raise AuthError(
403,
'You don\'t have permission to send send historical related events ("insertion", "chunk", and "marker")',
'You don\'t have permission to send send historical related events ("insertion", "batch", and "marker")',
)

View file

@ -344,6 +344,18 @@ class EventBase(metaclass=abc.ABCMeta):
# this will be a no-op if the event dict is already frozen.
self._dict = freeze(self._dict)
def __str__(self):
return self.__repr__()
def __repr__(self):
return "<%s event_id=%r, type=%r, state_key=%r, outlier=%s>" % (
self.__class__.__name__,
self.event_id,
self.get("type", None),
self.get("state_key", None),
self.internal_metadata.is_outlier(),
)
class FrozenEvent(EventBase):
format_version = EventFormatVersions.V1 # All events of this type are V1
@ -392,17 +404,6 @@ class FrozenEvent(EventBase):
def event_id(self) -> str:
return self._event_id
def __str__(self):
return self.__repr__()
def __repr__(self):
return "<FrozenEvent event_id=%r, type=%r, state_key=%r, outlier=%s>" % (
self.get("event_id", None),
self.get("type", None),
self.get("state_key", None),
self.internal_metadata.is_outlier(),
)
class FrozenEventV2(EventBase):
format_version = EventFormatVersions.V2 # All events of this type are V2
@ -478,17 +479,6 @@ class FrozenEventV2(EventBase):
"""
return self.auth_events
def __str__(self):
return self.__repr__()
def __repr__(self):
return "<%s event_id=%r, type=%r, state_key=%r>" % (
self.__class__.__name__,
self.event_id,
self.get("type", None),
self.get("state_key", None),
)
class FrozenEventV3(FrozenEventV2):
"""FrozenEventV3, which differs from FrozenEventV2 only in the event_id format"""

View file

@ -141,9 +141,9 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
elif event_type == EventTypes.Redaction and room_version.msc2176_redaction_rules:
add_fields("redacts")
elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_INSERTION:
add_fields(EventContentFields.MSC2716_NEXT_CHUNK_ID)
elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_CHUNK:
add_fields(EventContentFields.MSC2716_CHUNK_ID)
add_fields(EventContentFields.MSC2716_NEXT_BATCH_ID)
elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_BATCH:
add_fields(EventContentFields.MSC2716_BATCH_ID)
elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_MARKER:
add_fields(EventContentFields.MSC2716_MARKER_INSERTION)

View file

@ -1237,7 +1237,7 @@ class FederationHandlerRegistry:
self._edu_type_to_instance[edu_type] = instance_names
async def on_edu(self, edu_type: str, origin: str, content: dict) -> None:
if not self.config.use_presence and edu_type == EduTypes.Presence:
if not self.config.server.use_presence and edu_type == EduTypes.Presence:
return
# Check if we have a handler on this instance

View file

@ -594,7 +594,7 @@ class FederationSender(AbstractFederationSender):
destinations (list[str])
"""
if not states or not self.hs.config.use_presence:
if not states or not self.hs.config.server.use_presence:
# No-op if presence is disabled.
return

View file

@ -16,6 +16,7 @@ import logging
from typing import TYPE_CHECKING, Optional
from synapse.api.ratelimiting import Ratelimiter
from synapse.types import Requester
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -63,16 +64,21 @@ class BaseHandler:
self.event_builder_factory = hs.get_event_builder_factory()
async def ratelimit(self, requester, update=True, is_admin_redaction=False):
async def ratelimit(
self,
requester: Requester,
update: bool = True,
is_admin_redaction: bool = False,
) -> None:
"""Ratelimits requests.
Args:
requester (Requester)
update (bool): Whether to record that a request is being processed.
requester
update: Whether to record that a request is being processed.
Set to False when doing multiple checks for one request (e.g.
to check up front if we would reject the request), and set to
True for the last call for a given request.
is_admin_redaction (bool): Whether this is a room admin/moderator
is_admin_redaction: Whether this is a room admin/moderator
redacting an event. If so then we may apply different
ratelimits depending on config.

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from typing import TYPE_CHECKING, List, Tuple
from typing import TYPE_CHECKING, Collection, List, Optional, Tuple
from synapse.replication.http.account_data import (
ReplicationAddTagRestServlet,
@ -21,6 +21,7 @@ from synapse.replication.http.account_data import (
ReplicationRoomAccountDataRestServlet,
ReplicationUserAccountDataRestServlet,
)
from synapse.streams import EventSource
from synapse.types import JsonDict, UserID
if TYPE_CHECKING:
@ -163,7 +164,7 @@ class AccountDataHandler:
return response["max_stream_id"]
class AccountDataEventSource:
class AccountDataEventSource(EventSource[int, JsonDict]):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
@ -171,7 +172,13 @@ class AccountDataEventSource:
return self.store.get_max_account_data_stream_id()
async def get_new_events(
self, user: UserID, from_key: int, **kwargs
self,
user: UserID,
from_key: int,
limit: Optional[int],
room_ids: Collection[str],
is_guest: bool,
explicit_room_id: Optional[str] = None,
) -> Tuple[List[JsonDict], int]:
user_id = user.to_string()
last_stream_id = from_key

View file

@ -99,7 +99,7 @@ class AccountValidityHandler:
on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None,
on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None,
on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None,
):
) -> None:
"""Register callbacks from module for each hook."""
if is_user_expired is not None:
self._is_user_expired_callbacks.append(is_user_expired)
@ -165,7 +165,7 @@ class AccountValidityHandler:
return False
async def on_user_registration(self, user_id: str):
async def on_user_registration(self, user_id: str) -> None:
"""Tell third-party modules about a user's registration.
Args:

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Union
from prometheus_client import Counter
@ -58,7 +58,7 @@ class ApplicationServicesHandler:
self.current_max = 0
self.is_processing = False
def notify_interested_services(self, max_token: RoomStreamToken):
def notify_interested_services(self, max_token: RoomStreamToken) -> None:
"""Notifies (pushes) all application services interested in this event.
Pushing is done asynchronously, so this method won't block for any
@ -82,7 +82,7 @@ class ApplicationServicesHandler:
self._notify_interested_services(max_token)
@wrap_as_background_process("notify_interested_services")
async def _notify_interested_services(self, max_token: RoomStreamToken):
async def _notify_interested_services(self, max_token: RoomStreamToken) -> None:
with Measure(self.clock, "notify_interested_services"):
self.is_processing = True
try:
@ -100,7 +100,7 @@ class ApplicationServicesHandler:
for event in events:
events_by_room.setdefault(event.room_id, []).append(event)
async def handle_event(event):
async def handle_event(event: EventBase) -> None:
# Gather interested services
services = await self._get_services_for_event(event)
if len(services) == 0:
@ -116,9 +116,9 @@ class ApplicationServicesHandler:
if not self.started_scheduler:
async def start_scheduler():
async def start_scheduler() -> None:
try:
return await self.scheduler.start()
await self.scheduler.start()
except Exception:
logger.error("Application Services Failure")
@ -137,7 +137,7 @@ class ApplicationServicesHandler:
"appservice_sender"
).observe((now - ts) / 1000)
async def handle_room_events(events):
async def handle_room_events(events: Iterable[EventBase]) -> None:
for event in events:
await handle_event(event)
@ -184,7 +184,7 @@ class ApplicationServicesHandler:
stream_key: str,
new_token: Optional[int],
users: Optional[Collection[Union[str, UserID]]] = None,
):
) -> None:
"""This is called by the notifier in the background
when a ephemeral event handled by the homeserver.
@ -226,7 +226,7 @@ class ApplicationServicesHandler:
stream_key: str,
new_token: Optional[int],
users: Collection[Union[str, UserID]],
):
) -> None:
logger.debug("Checking interested services for %s" % (stream_key))
with Measure(self.clock, "notify_interested_services_ephemeral"):
for service in services:
@ -254,7 +254,7 @@ class ApplicationServicesHandler:
async def _handle_typing(
self, service: ApplicationService, new_token: int
) -> List[JsonDict]:
typing_source = self.event_sources.sources["typing"]
typing_source = self.event_sources.sources.typing
# Get the typing events from just before current
typing, _ = await typing_source.get_new_events_as(
service=service,
@ -269,7 +269,7 @@ class ApplicationServicesHandler:
from_key = await self.store.get_type_stream_id_for_appservice(
service, "read_receipt"
)
receipts_source = self.event_sources.sources["receipt"]
receipts_source = self.event_sources.sources.receipt
receipts, _ = await receipts_source.get_new_events_as(
service=service, from_key=from_key
)
@ -279,7 +279,7 @@ class ApplicationServicesHandler:
self, service: ApplicationService, users: Collection[Union[str, UserID]]
) -> List[JsonDict]:
events: List[JsonDict] = []
presence_source = self.event_sources.sources["presence"]
presence_source = self.event_sources.sources.presence
from_key = await self.store.get_type_stream_id_for_appservice(
service, "presence"
)

View file

@ -29,6 +29,7 @@ from typing import (
Mapping,
Optional,
Tuple,
Type,
Union,
cast,
)
@ -439,7 +440,7 @@ class AuthHandler(BaseHandler):
return ui_auth_types
def get_enabled_auth_types(self):
def get_enabled_auth_types(self) -> Iterable[str]:
"""Return the enabled user-interactive authentication types
Returns the UI-Auth types which are supported by the homeserver's current
@ -702,7 +703,7 @@ class AuthHandler(BaseHandler):
except StoreError:
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
async def _expire_old_sessions(self):
async def _expire_old_sessions(self) -> None:
"""
Invalidate any user interactive authentication sessions that have expired.
"""
@ -1347,12 +1348,12 @@ class AuthHandler(BaseHandler):
try:
res = self.macaroon_gen.verify_short_term_login_token(login_token)
except Exception:
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
raise AuthError(403, "Invalid login token", errcode=Codes.FORBIDDEN)
await self.auth.check_auth_blocking(res.user_id)
return res
async def delete_access_token(self, access_token: str):
async def delete_access_token(self, access_token: str) -> None:
"""Invalidate a single access token
Args:
@ -1381,7 +1382,7 @@ class AuthHandler(BaseHandler):
user_id: str,
except_token_id: Optional[int] = None,
device_id: Optional[str] = None,
):
) -> None:
"""Invalidate access tokens belonging to a user
Args:
@ -1409,7 +1410,7 @@ class AuthHandler(BaseHandler):
async def add_threepid(
self, user_id: str, medium: str, address: str, validated_at: int
):
) -> None:
# check if medium has a valid value
if medium not in ["email", "msisdn"]:
raise SynapseError(
@ -1480,7 +1481,7 @@ class AuthHandler(BaseHandler):
Hashed password.
"""
def _do_hash():
def _do_hash() -> str:
# Normalise the Unicode in the password
pw = unicodedata.normalize("NFKC", password)
@ -1504,7 +1505,7 @@ class AuthHandler(BaseHandler):
Whether self.hash(password) == stored_hash.
"""
def _do_validate_hash(checked_hash: bytes):
def _do_validate_hash(checked_hash: bytes) -> bool:
# Normalise the Unicode in the password
pw = unicodedata.normalize("NFKC", password)
@ -1581,7 +1582,7 @@ class AuthHandler(BaseHandler):
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
new_user: bool = False,
):
) -> None:
"""Having figured out a mxid for this user, complete the HTTP request
Args:
@ -1627,7 +1628,7 @@ class AuthHandler(BaseHandler):
extra_attributes: Optional[JsonDict] = None,
new_user: bool = False,
user_profile_data: Optional[ProfileInfo] = None,
):
) -> None:
"""
The synchronous portion of complete_sso_login.
@ -1726,7 +1727,7 @@ class AuthHandler(BaseHandler):
del self._extra_attributes[user_id]
@staticmethod
def add_query_param_to_url(url: str, param_name: str, param: Any):
def add_query_param_to_url(url: str, param_name: str, param: Any) -> str:
url_parts = list(urllib.parse.urlparse(url))
query = urllib.parse.parse_qsl(url_parts[4], keep_blank_values=True)
query.append((param_name, param))
@ -1734,9 +1735,9 @@ class AuthHandler(BaseHandler):
return urllib.parse.urlunparse(url_parts)
@attr.s(slots=True)
@attr.s(slots=True, auto_attribs=True)
class MacaroonGenerator:
hs = attr.ib()
hs: "HomeServer"
def generate_guest_access_token(self, user_id: str) -> str:
macaroon = self._generate_base_macaroon(user_id)
@ -1816,7 +1817,9 @@ class PasswordProvider:
"""
@classmethod
def load(cls, module, config, module_api: ModuleApi) -> "PasswordProvider":
def load(
cls, module: Type, config: JsonDict, module_api: ModuleApi
) -> "PasswordProvider":
try:
pp = module(config=config, account_handler=module_api)
except Exception as e:
@ -1824,7 +1827,7 @@ class PasswordProvider:
raise
return cls(pp, module_api)
def __init__(self, pp, module_api: ModuleApi):
def __init__(self, pp: "PasswordProvider", module_api: ModuleApi):
self._pp = pp
self._module_api = module_api
@ -1838,7 +1841,7 @@ class PasswordProvider:
if g:
self._supported_login_types.update(g())
def __str__(self):
def __str__(self) -> str:
return str(self._pp)
def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
@ -1876,19 +1879,19 @@ class PasswordProvider:
"""
# first grandfather in a call to check_password
if login_type == LoginType.PASSWORD:
g = getattr(self._pp, "check_password", None)
if g:
check_password = getattr(self._pp, "check_password", None)
if check_password:
qualified_user_id = self._module_api.get_qualified_user_id(username)
is_valid = await self._pp.check_password(
is_valid = await check_password(
qualified_user_id, login_dict["password"]
)
if is_valid:
return qualified_user_id, None
g = getattr(self._pp, "check_auth", None)
if not g:
check_auth = getattr(self._pp, "check_auth", None)
if not check_auth:
return None
result = await g(username, login_type, login_dict)
result = await check_auth(username, login_type, login_dict)
# Check if the return value is a str or a tuple
if isinstance(result, str):

View file

@ -34,20 +34,20 @@ logger = logging.getLogger(__name__)
class CasError(Exception):
"""Used to catch errors when validating the CAS ticket."""
def __init__(self, error, error_description=None):
def __init__(self, error: str, error_description: Optional[str] = None):
self.error = error
self.error_description = error_description
def __str__(self):
def __str__(self) -> str:
if self.error_description:
return f"{self.error}: {self.error_description}"
return self.error
@attr.s(slots=True, frozen=True)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class CasResponse:
username = attr.ib(type=str)
attributes = attr.ib(type=Dict[str, List[Optional[str]]])
username: str
attributes: Dict[str, List[Optional[str]]]
class CasHandler:
@ -133,11 +133,9 @@ class CasHandler:
body = pde.response
except HttpResponseException as e:
description = (
(
'Authorization server responded with a "{status}" error '
"while exchanging the authorization code."
).format(status=e.code),
)
'Authorization server responded with a "{status}" error '
"while exchanging the authorization code."
).format(status=e.code)
raise CasError("server_error", description) from e
return self._parse_cas_response(body)

View file

@ -257,11 +257,8 @@ class DeactivateAccountHandler(BaseHandler):
"""
# Add the user to the directory, if necessary.
user = UserID.from_string(user_id)
if self.hs.config.user_directory_search_all_users:
profile = await self.store.get_profileinfo(user.localpart)
await self.user_directory_handler.handle_local_profile_change(
user_id, profile
)
profile = await self.store.get_profileinfo(user.localpart)
await self.user_directory_handler.handle_local_profile_change(user_id, profile)
# Ensure the user is not marked as erased.
await self.store.mark_user_not_erased(user_id)

View file

@ -267,7 +267,7 @@ class DeviceHandler(DeviceWorkerHandler):
hs.get_distributor().observe("user_left_room", self.user_left_room)
def _check_device_name_length(self, name: Optional[str]):
def _check_device_name_length(self, name: Optional[str]) -> None:
"""
Checks whether a device name is longer than the maximum allowed length.

View file

@ -202,7 +202,7 @@ class E2eKeysHandler:
# Now fetch any devices that we don't have in our cache
@trace
async def do_remote_query(destination):
async def do_remote_query(destination: str) -> None:
"""This is called when we are querying the device list of a user on
a remote homeserver and their device list is not in the device list
cache. If we share a room with this user and we're not querying for
@ -447,7 +447,7 @@ class E2eKeysHandler:
}
@trace
async def claim_client_keys(destination):
async def claim_client_keys(destination: str) -> None:
set_tag("destination", destination)
device_keys = remote_queries[destination]
try:

View file

@ -25,6 +25,7 @@ from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase
from synapse.events.builder import EventBuilder
from synapse.events.snapshot import EventContext
from synapse.types import StateMap, get_domain_from_id
from synapse.util.metrics import Measure
@ -45,7 +46,11 @@ class EventAuthHandler:
self._server_name = hs.hostname
async def check_from_context(
self, room_version: str, event, context, do_sig_check=True
self,
room_version: str,
event: EventBase,
context: EventContext,
do_sig_check: bool = True,
) -> None:
auth_event_ids = event.auth_event_ids()
auth_events_by_id = await self._store.get_events(auth_event_ids)

View file

@ -1221,136 +1221,6 @@ class FederationHandler(BaseHandler):
return missing_events
async def construct_auth_difference(
self, local_auth: Iterable[EventBase], remote_auth: Iterable[EventBase]
) -> Dict:
"""Given a local and remote auth chain, find the differences. This
assumes that we have already processed all events in remote_auth
Params:
local_auth
remote_auth
Returns:
dict
"""
logger.debug("construct_auth_difference Start!")
# TODO: Make sure we are OK with local_auth or remote_auth having more
# auth events in them than strictly necessary.
def sort_fun(ev):
return ev.depth, ev.event_id
logger.debug("construct_auth_difference after sort_fun!")
# We find the differences by starting at the "bottom" of each list
# and iterating up on both lists. The lists are ordered by depth and
# then event_id, we iterate up both lists until we find the event ids
# don't match. Then we look at depth/event_id to see which side is
# missing that event, and iterate only up that list. Repeat.
remote_list = list(remote_auth)
remote_list.sort(key=sort_fun)
local_list = list(local_auth)
local_list.sort(key=sort_fun)
local_iter = iter(local_list)
remote_iter = iter(remote_list)
logger.debug("construct_auth_difference before get_next!")
def get_next(it, opt=None):
try:
return next(it)
except Exception:
return opt
current_local = get_next(local_iter)
current_remote = get_next(remote_iter)
logger.debug("construct_auth_difference before while")
missing_remotes = []
missing_locals = []
while current_local or current_remote:
if current_remote is None:
missing_locals.append(current_local)
current_local = get_next(local_iter)
continue
if current_local is None:
missing_remotes.append(current_remote)
current_remote = get_next(remote_iter)
continue
if current_local.event_id == current_remote.event_id:
current_local = get_next(local_iter)
current_remote = get_next(remote_iter)
continue
if current_local.depth < current_remote.depth:
missing_locals.append(current_local)
current_local = get_next(local_iter)
continue
if current_local.depth > current_remote.depth:
missing_remotes.append(current_remote)
current_remote = get_next(remote_iter)
continue
# They have the same depth, so we fall back to the event_id order
if current_local.event_id < current_remote.event_id:
missing_locals.append(current_local)
current_local = get_next(local_iter)
if current_local.event_id > current_remote.event_id:
missing_remotes.append(current_remote)
current_remote = get_next(remote_iter)
continue
logger.debug("construct_auth_difference after while")
# missing locals should be sent to the server
# We should find why we are missing remotes, as they will have been
# rejected.
# Remove events from missing_remotes if they are referencing a missing
# remote. We only care about the "root" rejected ones.
missing_remote_ids = [e.event_id for e in missing_remotes]
base_remote_rejected = list(missing_remotes)
for e in missing_remotes:
for e_id in e.auth_event_ids():
if e_id in missing_remote_ids:
try:
base_remote_rejected.remove(e)
except ValueError:
pass
reason_map = {}
for e in base_remote_rejected:
reason = await self.store.get_rejection_reason(e.event_id)
if reason is None:
# TODO: e is not in the current state, so we should
# construct some proof of that.
continue
reason_map[e.event_id] = reason
logger.debug("construct_auth_difference returning")
return {
"auth_chain": local_auth,
"rejects": {
e.event_id: {"reason": reason_map[e.event_id], "proof": None}
for e in base_remote_rejected
},
"missing": [e.event_id for e in missing_locals],
}
@log_function
async def exchange_third_party_invite(
self, sender_user_id: str, target_user_id: str, room_id: str, signed: JsonDict

View file

@ -1016,7 +1016,7 @@ class FederationEventHandler:
except Exception:
logger.exception("Failed to resync device for %s", sender)
async def _handle_marker_event(self, origin: str, marker_event: EventBase):
async def _handle_marker_event(self, origin: str, marker_event: EventBase) -> None:
"""Handles backfilling the insertion event when we receive a marker
event that points to one.
@ -1109,7 +1109,7 @@ class FederationEventHandler:
event_map: Dict[str, EventBase] = {}
async def get_event(event_id: str):
async def get_event(event_id: str) -> None:
with nested_logging_context(event_id):
try:
event = await self._federation_client.get_pdu(
@ -1218,7 +1218,7 @@ class FederationEventHandler:
if not event_infos:
return
async def prep(ev_info: _NewEventInfo):
async def prep(ev_info: _NewEventInfo) -> EventContext:
event = ev_info.event
with nested_logging_context(suffix=event.event_id):
res = await self._state_handler.compute_event_context(event)
@ -1692,7 +1692,7 @@ class FederationEventHandler:
async def _run_push_actions_and_persist_event(
self, event: EventBase, context: EventContext, backfilled: bool = False
):
) -> None:
"""Run the push actions for a received event, and persist it.
Args:

View file

@ -14,7 +14,7 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Dict, Iterable, List, Set
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Iterable, List, Set
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
from synapse.types import GroupID, JsonDict, get_domain_from_id
@ -25,12 +25,14 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
def _create_rerouter(func_name):
def _create_rerouter(func_name: str) -> Callable[..., Awaitable[JsonDict]]:
"""Returns an async function that looks at the group id and calls the function
on federation or the local group server if the group is local
"""
async def f(self, group_id, *args, **kwargs):
async def f(
self: "GroupsLocalWorkerHandler", group_id: str, *args: Any, **kwargs: Any
) -> JsonDict:
if not GroupID.is_valid(group_id):
raise SynapseError(400, "%s is not a legal group ID" % (group_id,))

View file

@ -13,7 +13,7 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Optional, Tuple
from typing import TYPE_CHECKING, List, Optional, Tuple
from twisted.internet import defer
@ -125,7 +125,7 @@ class InitialSyncHandler(BaseHandler):
now_token = self.hs.get_event_sources().get_current_token()
presence_stream = self.hs.get_event_sources().sources["presence"]
presence_stream = self.hs.get_event_sources().sources.presence
presence, _ = await presence_stream.get_new_events(
user, from_key=None, include_offline=False
)
@ -150,7 +150,7 @@ class InitialSyncHandler(BaseHandler):
if limit is None:
limit = 10
async def handle_room(event: RoomsForUser):
async def handle_room(event: RoomsForUser) -> None:
d: JsonDict = {
"room_id": event.room_id,
"membership": event.membership,
@ -411,9 +411,9 @@ class InitialSyncHandler(BaseHandler):
presence_handler = self.hs.get_presence_handler()
async def get_presence():
async def get_presence() -> List[JsonDict]:
# If presence is disabled, return an empty list
if not self.hs.config.use_presence:
if not self.hs.config.server.use_presence:
return []
states = await presence_handler.get_states(
@ -428,7 +428,7 @@ class InitialSyncHandler(BaseHandler):
for s in states
]
async def get_receipts():
async def get_receipts() -> List[JsonDict]:
receipts = await self.store.get_linearized_receipts_for_room(
room_id, to_key=now_token.receipt_key
)

View file

@ -46,6 +46,7 @@ from synapse.events import EventBase
from synapse.events.builder import EventBuilder
from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
from synapse.handlers.directory import DirectoryHandler
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
@ -298,7 +299,7 @@ class MessageHandler:
for user_id, profile in users_with_profile.items()
}
def maybe_schedule_expiry(self, event: EventBase):
def maybe_schedule_expiry(self, event: EventBase) -> None:
"""Schedule the expiry of an event if there's not already one scheduled,
or if the one running is for an event that will expire after the provided
timestamp.
@ -318,7 +319,7 @@ class MessageHandler:
# a task scheduled for a timestamp that's sooner than the provided one.
self._schedule_expiry_for_event(event.event_id, expiry_ts)
async def _schedule_next_expiry(self):
async def _schedule_next_expiry(self) -> None:
"""Retrieve the ID and the expiry timestamp of the next event to be expired,
and schedule an expiry task for it.
@ -331,7 +332,7 @@ class MessageHandler:
event_id, expiry_ts = res
self._schedule_expiry_for_event(event_id, expiry_ts)
def _schedule_expiry_for_event(self, event_id: str, expiry_ts: int):
def _schedule_expiry_for_event(self, event_id: str, expiry_ts: int) -> None:
"""Schedule an expiry task for the provided event if there's not already one
scheduled at a timestamp that's sooner than the provided one.
@ -367,7 +368,7 @@ class MessageHandler:
event_id,
)
async def _expire_event(self, event_id: str):
async def _expire_event(self, event_id: str) -> None:
"""Retrieve and expire an event that needs to be expired from the database.
If the event doesn't exist in the database, log it and delete the expiry date
@ -1229,7 +1230,10 @@ class EventCreationHandler:
self._external_cache_joined_hosts_updates[state_entry.state_group] = None
async def _validate_canonical_alias(
self, directory_handler, room_alias_str: str, expected_room_id: str
self,
directory_handler: DirectoryHandler,
room_alias_str: str,
expected_room_id: str,
) -> None:
"""
Ensure that the given room alias points to the expected room ID.
@ -1421,7 +1425,7 @@ class EventCreationHandler:
# structural protocol level).
is_msc2716_event = (
original_event.type == EventTypes.MSC2716_INSERTION
or original_event.type == EventTypes.MSC2716_CHUNK
or original_event.type == EventTypes.MSC2716_BATCH
or original_event.type == EventTypes.MSC2716_MARKER
)
if not room_version_obj.msc2716_historical and is_msc2716_event:
@ -1477,7 +1481,7 @@ class EventCreationHandler:
# If there's an expiry timestamp on the event, schedule its expiry.
self._message_handler.maybe_schedule_expiry(event)
def _notify():
def _notify() -> None:
try:
self.notifier.on_new_room_event(
event, event_pos, max_stream_token, extra_users=extra_users
@ -1523,7 +1527,7 @@ class EventCreationHandler:
except Exception:
logger.exception("Error bumping presence active time")
async def _send_dummy_events_to_fill_extremities(self):
async def _send_dummy_events_to_fill_extremities(self) -> None:
"""Background task to send dummy events into rooms that have a large
number of extremities
"""
@ -1600,7 +1604,7 @@ class EventCreationHandler:
)
return False
def _expire_rooms_to_exclude_from_dummy_event_insertion(self):
def _expire_rooms_to_exclude_from_dummy_event_insertion(self) -> None:
expire_before = self.clock.time_msec() - _DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY
to_expire = set()
for room_id, time in self._rooms_to_exclude_from_dummy_event_insertion.items():

View file

@ -14,7 +14,7 @@
# limitations under the License.
import inspect
import logging
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar, Union
from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, TypeVar, Union
from urllib.parse import urlencode, urlparse
import attr
@ -249,11 +249,11 @@ class OidcHandler:
class OidcError(Exception):
"""Used to catch errors when calling the token_endpoint"""
def __init__(self, error, error_description=None):
def __init__(self, error: str, error_description: Optional[str] = None):
self.error = error
self.error_description = error_description
def __str__(self):
def __str__(self) -> str:
if self.error_description:
return f"{self.error}: {self.error_description}"
return self.error
@ -1057,13 +1057,13 @@ class JwtClientSecret:
self._cached_secret = b""
self._cached_secret_replacement_time = 0
def __str__(self):
def __str__(self) -> str:
# if client_auth_method is client_secret_basic, then ClientAuth.prepare calls
# encode_client_secret_basic, which calls "{}".format(secret), which ends up
# here.
return self._get_secret().decode("ascii")
def __bytes__(self):
def __bytes__(self) -> bytes:
# if client_auth_method is client_secret_post, then ClientAuth.prepare calls
# encode_client_secret_post, which ends up here.
return self._get_secret()
@ -1197,21 +1197,21 @@ class OidcSessionTokenGenerator:
)
@attr.s(frozen=True, slots=True)
@attr.s(frozen=True, slots=True, auto_attribs=True)
class OidcSessionData:
"""The attributes which are stored in a OIDC session cookie"""
# the Identity Provider being used
idp_id = attr.ib(type=str)
idp_id: str
# The `nonce` parameter passed to the OIDC provider.
nonce = attr.ib(type=str)
nonce: str
# The URL the client gave when it initiated the flow. ("" if this is a UI Auth)
client_redirect_url = attr.ib(type=str)
client_redirect_url: str
# The session ID of the ongoing UI Auth ("" if this is a login)
ui_auth_session_id = attr.ib(type=str)
ui_auth_session_id: str
class UserAttributeDict(TypedDict):
@ -1290,20 +1290,20 @@ class OidcMappingProvider(Generic[C]):
# Used to clear out "None" values in templates
def jinja_finalize(thing):
def jinja_finalize(thing: Any) -> Any:
return thing if thing is not None else ""
env = Environment(finalize=jinja_finalize)
@attr.s(slots=True, frozen=True)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class JinjaOidcMappingConfig:
subject_claim = attr.ib(type=str)
localpart_template = attr.ib(type=Optional[Template])
display_name_template = attr.ib(type=Optional[Template])
email_template = attr.ib(type=Optional[Template])
extra_attributes = attr.ib(type=Dict[str, Template])
subject_claim: str
localpart_template: Optional[Template]
display_name_template: Optional[Template]
email_template: Optional[Template]
extra_attributes: Dict[str, Template]
class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):

View file

@ -15,6 +15,8 @@
import logging
from typing import TYPE_CHECKING, Any, Dict, Optional, Set
import attr
from twisted.python.failure import Failure
from synapse.api.constants import EventTypes, Membership
@ -24,7 +26,7 @@ from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.state import StateFilter
from synapse.streams.config import PaginationConfig
from synapse.types import Requester
from synapse.types import JsonDict, Requester
from synapse.util.async_helpers import ReadWriteLock
from synapse.util.stringutils import random_string
from synapse.visibility import filter_events_for_client
@ -36,15 +38,12 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
@attr.s(slots=True, auto_attribs=True)
class PurgeStatus:
"""Object tracking the status of a purge request
This class contains information on the progress of a purge request, for
return by get_purge_status.
Attributes:
status (int): Tracks whether this request has completed. One of
STATUS_{ACTIVE,COMPLETE,FAILED}
"""
STATUS_ACTIVE = 0
@ -57,10 +56,10 @@ class PurgeStatus:
STATUS_FAILED: "failed",
}
def __init__(self):
self.status = PurgeStatus.STATUS_ACTIVE
# Tracks whether this request has completed. One of STATUS_{ACTIVE,COMPLETE,FAILED}.
status: int = STATUS_ACTIVE
def asdict(self):
def asdict(self) -> JsonDict:
return {"status": PurgeStatus.STATUS_TEXT[self.status]}
@ -107,7 +106,7 @@ class PaginationHandler:
async def purge_history_for_rooms_in_range(
self, min_ms: Optional[int], max_ms: Optional[int]
):
) -> None:
"""Purge outdated events from rooms within the given retention range.
If a default retention policy is defined in the server's configuration and its
@ -291,7 +290,7 @@ class PaginationHandler:
self._purges_in_progress_by_room.discard(room_id)
# remove the purge from the list 24 hours after it completes
def clear_purge():
def clear_purge() -> None:
del self._purges_by_id[purge_id]
self.hs.get_reactor().callLater(24 * 3600, clear_purge)

View file

@ -26,18 +26,22 @@ import contextlib
import logging
from bisect import bisect
from contextlib import contextmanager
from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Collection,
Dict,
FrozenSet,
Generator,
Iterable,
List,
Optional,
Set,
Tuple,
Type,
Union,
)
@ -61,6 +65,7 @@ from synapse.replication.http.streams import ReplicationGetStreamUpdates
from synapse.replication.tcp.commands import ClearUserSyncsCommand
from synapse.replication.tcp.streams import PresenceFederationStream, PresenceStream
from synapse.storage.databases.main import DataStore
from synapse.streams import EventSource
from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.descriptors import _CacheContext, cached
@ -240,7 +245,7 @@ class BasePresenceHandler(abc.ABC):
"""
@abc.abstractmethod
async def bump_presence_active_time(self, user: UserID):
async def bump_presence_active_time(self, user: UserID) -> None:
"""We've seen the user do something that indicates they're interacting
with the app.
"""
@ -274,7 +279,7 @@ class BasePresenceHandler(abc.ABC):
async def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: list
):
) -> None:
"""Process streams received over replication."""
await self._federation_queue.process_replication_rows(
stream_name, instance_name, token, rows
@ -286,7 +291,7 @@ class BasePresenceHandler(abc.ABC):
async def maybe_send_presence_to_interested_destinations(
self, states: List[UserPresenceState]
):
) -> None:
"""If this instance is a federation sender, send the states to all
destinations that are interested. Filters out any states for remote
users.
@ -309,7 +314,7 @@ class BasePresenceHandler(abc.ABC):
for destination, host_states in hosts_to_states.items():
self._federation.send_presence_to_destinations(host_states, [destination])
async def send_full_presence_to_users(self, user_ids: Collection[str]):
async def send_full_presence_to_users(self, user_ids: Collection[str]) -> None:
"""
Adds to the list of users who should receive a full snapshot of presence
upon their next sync. Note that this only works for local users.
@ -363,7 +368,12 @@ class BasePresenceHandler(abc.ABC):
class _NullContextManager(ContextManager[None]):
"""A context manager which does nothing."""
def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
pass
@ -374,7 +384,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
self._presence_writer_instance = hs.config.worker.writers.presence[0]
self._presence_enabled = hs.config.use_presence
self._presence_enabled = hs.config.server.use_presence
# Route presence EDUs to the right worker
hs.get_federation_registry().register_instances_for_edu(
@ -468,7 +478,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
if self._user_to_num_current_syncs[user_id] == 1:
self.mark_as_coming_online(user_id)
def _end():
def _end() -> None:
# We check that the user_id is in user_to_num_current_syncs because
# user_to_num_current_syncs may have been cleared if we are
# shutting down.
@ -480,7 +490,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
self.mark_as_going_offline(user_id)
@contextlib.contextmanager
def _user_syncing():
def _user_syncing() -> Generator[None, None, None]:
try:
yield
finally:
@ -503,7 +513,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
async def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: list
):
) -> None:
await super().process_replication_rows(stream_name, instance_name, token, rows)
if stream_name != PresenceStream.NAME:
@ -584,7 +594,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
user_id = target_user.to_string()
# If presence is disabled, no-op
if not self.hs.config.use_presence:
if not self.hs.config.server.use_presence:
return
# Proxy request to instance that writes presence
@ -601,7 +611,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
with the app.
"""
# If presence is disabled, no-op
if not self.hs.config.use_presence:
if not self.hs.config.server.use_presence:
return
# Proxy request to instance that writes presence
@ -618,7 +628,7 @@ class PresenceHandler(BasePresenceHandler):
self.server_name = hs.hostname
self.wheel_timer: WheelTimer[str] = WheelTimer()
self.notifier = hs.get_notifier()
self._presence_enabled = hs.config.use_presence
self._presence_enabled = hs.config.server.use_presence
federation_registry = hs.get_federation_registry()
@ -689,7 +699,7 @@ class PresenceHandler(BasePresenceHandler):
# Start a LoopingCall in 30s that fires every 5s.
# The initial delay is to allow disconnected clients a chance to
# reconnect before we treat them as offline.
def run_timeout_handler():
def run_timeout_handler() -> Awaitable[None]:
return run_as_background_process(
"handle_presence_timeouts", self._handle_timeouts
)
@ -698,7 +708,7 @@ class PresenceHandler(BasePresenceHandler):
30, self.clock.looping_call, run_timeout_handler, 5000
)
def run_persister():
def run_persister() -> Awaitable[None]:
return run_as_background_process(
"persist_presence_changes", self._persist_unpersisted_changes
)
@ -916,7 +926,7 @@ class PresenceHandler(BasePresenceHandler):
with the app.
"""
# If presence is disabled, no-op
if not self.hs.config.use_presence:
if not self.hs.config.server.use_presence:
return
user_id = user.to_string()
@ -942,14 +952,14 @@ class PresenceHandler(BasePresenceHandler):
when users disconnect/reconnect.
Args:
user_id (str)
affect_presence (bool): If false this function will be a no-op.
user_id
affect_presence: If false this function will be a no-op.
Useful for streams that are not associated with an actual
client that is being used by a user.
"""
# Override if it should affect the user's presence, if presence is
# disabled.
if not self.hs.config.use_presence:
if not self.hs.config.server.use_presence:
affect_presence = False
if affect_presence:
@ -978,7 +988,7 @@ class PresenceHandler(BasePresenceHandler):
]
)
async def _end():
async def _end() -> None:
try:
self.user_to_num_current_syncs[user_id] -= 1
@ -994,7 +1004,7 @@ class PresenceHandler(BasePresenceHandler):
logger.exception("Error updating presence after sync")
@contextmanager
def _user_syncing():
def _user_syncing() -> Generator[None, None, None]:
try:
yield
finally:
@ -1264,7 +1274,7 @@ class PresenceHandler(BasePresenceHandler):
if self._event_processing:
return
async def _process_presence():
async def _process_presence() -> None:
assert not self._event_processing
self._event_processing = True
@ -1491,7 +1501,7 @@ def format_user_presence_state(
return content
class PresenceEventSource:
class PresenceEventSource(EventSource[int, UserPresenceState]):
def __init__(self, hs: "HomeServer"):
# We can't call get_presence_handler here because there's a cycle:
#
@ -1510,10 +1520,11 @@ class PresenceEventSource:
self,
user: UserID,
from_key: Optional[int],
limit: Optional[int] = None,
room_ids: Optional[List[str]] = None,
include_offline: bool = True,
is_guest: bool = False,
explicit_room_id: Optional[str] = None,
**kwargs,
include_offline: bool = True,
) -> Tuple[List[UserPresenceState], int]:
# The process for getting presence events are:
# 1. Get the rooms the user is in.
@ -2074,7 +2085,7 @@ class PresenceFederationQueue:
if self._queue_presence_updates:
self._clock.looping_call(self._clear_queue, self._CLEAR_ITEMS_EVERY_MS)
def _clear_queue(self):
def _clear_queue(self) -> None:
"""Clear out older entries from the queue."""
clear_before = self._clock.time_msec() - self._KEEP_ITEMS_IN_QUEUE_FOR_MS
@ -2205,7 +2216,7 @@ class PresenceFederationQueue:
async def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: list
):
) -> None:
if stream_name != PresenceFederationStream.NAME:
return

View file

@ -214,11 +214,10 @@ class ProfileHandler(BaseHandler):
target_user.localpart, displayname_to_set
)
if self.hs.config.user_directory_search_all_users:
profile = await self.store.get_profileinfo(target_user.localpart)
await self.user_directory_handler.handle_local_profile_change(
target_user.to_string(), profile
)
profile = await self.store.get_profileinfo(target_user.localpart)
await self.user_directory_handler.handle_local_profile_change(
target_user.to_string(), profile
)
await self._update_join_states(requester, target_user)
@ -254,7 +253,7 @@ class ProfileHandler(BaseHandler):
requester: Requester,
new_avatar_url: str,
by_admin: bool = False,
):
) -> None:
"""Set a new avatar URL for a user.
Args:
@ -300,11 +299,10 @@ class ProfileHandler(BaseHandler):
target_user.localpart, avatar_url_to_set
)
if self.hs.config.user_directory_search_all_users:
profile = await self.store.get_profileinfo(target_user.localpart)
await self.user_directory_handler.handle_local_profile_change(
target_user.to_string(), profile
)
profile = await self.store.get_profileinfo(target_user.localpart)
await self.user_directory_handler.handle_local_profile_change(
target_user.to_string(), profile
)
await self._update_join_states(requester, target_user)
@ -425,7 +423,7 @@ class ProfileHandler(BaseHandler):
raise
@wrap_as_background_process("Update remote profile")
async def _update_remote_profile_cache(self):
async def _update_remote_profile_cache(self) -> None:
"""Called periodically to check profiles of remote users we haven't
checked in a while.
"""

View file

@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, List, Optional, Tuple
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
from synapse.api.constants import ReadReceiptEventFields
from synapse.appservice import ApplicationService
from synapse.handlers._base import BaseHandler
from synapse.streams import EventSource
from synapse.types import JsonDict, ReadReceipt, UserID, get_domain_from_id
if TYPE_CHECKING:
@ -162,7 +163,7 @@ class ReceiptsHandler(BaseHandler):
await self.federation_sender.send_read_receipt(receipt)
class ReceiptEventSource:
class ReceiptEventSource(EventSource[int, JsonDict]):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.config = hs.config
@ -216,7 +217,13 @@ class ReceiptEventSource:
return visible_events
async def get_new_events(
self, from_key: int, room_ids: List[str], user: UserID, **kwargs
self,
user: UserID,
from_key: int,
limit: Optional[int],
room_ids: Iterable[str],
is_guest: bool,
explicit_room_id: Optional[str] = None,
) -> Tuple[List[JsonDict], int]:
from_key = int(from_key)
to_key = self.get_current_key()

View file

@ -125,7 +125,7 @@ class RegistrationHandler(BaseHandler):
localpart: str,
guest_access_token: Optional[str] = None,
assigned_user_id: Optional[str] = None,
):
) -> None:
if types.contains_invalid_mxid_characters(localpart):
raise SynapseError(
400,
@ -295,11 +295,10 @@ class RegistrationHandler(BaseHandler):
shadow_banned=shadow_banned,
)
if self.hs.config.user_directory_search_all_users:
profile = await self.store.get_profileinfo(localpart)
await self.user_directory_handler.handle_local_profile_change(
user_id, profile
)
profile = await self.store.get_profileinfo(localpart)
await self.user_directory_handler.handle_local_profile_change(
user_id, profile
)
else:
# autogen a sequential user ID

View file

@ -1,6 +1,4 @@
# Copyright 2014 - 2016 OpenMarket Ltd
# Copyright 2018-2019 New Vector Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C.
# Copyright 2016-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -22,7 +20,16 @@ import math
import random
import string
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Collection,
Dict,
List,
Optional,
Tuple,
)
from synapse.api.constants import (
EventContentFields,
@ -49,6 +56,7 @@ from synapse.events import EventBase
from synapse.events.utils import copy_power_levels_contents
from synapse.rest.admin._base import assert_user_is_admin
from synapse.storage.state import StateFilter
from synapse.streams import EventSource
from synapse.types import (
JsonDict,
MutableStateMap,
@ -186,7 +194,7 @@ class RoomCreationHandler(BaseHandler):
async def _upgrade_room(
self, requester: Requester, old_room_id: str, new_version: RoomVersion
):
) -> str:
"""
Args:
requester: the user requesting the upgrade
@ -512,7 +520,7 @@ class RoomCreationHandler(BaseHandler):
old_room_id: str,
new_room_id: str,
old_room_state: StateMap[str],
):
) -> None:
# check to see if we have a canonical alias.
canonical_alias_event = None
canonical_alias_event_id = old_room_state.get((EventTypes.CanonicalAlias, ""))
@ -902,7 +910,7 @@ class RoomCreationHandler(BaseHandler):
event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
def create(etype: str, content: JsonDict, **kwargs) -> JsonDict:
def create(etype: str, content: JsonDict, **kwargs: Any) -> JsonDict:
e = {"type": etype, "content": content}
e.update(event_keys)
@ -910,7 +918,7 @@ class RoomCreationHandler(BaseHandler):
return e
async def send(etype: str, content: JsonDict, **kwargs) -> int:
async def send(etype: str, content: JsonDict, **kwargs: Any) -> int:
event = create(etype, content, **kwargs)
logger.debug("Sending %s in new room", etype)
# Allow these events to be sent even if the user is shadow-banned to
@ -1033,7 +1041,7 @@ class RoomCreationHandler(BaseHandler):
creator_id: str,
is_public: bool,
room_version: RoomVersion,
):
) -> str:
# autogen room IDs and try to create it. We may clash, so just
# try a few times till one goes through, giving up eventually.
attempts = 0
@ -1097,7 +1105,7 @@ class RoomContextHandler:
users = await self.store.get_users_in_room(room_id)
is_peeking = user.to_string() not in users
async def filter_evts(events):
async def filter_evts(events: List[EventBase]) -> List[EventBase]:
if use_admin_priviledge:
return events
return await filter_events_for_client(
@ -1175,7 +1183,7 @@ class RoomContextHandler:
return results
class RoomEventSource:
class RoomEventSource(EventSource[RoomStreamToken, EventBase]):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
@ -1183,8 +1191,8 @@ class RoomEventSource:
self,
user: UserID,
from_key: RoomStreamToken,
limit: int,
room_ids: List[str],
limit: Optional[int],
room_ids: Collection[str],
is_guest: bool,
explicit_room_id: Optional[str] = None,
) -> Tuple[List[EventBase], RoomStreamToken]:

View file

@ -14,7 +14,7 @@
import logging
from collections import namedtuple
from typing import TYPE_CHECKING, Optional, Tuple
from typing import TYPE_CHECKING, Any, Optional, Tuple
import msgpack
from unpaddedbase64 import decode_base64, encode_base64
@ -33,7 +33,7 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.types import JsonDict, ThirdPartyInstanceID
from synapse.util.caches.descriptors import cached
from synapse.util.caches.descriptors import _CacheContext, cached
from synapse.util.caches.response_cache import ResponseCache
from ._base import BaseHandler
@ -169,7 +169,7 @@ class RoomListHandler(BaseHandler):
ignore_non_federatable=from_federation,
)
def build_room_entry(room):
def build_room_entry(room: JsonDict) -> JsonDict:
entry = {
"room_id": room["room_id"],
"name": room["name"],
@ -249,10 +249,10 @@ class RoomListHandler(BaseHandler):
self,
room_id: str,
num_joined_users: int,
cache_context,
cache_context: _CacheContext,
with_alias: bool = True,
allow_private: bool = False,
) -> Optional[dict]:
) -> Optional[JsonDict]:
"""Returns the entry for a room
Args:
@ -507,7 +507,7 @@ class RoomListNextBatch(
)
)
def copy_and_replace(self, **kwds) -> "RoomListNextBatch":
def copy_and_replace(self, **kwds: Any) -> "RoomListNextBatch":
return self._replace(**kwds)

View file

@ -226,7 +226,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
room_id: Optional[str],
n_invites: int,
update: bool = True,
):
) -> None:
"""Ratelimit more than one invite sent by the given requester in the given room.
Args:
@ -250,7 +250,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
requester: Optional[Requester],
room_id: Optional[str],
invitee_user_id: str,
):
) -> None:
"""Ratelimit invites by room and by target user.
If room ID is missing then we just rate limit by target user.
@ -387,7 +387,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
return result_event.event_id, result_event.internal_metadata.stream_ordering
async def copy_room_tags_and_direct_to_room(
self, old_room_id, new_room_id, user_id
self, old_room_id: str, new_room_id: str, user_id: str
) -> None:
"""Copies the tags and direct room state from one room to another.
@ -688,7 +688,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
" (membership=%s)" % old_membership,
errcode=Codes.BAD_STATE,
)
if old_membership == "ban" and action != "unban":
if old_membership == "ban" and action not in ["ban", "unban", "leave"]:
raise SynapseError(
403,
"Cannot %s user who was banned" % (action,),
@ -1050,7 +1050,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
event: EventBase,
context: EventContext,
ratelimit: bool = True,
):
) -> None:
"""
Change the membership status of a user in a room.

View file

@ -541,7 +541,7 @@ class RoomSummaryHandler:
origin: str,
requested_room_id: str,
suggested_only: bool,
):
) -> JsonDict:
"""
Implementation of the room hierarchy Federation API.

View file

@ -40,15 +40,15 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
@attr.s(slots=True)
@attr.s(slots=True, auto_attribs=True)
class Saml2SessionData:
"""Data we track about SAML2 sessions"""
# time the session was created, in milliseconds
creation_time = attr.ib()
creation_time: int
# The user interactive authentication session ID associated with this SAML
# session (or None if this SAML session is for an initial login).
ui_auth_session_id = attr.ib(type=Optional[str], default=None)
ui_auth_session_id: Optional[str] = None
class SamlHandler(BaseHandler):
@ -359,7 +359,7 @@ class SamlHandler(BaseHandler):
return remote_user_id
def expire_sessions(self):
def expire_sessions(self) -> None:
expire_before = self.clock.time_msec() - self._saml2_session_lifetime
to_expire = set()
for reqid, data in self._outstanding_requests_dict.items():
@ -391,10 +391,10 @@ MXID_MAPPER_MAP: Dict[str, Callable[[str], str]] = {
}
@attr.s
@attr.s(auto_attribs=True)
class SamlConfig:
mxid_source_attribute = attr.ib()
mxid_mapper = attr.ib()
mxid_source_attribute: str
mxid_mapper: Callable[[str], str]
class DefaultSamlMappingProvider:

View file

@ -17,7 +17,7 @@ import logging
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from io import BytesIO
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Any, Optional
from pkg_resources import parse_version
@ -79,7 +79,7 @@ async def _sendmail(
msg = BytesIO(msg_bytes)
d: "Deferred[object]" = Deferred()
def build_sender_factory(**kwargs) -> ESMTPSenderFactory:
def build_sender_factory(**kwargs: Any) -> ESMTPSenderFactory:
return ESMTPSenderFactory(
username,
password,

View file

@ -205,7 +205,7 @@ class SsoHandler:
self._consent_at_registration = hs.config.consent.user_consent_at_registration
def register_identity_provider(self, p: SsoIdentityProvider):
def register_identity_provider(self, p: SsoIdentityProvider) -> None:
p_id = p.idp_id
assert p_id not in self._identity_providers
self._identity_providers[p_id] = p
@ -856,7 +856,7 @@ class SsoHandler:
async def handle_terms_accepted(
self, request: Request, session_id: str, terms_version: str
):
) -> None:
"""Handle a request to the new-user 'consent' endpoint
Will serve an HTTP response to the request.
@ -959,7 +959,7 @@ class SsoHandler:
new_user=True,
)
def _expire_old_sessions(self):
def _expire_old_sessions(self) -> None:
to_expire = []
now = int(self._clock.time_msec())

View file

@ -68,7 +68,7 @@ class StatsHandler:
self._is_processing = True
async def process():
async def process() -> None:
try:
await self._unsafe_process()
finally:

View file

@ -364,7 +364,9 @@ class SyncHandler:
)
else:
async def current_sync_callback(before_token, after_token) -> SyncResult:
async def current_sync_callback(
before_token: StreamToken, after_token: StreamToken
) -> SyncResult:
return await self.current_sync_for_user(sync_config, since_token)
result = await self.notifier.wait_for_events(
@ -441,7 +443,7 @@ class SyncHandler:
room_ids = sync_result_builder.joined_room_ids
typing_source = self.event_sources.sources["typing"]
typing_source = self.event_sources.sources.typing
typing, typing_key = await typing_source.get_new_events(
user=sync_config.user,
from_key=typing_key,
@ -463,7 +465,7 @@ class SyncHandler:
receipt_key = since_token.receipt_key if since_token else 0
receipt_source = self.event_sources.sources["receipt"]
receipt_source = self.event_sources.sources.receipt
receipts, receipt_key = await receipt_source.get_new_events(
user=sync_config.user,
from_key=receipt_key,
@ -1090,7 +1092,7 @@ class SyncHandler:
block_all_presence_data = (
since_token is None and sync_config.filter_collection.blocks_all_presence()
)
if self.hs_config.use_presence and not block_all_presence_data:
if self.hs_config.server.use_presence and not block_all_presence_data:
logger.debug("Fetching presence data")
await self._generate_sync_entry_for_presence(
sync_result_builder,
@ -1413,7 +1415,7 @@ class SyncHandler:
sync_config = sync_result_builder.sync_config
user = sync_result_builder.sync_config.user
presence_source = self.event_sources.sources["presence"]
presence_source = self.event_sources.sources.presence
since_token = sync_result_builder.since_token
presence_key = None
@ -1532,9 +1534,9 @@ class SyncHandler:
newly_joined_rooms = room_changes.newly_joined_rooms
newly_left_rooms = room_changes.newly_left_rooms
async def handle_room_entries(room_entry: "RoomSyncResultBuilder"):
async def handle_room_entries(room_entry: "RoomSyncResultBuilder") -> None:
logger.debug("Generating room entry for %s", room_entry.room_id)
res = await self._generate_room_entry(
await self._generate_room_entry(
sync_result_builder,
ignored_users,
room_entry,
@ -1544,7 +1546,6 @@ class SyncHandler:
always_include=sync_result_builder.full_state,
)
logger.debug("Generated room entry for %s", room_entry.room_id)
return res
await concurrently_execute(handle_room_entries, room_entries, 10)
@ -1925,7 +1926,7 @@ class SyncHandler:
tags: Optional[Dict[str, Dict[str, Any]]],
account_data: Dict[str, JsonDict],
always_include: bool = False,
):
) -> None:
"""Populates the `joined` and `archived` section of `sync_result_builder`
based on the `room_builder`.

View file

@ -23,6 +23,7 @@ from synapse.metrics.background_process_metrics import (
wrap_as_background_process,
)
from synapse.replication.tcp.streams import TypingStream
from synapse.streams import EventSource
from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.metrics import Measure
@ -439,7 +440,7 @@ class TypingWriterHandler(FollowerTypingHandler):
raise Exception("Typing writer instance got typing info over replication")
class TypingNotificationEventSource:
class TypingNotificationEventSource(EventSource[int, JsonDict]):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.clock = hs.get_clock()
@ -485,7 +486,13 @@ class TypingNotificationEventSource:
return (events, handler._latest_room_serial)
async def get_new_events(
self, from_key: int, room_ids: Iterable[str], **kwargs
self,
user: UserID,
from_key: int,
limit: Optional[int],
room_ids: Iterable[str],
is_guest: bool,
explicit_room_id: Optional[str] = None,
) -> Tuple[List[JsonDict], int]:
with Measure(self.clock, "typing.get_new_events"):
from_key = int(from_key)

View file

@ -70,7 +70,7 @@ class DummyAuthChecker(UserInteractiveAuthChecker):
class TermsAuthChecker(UserInteractiveAuthChecker):
AUTH_TYPE = LoginType.TERMS
def is_enabled(self):
def is_enabled(self) -> bool:
return True
async def check_auth(self, authdict: dict, clientip: str) -> Any:

View file

@ -114,7 +114,7 @@ class UserDirectoryHandler(StateDeltasHandler):
if self._is_processing:
return
async def process():
async def process() -> None:
try:
await self._unsafe_process()
finally:

View file

@ -321,8 +321,11 @@ class SimpleHttpClient:
self.user_agent = hs.version_string
self.clock = hs.get_clock()
if hs.config.user_agent_suffix:
self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix)
if hs.config.server.user_agent_suffix:
self.user_agent = "%s %s" % (
self.user_agent,
hs.config.server.user_agent_suffix,
)
# We use this for our body producers to ensure that they use the correct
# reactor.

View file

@ -66,7 +66,7 @@ from synapse.http.client import (
)
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.logging import opentracing
from synapse.logging.context import make_deferred_yieldable
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.types import JsonDict
from synapse.util import json_decoder
@ -553,20 +553,29 @@ class MatrixFederationHttpClient:
with Measure(self.clock, "outbound_request"):
# we don't want all the fancy cookie and redirect handling
# that treq.request gives: just use the raw Agent.
request_deferred = self.agent.request(
# To preserve the logging context, the timeout is treated
# in a similar way to `defer.gatherResults`:
# * Each logging context-preserving fork is wrapped in
# `run_in_background`. In this case there is only one,
# since the timeout fork is not logging-context aware.
# * The `Deferred` that joins the forks back together is
# wrapped in `make_deferred_yieldable` to restore the
# logging context regardless of the path taken.
request_deferred = run_in_background(
self.agent.request,
method_bytes,
url_bytes,
headers=Headers(headers_dict),
bodyProducer=producer,
)
request_deferred = timeout_deferred(
request_deferred,
timeout=_sec_timeout,
reactor=self.reactor,
)
response = await request_deferred
response = await make_deferred_yieldable(request_deferred)
except DNSLookupError as e:
raise RequestSendFailed(e, can_retry=retry_on_dns_fail) from e
except Exception as e:

View file

@ -21,7 +21,7 @@ from zope.interface import implementer
from twisted.internet.interfaces import IAddress, IReactorTime
from twisted.python.failure import Failure
from twisted.web.resource import IResource
from twisted.web.resource import IResource, Resource
from twisted.web.server import Request, Site
from synapse.config.server import ListenerConfig
@ -61,7 +61,7 @@ class SynapseRequest(Request):
logcontext: the log context for this request
"""
def __init__(self, channel, *args, max_request_body_size=1024, **kw):
def __init__(self, channel, *args, max_request_body_size: int = 1024, **kw):
Request.__init__(self, channel, *args, **kw)
self._max_request_body_size = max_request_body_size
self.site: SynapseSite = channel.site
@ -83,13 +83,13 @@ class SynapseRequest(Request):
self._is_processing = False
# the time when the asynchronous request handler completed its processing
self._processing_finished_time = None
self._processing_finished_time: Optional[float] = None
# what time we finished sending the response to the client (or the connection
# dropped)
self.finish_time = None
self.finish_time: Optional[float] = None
def __repr__(self):
def __repr__(self) -> str:
# We overwrite this so that we don't log ``access_token``
return "<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>" % (
self.__class__.__name__,
@ -100,7 +100,7 @@ class SynapseRequest(Request):
self.site.site_tag,
)
def handleContentChunk(self, data):
def handleContentChunk(self, data: bytes) -> None:
# we should have a `content` by now.
assert self.content, "handleContentChunk() called before gotLength()"
if self.content.tell() + len(data) > self._max_request_body_size:
@ -139,7 +139,7 @@ class SynapseRequest(Request):
# If there's no authenticated entity, it was the requester.
self.logcontext.request.authenticated_entity = authenticated_entity or requester
def get_request_id(self):
def get_request_id(self) -> str:
return "%s-%i" % (self.get_method(), self.request_seq)
def get_redacted_uri(self) -> str:
@ -205,7 +205,7 @@ class SynapseRequest(Request):
return None, None
def render(self, resrc):
def render(self, resrc: Resource) -> None:
# this is called once a Resource has been found to serve the request; in our
# case the Resource in question will normally be a JsonResource.
@ -282,7 +282,7 @@ class SynapseRequest(Request):
if self.finish_time is not None:
self._finished_processing()
def finish(self):
def finish(self) -> None:
"""Called when all response data has been written to this Request.
Overrides twisted.web.server.Request.finish to record the finish time and do
@ -295,7 +295,7 @@ class SynapseRequest(Request):
with PreserveLoggingContext(self.logcontext):
self._finished_processing()
def connectionLost(self, reason):
def connectionLost(self, reason: Union[Failure, Exception]) -> None:
"""Called when the client connection is closed before the response is written.
Overrides twisted.web.server.Request.connectionLost to record the finish time and
@ -327,7 +327,7 @@ class SynapseRequest(Request):
if not self._is_processing:
self._finished_processing()
def _started_processing(self, servlet_name):
def _started_processing(self, servlet_name: str) -> None:
"""Record the fact that we are processing this request.
This will log the request's arrival. Once the request completes,
@ -354,9 +354,11 @@ class SynapseRequest(Request):
self.get_redacted_uri(),
)
def _finished_processing(self):
def _finished_processing(self) -> None:
"""Log the completion of this request and update the metrics"""
assert self.logcontext is not None
assert self.finish_time is not None
usage = self.logcontext.get_resource_usage()
if self._processing_finished_time is None:
@ -437,7 +439,7 @@ class XForwardedForRequest(SynapseRequest):
_forwarded_for: "Optional[_XForwardedForAddress]" = None
_forwarded_https: bool = False
def requestReceived(self, command, path, version):
def requestReceived(self, command: bytes, path: bytes, version: bytes) -> None:
# this method is called by the Channel once the full request has been
# received, to dispatch the request to a resource.
# We can use it to set the IP address and protocol according to the
@ -445,7 +447,7 @@ class XForwardedForRequest(SynapseRequest):
self._process_forwarded_headers()
return super().requestReceived(command, path, version)
def _process_forwarded_headers(self):
def _process_forwarded_headers(self) -> None:
headers = self.requestHeaders.getRawHeaders(b"x-forwarded-for")
if not headers:
return
@ -470,7 +472,7 @@ class XForwardedForRequest(SynapseRequest):
)
self._forwarded_https = True
def isSecure(self):
def isSecure(self) -> bool:
if self._forwarded_https:
return True
return super().isSecure()
@ -545,14 +547,16 @@ class SynapseSite(Site):
proxied = config.http_options.x_forwarded
request_class = XForwardedForRequest if proxied else SynapseRequest
def request_factory(channel, queued) -> Request:
def request_factory(channel, queued: bool) -> Request:
return request_class(
channel, max_request_body_size=max_request_body_size, queued=queued
channel,
max_request_body_size=max_request_body_size,
queued=queued,
)
self.requestFactory = request_factory # type: ignore
self.access_logger = logging.getLogger(logger_name)
self.server_version_string = server_version_string.encode("ascii")
def log(self, request):
def log(self, request: SynapseRequest) -> None:
pass

View file

@ -91,7 +91,7 @@ class ModuleApi:
self._auth = hs.get_auth()
self._auth_handler = auth_handler
self._server_name = hs.hostname
self._presence_stream = hs.get_event_sources().sources["presence"]
self._presence_stream = hs.get_event_sources().sources.presence
self._state = hs.get_state_handler()
self._clock: Clock = hs.get_clock()
self._send_email_handler = hs.get_send_email_handler()

View file

@ -584,7 +584,7 @@ class Notifier:
events: List[EventBase] = []
end_token = from_token
for name, source in self.event_sources.sources.items():
for name, source in self.event_sources.sources.get_sources():
keyname = "%s_key" % name
before_id = getattr(before_token, keyname)
after_id = getattr(after_token, keyname)

View file

@ -370,7 +370,7 @@ class HttpPusher(Pusher):
if event.type == "m.room.member" and event.is_state():
d["notification"]["membership"] = event.content["membership"]
d["notification"]["user_is_target"] = event.state_key == self.user_id
if self.hs.config.push_include_content and event.content:
if self.hs.config.push.push_include_content and event.content:
d["notification"]["content"] = event.content
# We no longer send aliases separately, instead, we send the human

View file

@ -110,7 +110,7 @@ class Mailer:
self.state_handler = self.hs.get_state_handler()
self.storage = hs.get_storage()
self.app_name = app_name
self.email_subjects: EmailSubjectConfig = hs.config.email_subjects
self.email_subjects: EmailSubjectConfig = hs.config.email.email_subjects
logger.info("Created Mailer for app_name %s" % app_name)
@ -796,8 +796,8 @@ class Mailer:
Returns:
A link to open a room in the web client.
"""
if self.hs.config.email_riot_base_url:
base_url = "%s/#/room" % (self.hs.config.email_riot_base_url)
if self.hs.config.email.email_riot_base_url:
base_url = "%s/#/room" % (self.hs.config.email.email_riot_base_url)
elif self.app_name == "Vector":
# need /beta for Universal Links to work on iOS
base_url = "https://vector.im/beta/#/room"
@ -815,9 +815,9 @@ class Mailer:
Returns:
A link to open the notification in the web client.
"""
if self.hs.config.email_riot_base_url:
if self.hs.config.email.email_riot_base_url:
return "%s/#/room/%s/%s" % (
self.hs.config.email_riot_base_url,
self.hs.config.email.email_riot_base_url,
notif["room_id"],
notif["event_id"],
)

View file

@ -35,12 +35,12 @@ class PusherFactory:
"http": HttpPusher
}
logger.info("email enable notifs: %r", hs.config.email_enable_notifs)
if hs.config.email_enable_notifs:
logger.info("email enable notifs: %r", hs.config.email.email_enable_notifs)
if hs.config.email.email_enable_notifs:
self.mailers: Dict[str, Mailer] = {}
self._notif_template_html = hs.config.email_notif_template_html
self._notif_template_text = hs.config.email_notif_template_text
self._notif_template_html = hs.config.email.email_notif_template_html
self._notif_template_text = hs.config.email.email_notif_template_text
self.pusher_types["email"] = self._create_email_pusher

View file

@ -62,7 +62,7 @@ class PusherPool:
self.clock = self.hs.get_clock()
# We shard the handling of push notifications by user ID.
self._pusher_shard_config = hs.config.push.pusher_shard_config
self._pusher_shard_config = hs.config.worker.pusher_shard_config
self._instance_name = hs.get_instance_name()
self._should_start_pushers = (
self._instance_name in self._pusher_shard_config.instances

View file

@ -12,7 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.http.server import JsonResource
from typing import TYPE_CHECKING
from synapse.http.server import HttpServer, JsonResource
from synapse.rest import admin
from synapse.rest.client import (
account,
@ -57,6 +59,9 @@ from synapse.rest.client import (
voip,
)
if TYPE_CHECKING:
from synapse.server import HomeServer
class ClientRestResource(JsonResource):
"""Matrix Client API REST resource.
@ -68,12 +73,12 @@ class ClientRestResource(JsonResource):
* etc
"""
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
JsonResource.__init__(self, hs, canonical_json=False)
self.register_servlets(self, hs)
@staticmethod
def register_servlets(client_resource, hs):
def register_servlets(client_resource: HttpServer, hs: "HomeServer") -> None:
versions.register_servlets(hs, client_resource)
# Deprecated in r0

View file

@ -47,7 +47,7 @@ class DeviceRestServlet(RestServlet):
self.store = hs.get_datastore()
async def on_GET(
self, request: SynapseRequest, user_id, device_id: str
self, request: SynapseRequest, user_id: str, device_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)

View file

@ -125,7 +125,7 @@ class ListRoomRestServlet(RestServlet):
errcode=Codes.INVALID_PARAM,
)
search_term = parse_string(request, "search_term")
search_term = parse_string(request, "search_term", encoding="utf-8")
if search_term == "":
raise SynapseError(
400,

View file

@ -57,7 +57,7 @@ class SendServerNoticeServlet(RestServlet):
self.admin_handler = hs.get_admin_handler()
self.txns = HttpTransactionCache(hs)
def register(self, json_resource: HttpServer):
def register(self, json_resource: HttpServer) -> None:
PATTERN = "/send_server_notice"
json_resource.register_paths(
"POST", admin_patterns(PATTERN + "$"), self.on_POST, self.__class__.__name__

View file

@ -419,7 +419,7 @@ class UserRegisterServlet(RestServlet):
self.nonces: Dict[str, int] = {}
self.hs = hs
def _clear_old_nonces(self):
def _clear_old_nonces(self) -> None:
"""
Clear out old nonces that are older than NONCE_TIMEOUT.
"""

View file

@ -14,6 +14,7 @@
import logging
import re
from http import HTTPStatus
from typing import TYPE_CHECKING, Awaitable, List, Tuple
from twisted.web.server import Request
@ -42,25 +43,25 @@ logger = logging.getLogger(__name__)
class RoomBatchSendEventRestServlet(RestServlet):
"""
API endpoint which can insert a chunk of events historically back in time
API endpoint which can insert a batch of events historically back in time
next to the given `prev_event`.
`chunk_id` comes from `next_chunk_id `in the response of the batch send
endpoint and is derived from the "insertion" events added to each chunk.
`batch_id` comes from `next_batch_id `in the response of the batch send
endpoint and is derived from the "insertion" events added to each batch.
It's not required for the first batch send.
`state_events_at_start` is used to define the historical state events
needed to auth the events like join events. These events will float
outside of the normal DAG as outlier's and won't be visible in the chat
history which also allows us to insert multiple chunks without having a bunch
of `@mxid joined the room` noise between each chunk.
history which also allows us to insert multiple batches without having a bunch
of `@mxid joined the room` noise between each batch.
`events` is chronological chunk/list of events you want to insert.
There is a reverse-chronological constraint on chunks so once you insert
`events` is chronological list of events you want to insert.
There is a reverse-chronological constraint on batches so once you insert
some messages, you can only insert older ones after that.
tldr; Insert chunks from your most recent history -> oldest history.
tldr; Insert batches from your most recent history -> oldest history.
POST /_matrix/client/unstable/org.matrix.msc2716/rooms/<roomID>/batch_send?prev_event=<eventID>&chunk_id=<chunkID>
POST /_matrix/client/unstable/org.matrix.msc2716/rooms/<roomID>/batch_send?prev_event_id=<eventID>&batch_id=<batchID>
{
"events": [ ... ],
"state_events_at_start": [ ... ]
@ -128,7 +129,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
self, sender: str, room_id: str, origin_server_ts: int
) -> JsonDict:
"""Creates an event dict for an "insertion" event with the proper fields
and a random chunk ID.
and a random batch ID.
Args:
sender: The event author MXID
@ -139,13 +140,13 @@ class RoomBatchSendEventRestServlet(RestServlet):
The new event dictionary to insert.
"""
next_chunk_id = random_string(8)
next_batch_id = random_string(8)
insertion_event = {
"type": EventTypes.MSC2716_INSERTION,
"sender": sender,
"room_id": room_id,
"content": {
EventContentFields.MSC2716_NEXT_CHUNK_ID: next_chunk_id,
EventContentFields.MSC2716_NEXT_BATCH_ID: next_batch_id,
EventContentFields.MSC2716_HISTORICAL: True,
},
"origin_server_ts": origin_server_ts,
@ -179,7 +180,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
if not requester.app_service:
raise AuthError(
403,
HTTPStatus.FORBIDDEN,
"Only application services can use the /batchsend endpoint",
)
@ -187,24 +188,26 @@ class RoomBatchSendEventRestServlet(RestServlet):
assert_params_in_dict(body, ["state_events_at_start", "events"])
assert request.args is not None
prev_events_from_query = parse_strings_from_args(request.args, "prev_event")
chunk_id_from_query = parse_string(request, "chunk_id")
prev_event_ids_from_query = parse_strings_from_args(
request.args, "prev_event_id"
)
batch_id_from_query = parse_string(request, "batch_id")
if prev_events_from_query is None:
if prev_event_ids_from_query is None:
raise SynapseError(
400,
HTTPStatus.BAD_REQUEST,
"prev_event query parameter is required when inserting historical messages back in time",
errcode=Codes.MISSING_PARAM,
)
# For the event we are inserting next to (`prev_events_from_query`),
# For the event we are inserting next to (`prev_event_ids_from_query`),
# find the most recent auth events (derived from state events) that
# allowed that message to be sent. We will use that as a base
# to auth our historical messages against.
(
most_recent_prev_event_id,
_,
) = await self.store.get_max_depth_of(prev_events_from_query)
) = await self.store.get_max_depth_of(prev_event_ids_from_query)
# mapping from (type, state_key) -> state_event_id
prev_state_map = await self.state_store.get_state_ids_for_event(
most_recent_prev_event_id
@ -213,7 +216,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
prev_state_ids = list(prev_state_map.values())
auth_event_ids = prev_state_ids
state_events_at_start = []
state_event_ids_at_start = []
for state_event in body["state_events_at_start"]:
assert_params_in_dict(
state_event, ["type", "origin_server_ts", "content", "sender"]
@ -279,27 +282,38 @@ class RoomBatchSendEventRestServlet(RestServlet):
)
event_id = event.event_id
state_events_at_start.append(event_id)
state_event_ids_at_start.append(event_id)
auth_event_ids.append(event_id)
events_to_create = body["events"]
inherited_depth = await self._inherit_depth_from_prev_ids(
prev_events_from_query
prev_event_ids_from_query
)
# Figure out which chunk to connect to. If they passed in
# chunk_id_from_query let's use it. The chunk ID passed in comes
# from the chunk_id in the "insertion" event from the previous chunk.
last_event_in_chunk = events_to_create[-1]
chunk_id_to_connect_to = chunk_id_from_query
# Figure out which batch to connect to. If they passed in
# batch_id_from_query let's use it. The batch ID passed in comes
# from the batch_id in the "insertion" event from the previous batch.
last_event_in_batch = events_to_create[-1]
batch_id_to_connect_to = batch_id_from_query
base_insertion_event = None
if chunk_id_from_query:
if batch_id_from_query:
# All but the first base insertion event should point at a fake
# event, which causes the HS to ask for the state at the start of
# the chunk later.
# the batch later.
prev_event_ids = [fake_prev_event_id]
# TODO: Verify the chunk_id_from_query corresponds to an insertion event
# Verify the batch_id_from_query corresponds to an actual insertion event
# and have the batch connected.
corresponding_insertion_event_id = (
await self.store.get_insertion_event_by_batch_id(batch_id_from_query)
)
if corresponding_insertion_event_id is None:
raise SynapseError(
400,
"No insertion event corresponds to the given ?batch_id",
errcode=Codes.INVALID_PARAM,
)
pass
# Otherwise, create an insertion event to act as a starting point.
#
@ -309,12 +323,12 @@ class RoomBatchSendEventRestServlet(RestServlet):
# an insertion event), in which case we just create a new insertion event
# that can then get pointed to by a "marker" event later.
else:
prev_event_ids = prev_events_from_query
prev_event_ids = prev_event_ids_from_query
base_insertion_event_dict = self._create_insertion_event_dict(
sender=requester.user.to_string(),
room_id=room_id,
origin_server_ts=last_event_in_chunk["origin_server_ts"],
origin_server_ts=last_event_in_batch["origin_server_ts"],
)
base_insertion_event_dict["prev_events"] = prev_event_ids.copy()
@ -333,38 +347,38 @@ class RoomBatchSendEventRestServlet(RestServlet):
depth=inherited_depth,
)
chunk_id_to_connect_to = base_insertion_event["content"][
EventContentFields.MSC2716_NEXT_CHUNK_ID
batch_id_to_connect_to = base_insertion_event["content"][
EventContentFields.MSC2716_NEXT_BATCH_ID
]
# Connect this current chunk to the insertion event from the previous chunk
chunk_event = {
"type": EventTypes.MSC2716_CHUNK,
# Connect this current batch to the insertion event from the previous batch
batch_event = {
"type": EventTypes.MSC2716_BATCH,
"sender": requester.user.to_string(),
"room_id": room_id,
"content": {
EventContentFields.MSC2716_CHUNK_ID: chunk_id_to_connect_to,
EventContentFields.MSC2716_BATCH_ID: batch_id_to_connect_to,
EventContentFields.MSC2716_HISTORICAL: True,
},
# Since the chunk event is put at the end of the chunk,
# Since the batch event is put at the end of the batch,
# where the newest-in-time event is, copy the origin_server_ts from
# the last event we're inserting
"origin_server_ts": last_event_in_chunk["origin_server_ts"],
"origin_server_ts": last_event_in_batch["origin_server_ts"],
}
# Add the chunk event to the end of the chunk (newest-in-time)
events_to_create.append(chunk_event)
# Add the batch event to the end of the batch (newest-in-time)
events_to_create.append(batch_event)
# Add an "insertion" event to the start of each chunk (next to the oldest-in-time
# event in the chunk) so the next chunk can be connected to this one.
# Add an "insertion" event to the start of each batch (next to the oldest-in-time
# event in the batch) so the next batch can be connected to this one.
insertion_event = self._create_insertion_event_dict(
sender=requester.user.to_string(),
room_id=room_id,
# Since the insertion event is put at the start of the chunk,
# Since the insertion event is put at the start of the batch,
# where the oldest-in-time event is, copy the origin_server_ts from
# the first event we're inserting
origin_server_ts=events_to_create[0]["origin_server_ts"],
)
# Prepend the insertion event to the start of the chunk (oldest-in-time)
# Prepend the insertion event to the start of the batch (oldest-in-time)
events_to_create = [insertion_event] + events_to_create
event_ids = []
@ -424,20 +438,26 @@ class RoomBatchSendEventRestServlet(RestServlet):
context=context,
)
# Add the base_insertion_event to the bottom of the list we return
if base_insertion_event is not None:
event_ids.append(base_insertion_event.event_id)
insertion_event_id = event_ids[0]
batch_event_id = event_ids[-1]
historical_event_ids = event_ids[1:-1]
return 200, {
"state_events": state_events_at_start,
"events": event_ids,
"next_chunk_id": insertion_event["content"][
EventContentFields.MSC2716_NEXT_CHUNK_ID
response_dict = {
"state_event_ids": state_event_ids_at_start,
"event_ids": historical_event_ids,
"next_batch_id": insertion_event["content"][
EventContentFields.MSC2716_NEXT_BATCH_ID
],
"insertion_event_id": insertion_event_id,
"batch_event_id": batch_event_id,
}
if base_insertion_event is not None:
response_dict["base_insertion_event_id"] = base_insertion_event.event_id
return HTTPStatus.OK, response_dict
def on_GET(self, request: Request, room_id: str) -> Tuple[int, str]:
return 501, "Not implemented"
return HTTPStatus.NOT_IMPLEMENTED, "Not implemented"
def on_PUT(
self, request: SynapseRequest, room_id: str

View file

@ -17,17 +17,22 @@ import logging
from hashlib import sha256
from http import HTTPStatus
from os import path
from typing import Dict, List
from typing import TYPE_CHECKING, Any, Dict, List
import jinja2
from jinja2 import TemplateNotFound
from twisted.web.server import Request
from synapse.api.errors import NotFoundError, StoreError, SynapseError
from synapse.config import ConfigError
from synapse.http.server import DirectServeHtmlResource, respond_with_html
from synapse.http.servlet import parse_bytes_from_args, parse_string
from synapse.types import UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
# language to use for the templates. TODO: figure this out from Accept-Language
TEMPLATE_LANGUAGE = "en"
@ -69,11 +74,7 @@ class ConsentResource(DirectServeHtmlResource):
against the user.
"""
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): homeserver
"""
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
@ -106,18 +107,14 @@ class ConsentResource(DirectServeHtmlResource):
self._hmac_secret = hs.config.form_secret.encode("utf-8")
async def _async_render_GET(self, request):
"""
Args:
request (twisted.web.http.Request):
"""
async def _async_render_GET(self, request: Request) -> None:
version = parse_string(request, "v", default=self._default_consent_version)
username = parse_string(request, "u", default="")
userhmac = None
has_consented = False
public_version = username == ""
if not public_version:
args: Dict[bytes, List[bytes]] = request.args
args: Dict[bytes, List[bytes]] = request.args # type: ignore
userhmac_bytes = parse_bytes_from_args(args, "h", required=True)
self._check_hash(username, userhmac_bytes)
@ -147,14 +144,10 @@ class ConsentResource(DirectServeHtmlResource):
except TemplateNotFound:
raise NotFoundError("Unknown policy version")
async def _async_render_POST(self, request):
"""
Args:
request (twisted.web.http.Request):
"""
async def _async_render_POST(self, request: Request) -> None:
version = parse_string(request, "v", required=True)
username = parse_string(request, "u", required=True)
args: Dict[bytes, List[bytes]] = request.args
args: Dict[bytes, List[bytes]] = request.args # type: ignore
userhmac = parse_bytes_from_args(args, "h", required=True)
self._check_hash(username, userhmac)
@ -177,7 +170,9 @@ class ConsentResource(DirectServeHtmlResource):
except TemplateNotFound:
raise NotFoundError("success.html not found")
def _render_template(self, request, template_name, **template_args):
def _render_template(
self, request: Request, template_name: str, **template_args: Any
) -> None:
# get_template checks for ".." so we don't need to worry too much
# about path traversal here.
template_html = self._jinja_env.get_template(
@ -186,11 +181,11 @@ class ConsentResource(DirectServeHtmlResource):
html = template_html.render(**template_args)
respond_with_html(request, 200, html)
def _check_hash(self, userid, userhmac):
def _check_hash(self, userid: str, userhmac: bytes) -> None:
"""
Args:
userid (unicode):
userhmac (bytes):
userid:
userhmac:
Raises:
SynapseError if the hash doesn't match

View file

@ -13,6 +13,7 @@
# limitations under the License.
from twisted.web.resource import Resource
from twisted.web.server import Request
class HealthResource(Resource):
@ -25,6 +26,6 @@ class HealthResource(Resource):
isLeaf = 1
def render_GET(self, request):
def render_GET(self, request: Request) -> bytes:
request.setHeader(b"Content-Type", b"text/plain")
return b"OK"

View file

@ -12,14 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from twisted.web.resource import Resource
from .local_key_resource import LocalKey
from .remote_key_resource import RemoteKey
if TYPE_CHECKING:
from synapse.server import HomeServer
class KeyApiV2Resource(Resource):
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
Resource.__init__(self)
self.putChild(b"server", LocalKey(hs))
self.putChild(b"query", RemoteKey(hs))

Some files were not shown because too many files have changed in this diff Show more