0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-12 13:01:34 +01:00

Pass the requester during event serialization. (#15174)

This allows Synapse to properly include the transaction ID in the
unsigned data of events.
This commit is contained in:
Quentin Gliech 2023-03-06 17:08:39 +01:00 committed by GitHub
parent 05e0a4089a
commit 41f127e068
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 151 additions and 75 deletions

1
changelog.d/15174.bugfix Normal file
View file

@ -0,0 +1 @@
Add the `transaction_id` in the events included in many endpoints responses.

View file

@ -38,7 +38,7 @@ from synapse.api.constants import (
)
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersion
from synapse.types import JsonDict
from synapse.types import JsonDict, Requester
from . import EventBase
@ -316,8 +316,9 @@ class SerializeEventConfig:
as_client_event: bool = True
# Function to convert from federation format to client format
event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1
# ID of the user's auth token - used for namespacing of transaction IDs
token_id: Optional[int] = None
# The entity that requested the event. This is used to determine whether to include
# the transaction_id in the unsigned section of the event.
requester: Optional[Requester] = None
# List of event fields to include. If empty, all fields will be returned.
only_event_fields: Optional[List[str]] = None
# Some events can have stripped room state stored in the `unsigned` field.
@ -367,10 +368,23 @@ def serialize_event(
e.unsigned["redacted_because"], time_now_ms, config=config
)
if config.token_id is not None:
if config.token_id == getattr(e.internal_metadata, "token_id", None):
# If we have a txn_id saved in the internal_metadata, we should include it in the
# unsigned section of the event if it was sent by the same session as the one
# requesting the event.
# There is a special case for guests, because they only have one access token
# without associated access_token_id, so we always include the txn_id for events
# they sent.
txn_id = getattr(e.internal_metadata, "txn_id", None)
if txn_id is not None:
if txn_id is not None and config.requester is not None:
event_token_id = getattr(e.internal_metadata, "token_id", None)
if config.requester.user.to_string() == e.sender and (
(
event_token_id is not None
and config.requester.access_token_id is not None
and event_token_id == config.requester.access_token_id
)
or config.requester.is_guest
):
d["unsigned"]["transaction_id"] = txn_id
# invite_room_state and knock_room_state are a list of stripped room state events

View file

@ -23,7 +23,7 @@ from synapse.events.utils import SerializeEventConfig
from synapse.handlers.presence import format_user_presence_state
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict, UserID
from synapse.types import JsonDict, Requester, UserID
from synapse.visibility import filter_events_for_client
if TYPE_CHECKING:
@ -46,13 +46,12 @@ class EventStreamHandler:
async def get_stream(
self,
auth_user_id: str,
requester: Requester,
pagin_config: PaginationConfig,
timeout: int = 0,
as_client_event: bool = True,
affect_presence: bool = True,
room_id: Optional[str] = None,
is_guest: bool = False,
) -> JsonDict:
"""Fetches the events stream for a given user."""
@ -62,13 +61,12 @@ class EventStreamHandler:
raise SynapseError(403, "This room has been blocked on this server")
# send any outstanding server notices to the user.
await self._server_notices_sender.on_user_syncing(auth_user_id)
await self._server_notices_sender.on_user_syncing(requester.user.to_string())
auth_user = UserID.from_string(auth_user_id)
presence_handler = self.hs.get_presence_handler()
context = await presence_handler.user_syncing(
auth_user_id,
requester.user.to_string(),
affect_presence=affect_presence,
presence_state=PresenceState.ONLINE,
)
@ -82,10 +80,10 @@ class EventStreamHandler:
timeout = random.randint(int(timeout * 0.9), int(timeout * 1.1))
stream_result = await self.notifier.get_events_for(
auth_user,
requester.user,
pagin_config,
timeout,
is_guest=is_guest,
is_guest=requester.is_guest,
explicit_room_id=room_id,
)
events = stream_result.events
@ -102,7 +100,7 @@ class EventStreamHandler:
if event.membership != Membership.JOIN:
continue
# Send down presence.
if event.state_key == auth_user_id:
if event.state_key == requester.user.to_string():
# Send down presence for everyone in the room.
users: Iterable[str] = await self.store.get_users_in_room(
event.room_id
@ -124,7 +122,9 @@ class EventStreamHandler:
chunks = self._event_serializer.serialize_events(
events,
time_now,
config=SerializeEventConfig(as_client_event=as_client_event),
config=SerializeEventConfig(
as_client_event=as_client_event, requester=requester
),
)
chunk = {

View file

@ -318,11 +318,9 @@ class InitialSyncHandler:
)
is_peeking = member_event_id is None
user_id = requester.user.to_string()
if membership == Membership.JOIN:
result = await self._room_initial_sync_joined(
user_id, room_id, pagin_config, membership, is_peeking
requester, room_id, pagin_config, membership, is_peeking
)
elif membership == Membership.LEAVE:
# The member_event_id will always be available if membership is set
@ -330,10 +328,16 @@ class InitialSyncHandler:
assert member_event_id
result = await self._room_initial_sync_parted(
user_id, room_id, pagin_config, membership, member_event_id, is_peeking
requester,
room_id,
pagin_config,
membership,
member_event_id,
is_peeking,
)
account_data_events = []
user_id = requester.user.to_string()
tags = await self.store.get_tags_for_room(user_id, room_id)
if tags:
account_data_events.append(
@ -350,7 +354,7 @@ class InitialSyncHandler:
async def _room_initial_sync_parted(
self,
user_id: str,
requester: Requester,
room_id: str,
pagin_config: PaginationConfig,
membership: str,
@ -369,13 +373,17 @@ class InitialSyncHandler:
)
messages = await filter_events_for_client(
self._storage_controllers, user_id, messages, is_peeking=is_peeking
self._storage_controllers,
requester.user.to_string(),
messages,
is_peeking=is_peeking,
)
start_token = StreamToken.START.copy_and_replace(StreamKeyType.ROOM, token)
end_token = StreamToken.START.copy_and_replace(StreamKeyType.ROOM, stream_token)
time_now = self.clock.time_msec()
serialize_options = SerializeEventConfig(requester=requester)
return {
"membership": membership,
@ -383,14 +391,18 @@ class InitialSyncHandler:
"messages": {
"chunk": (
# Don't bundle aggregations as this is a deprecated API.
self._event_serializer.serialize_events(messages, time_now)
self._event_serializer.serialize_events(
messages, time_now, config=serialize_options
)
),
"start": await start_token.to_string(self.store),
"end": await end_token.to_string(self.store),
},
"state": (
# Don't bundle aggregations as this is a deprecated API.
self._event_serializer.serialize_events(room_state.values(), time_now)
self._event_serializer.serialize_events(
room_state.values(), time_now, config=serialize_options
)
),
"presence": [],
"receipts": [],
@ -398,7 +410,7 @@ class InitialSyncHandler:
async def _room_initial_sync_joined(
self,
user_id: str,
requester: Requester,
room_id: str,
pagin_config: PaginationConfig,
membership: str,
@ -410,9 +422,12 @@ class InitialSyncHandler:
# TODO: These concurrently
time_now = self.clock.time_msec()
serialize_options = SerializeEventConfig(requester=requester)
# Don't bundle aggregations as this is a deprecated API.
state = self._event_serializer.serialize_events(
current_state.values(), time_now
current_state.values(),
time_now,
config=serialize_options,
)
now_token = self.hs.get_event_sources().get_current_token()
@ -450,7 +465,10 @@ class InitialSyncHandler:
if not receipts:
return []
return ReceiptEventSource.filter_out_private_receipts(receipts, user_id)
return ReceiptEventSource.filter_out_private_receipts(
receipts,
requester.user.to_string(),
)
presence, receipts, (messages, token) = await make_deferred_yieldable(
gather_results(
@ -469,20 +487,23 @@ class InitialSyncHandler:
)
messages = await filter_events_for_client(
self._storage_controllers, user_id, messages, is_peeking=is_peeking
self._storage_controllers,
requester.user.to_string(),
messages,
is_peeking=is_peeking,
)
start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token)
end_token = now_token
time_now = self.clock.time_msec()
ret = {
"room_id": room_id,
"messages": {
"chunk": (
# Don't bundle aggregations as this is a deprecated API.
self._event_serializer.serialize_events(messages, time_now)
self._event_serializer.serialize_events(
messages, time_now, config=serialize_options
)
),
"start": await start_token.to_string(self.store),
"end": await end_token.to_string(self.store),

View file

@ -50,7 +50,7 @@ from synapse.event_auth import validate_event_for_room_version
from synapse.events import EventBase, relation_from_event
from synapse.events.builder import EventBuilder
from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
from synapse.events.utils import maybe_upsert_event_field
from synapse.events.utils import SerializeEventConfig, maybe_upsert_event_field
from synapse.events.validator import EventValidator
from synapse.handlers.directory import DirectoryHandler
from synapse.logging import opentracing
@ -245,8 +245,11 @@ class MessageHandler:
)
room_state = room_state_events[membership_event_id]
now = self.clock.time_msec()
events = self._event_serializer.serialize_events(room_state.values(), now)
events = self._event_serializer.serialize_events(
room_state.values(),
self.clock.time_msec(),
config=SerializeEventConfig(requester=requester),
)
return events
async def _user_can_see_state_at_event(

View file

@ -579,7 +579,9 @@ class PaginationHandler:
time_now = self.clock.time_msec()
serialize_options = SerializeEventConfig(as_client_event=as_client_event)
serialize_options = SerializeEventConfig(
as_client_event=as_client_event, requester=requester
)
chunk = {
"chunk": (

View file

@ -20,6 +20,7 @@ import attr
from synapse.api.constants import Direction, EventTypes, RelationTypes
from synapse.api.errors import SynapseError
from synapse.events import EventBase, relation_from_event
from synapse.events.utils import SerializeEventConfig
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import trace
from synapse.storage.databases.main.relations import ThreadsNextBatch, _RelatedEvent
@ -151,16 +152,23 @@ class RelationsHandler:
)
now = self._clock.time_msec()
serialize_options = SerializeEventConfig(requester=requester)
return_value: JsonDict = {
"chunk": self._event_serializer.serialize_events(
events, now, bundle_aggregations=aggregations
events,
now,
bundle_aggregations=aggregations,
config=serialize_options,
),
}
if include_original_event:
# Do not bundle aggregations when retrieving the original event because
# we want the content before relations are applied to it.
return_value["original_event"] = self._event_serializer.serialize_event(
event, now, bundle_aggregations=None
event,
now,
bundle_aggregations=None,
config=serialize_options,
)
if next_token:

View file

@ -23,7 +23,8 @@ from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.filtering import Filter
from synapse.events import EventBase
from synapse.types import JsonDict, StrCollection, StreamKeyType, UserID
from synapse.events.utils import SerializeEventConfig
from synapse.types import JsonDict, Requester, StrCollection, StreamKeyType, UserID
from synapse.types.state import StateFilter
from synapse.visibility import filter_events_for_client
@ -109,12 +110,12 @@ class SearchHandler:
return historical_room_ids
async def search(
self, user: UserID, content: JsonDict, batch: Optional[str] = None
self, requester: Requester, content: JsonDict, batch: Optional[str] = None
) -> JsonDict:
"""Performs a full text search for a user.
Args:
user: The user performing the search.
requester: The user performing the search.
content: Search parameters
batch: The next_batch parameter. Used for pagination.
@ -199,7 +200,7 @@ class SearchHandler:
)
return await self._search(
user,
requester,
batch_group,
batch_group_key,
batch_token,
@ -217,7 +218,7 @@ class SearchHandler:
async def _search(
self,
user: UserID,
requester: Requester,
batch_group: Optional[str],
batch_group_key: Optional[str],
batch_token: Optional[str],
@ -235,7 +236,7 @@ class SearchHandler:
"""Performs a full text search for a user.
Args:
user: The user performing the search.
requester: The user performing the search.
batch_group: Pagination information.
batch_group_key: Pagination information.
batch_token: Pagination information.
@ -269,7 +270,7 @@ class SearchHandler:
# TODO: Search through left rooms too
rooms = await self.store.get_rooms_for_local_user_where_membership_is(
user.to_string(),
requester.user.to_string(),
membership_list=[Membership.JOIN],
# membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban],
)
@ -303,13 +304,13 @@ class SearchHandler:
if order_by == "rank":
search_result, sender_group = await self._search_by_rank(
user, room_ids, search_term, keys, search_filter
requester.user, room_ids, search_term, keys, search_filter
)
# Unused return values for rank search.
global_next_batch = None
elif order_by == "recent":
search_result, global_next_batch = await self._search_by_recent(
user,
requester.user,
room_ids,
search_term,
keys,
@ -334,7 +335,7 @@ class SearchHandler:
assert after_limit is not None
contexts = await self._calculate_event_contexts(
user,
requester.user,
search_result.allowed_events,
before_limit,
after_limit,
@ -363,27 +364,37 @@ class SearchHandler:
# The returned events.
search_result.allowed_events,
),
user.to_string(),
requester.user.to_string(),
)
# We're now about to serialize the events. We should not make any
# blocking calls after this. Otherwise, the 'age' will be wrong.
time_now = self.clock.time_msec()
serialize_options = SerializeEventConfig(requester=requester)
for context in contexts.values():
context["events_before"] = self._event_serializer.serialize_events(
context["events_before"], time_now, bundle_aggregations=aggregations
context["events_before"],
time_now,
bundle_aggregations=aggregations,
config=serialize_options,
)
context["events_after"] = self._event_serializer.serialize_events(
context["events_after"], time_now, bundle_aggregations=aggregations
context["events_after"],
time_now,
bundle_aggregations=aggregations,
config=serialize_options,
)
results = [
{
"rank": search_result.rank_map[e.event_id],
"result": self._event_serializer.serialize_event(
e, time_now, bundle_aggregations=aggregations
e,
time_now,
bundle_aggregations=aggregations,
config=serialize_options,
),
"context": contexts.get(e.event_id, {}),
}
@ -398,7 +409,9 @@ class SearchHandler:
if state_results:
rooms_cat_res["state"] = {
room_id: self._event_serializer.serialize_events(state_events, time_now)
room_id: self._event_serializer.serialize_events(
state_events, time_now, config=serialize_options
)
for room_id, state_events in state_results.items()
}

View file

@ -17,6 +17,7 @@ import logging
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
from synapse.api.errors import SynapseError
from synapse.events.utils import SerializeEventConfig
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_string
from synapse.http.site import SynapseRequest
@ -43,9 +44,8 @@ class EventStreamRestServlet(RestServlet):
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
is_guest = requester.is_guest
args: Dict[bytes, List[bytes]] = request.args # type: ignore
if is_guest:
if requester.is_guest:
if b"room_id" not in args:
raise SynapseError(400, "Guest users must specify room_id param")
room_id = parse_string(request, "room_id")
@ -63,13 +63,12 @@ class EventStreamRestServlet(RestServlet):
as_client_event = b"raw" not in args
chunk = await self.event_stream_handler.get_stream(
requester.user.to_string(),
requester,
pagin_config,
timeout=timeout,
as_client_event=as_client_event,
affect_presence=(not is_guest),
affect_presence=(not requester.is_guest),
room_id=room_id,
is_guest=is_guest,
)
return 200, chunk
@ -91,9 +90,12 @@ class EventRestServlet(RestServlet):
requester = await self.auth.get_user_by_req(request)
event = await self.event_handler.get_event(requester.user, None, event_id)
time_now = self.clock.time_msec()
if event:
result = self._event_serializer.serialize_event(event, time_now)
result = self._event_serializer.serialize_event(
event,
self.clock.time_msec(),
config=SerializeEventConfig(requester=requester),
)
return 200, result
else:
return 404, "Event not found."

View file

@ -72,6 +72,12 @@ class NotificationsServlet(RestServlet):
next_token = None
serialize_options = SerializeEventConfig(
event_format=format_event_for_client_v2_without_room_id,
requester=requester,
)
now = self.clock.time_msec()
for pa in push_actions:
returned_pa = {
"room_id": pa.room_id,
@ -81,10 +87,8 @@ class NotificationsServlet(RestServlet):
"event": (
self._event_serializer.serialize_event(
notif_events[pa.event_id],
self.clock.time_msec(),
config=SerializeEventConfig(
event_format=format_event_for_client_v2_without_room_id
),
now,
config=serialize_options,
)
),
}

View file

@ -37,7 +37,7 @@ from synapse.api.errors import (
UnredactedContentDeletedError,
)
from synapse.api.filtering import Filter
from synapse.events.utils import format_event_for_client_v2
from synapse.events.utils import SerializeEventConfig, format_event_for_client_v2
from synapse.http.server import HttpServer
from synapse.http.servlet import (
ResolveRoomIdMixin,
@ -814,11 +814,13 @@ class RoomEventServlet(RestServlet):
[event], requester.user.to_string()
)
time_now = self.clock.time_msec()
# per MSC2676, /rooms/{roomId}/event/{eventId}, should return the
# *original* event, rather than the edited version
event_dict = self._event_serializer.serialize_event(
event, time_now, bundle_aggregations=aggregations
event,
self.clock.time_msec(),
bundle_aggregations=aggregations,
config=SerializeEventConfig(requester=requester),
)
return 200, event_dict
@ -863,24 +865,30 @@ class RoomEventContextServlet(RestServlet):
raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
time_now = self.clock.time_msec()
serializer_options = SerializeEventConfig(requester=requester)
results = {
"events_before": self._event_serializer.serialize_events(
event_context.events_before,
time_now,
bundle_aggregations=event_context.aggregations,
config=serializer_options,
),
"event": self._event_serializer.serialize_event(
event_context.event,
time_now,
bundle_aggregations=event_context.aggregations,
config=serializer_options,
),
"events_after": self._event_serializer.serialize_events(
event_context.events_after,
time_now,
bundle_aggregations=event_context.aggregations,
config=serializer_options,
),
"state": self._event_serializer.serialize_events(
event_context.state, time_now
event_context.state,
time_now,
config=serializer_options,
),
"start": event_context.start,
"end": event_context.end,
@ -1192,7 +1200,7 @@ class SearchRestServlet(RestServlet):
content = parse_json_object_from_request(request)
batch = parse_string(request, "next_batch")
results = await self.search_handler.search(requester.user, content, batch)
results = await self.search_handler.search(requester, content, batch)
return 200, results

View file

@ -38,7 +38,7 @@ from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import trace_with_opname
from synapse.types import JsonDict, StreamToken
from synapse.types import JsonDict, Requester, StreamToken
from synapse.util import json_decoder
from ._base import client_patterns, set_timeline_upper_limit
@ -226,7 +226,7 @@ class SyncRestServlet(RestServlet):
# We know that the the requester has an access token since appservices
# cannot use sync.
response_content = await self.encode_response(
time_now, sync_result, requester.access_token_id, filter_collection
time_now, sync_result, requester, filter_collection
)
logger.debug("Event formatting complete")
@ -237,7 +237,7 @@ class SyncRestServlet(RestServlet):
self,
time_now: int,
sync_result: SyncResult,
access_token_id: Optional[int],
requester: Requester,
filter: FilterCollection,
) -> JsonDict:
logger.debug("Formatting events in sync response")
@ -250,12 +250,12 @@ class SyncRestServlet(RestServlet):
serialize_options = SerializeEventConfig(
event_format=event_formatter,
token_id=access_token_id,
requester=requester,
only_event_fields=filter.event_fields,
)
stripped_serialize_options = SerializeEventConfig(
event_format=event_formatter,
token_id=access_token_id,
requester=requester,
include_stripped_room_state=True,
)