0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-14 03:53:47 +01:00

Implement 'event_format' filter param in /sync

This has been specced and part-implemented; let's implement it for /sync (but
no other endpoints yet :/).
This commit is contained in:
Richard van der Hoff 2018-09-04 15:18:25 +01:00
parent 77055dba92
commit 87c18d12ee
2 changed files with 39 additions and 13 deletions

View file

@ -251,6 +251,7 @@ class FilterCollection(object):
"include_leave", False "include_leave", False
) )
self.event_fields = filter_json.get("event_fields", []) self.event_fields = filter_json.get("event_fields", [])
self.event_format = filter_json.get("event_format", "client")
def __repr__(self): def __repr__(self):
return "<FilterCollection %s>" % (json.dumps(self._filter_json),) return "<FilterCollection %s>" % (json.dumps(self._filter_json),)

View file

@ -25,6 +25,7 @@ from synapse.api.errors import SynapseError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
from synapse.events.utils import ( from synapse.events.utils import (
format_event_for_client_v2_without_room_id, format_event_for_client_v2_without_room_id,
format_event_raw,
serialize_event, serialize_event,
) )
from synapse.handlers.presence import format_user_presence_state from synapse.handlers.presence import format_user_presence_state
@ -175,17 +176,28 @@ class SyncRestServlet(RestServlet):
@staticmethod @staticmethod
def encode_response(time_now, sync_result, access_token_id, filter): def encode_response(time_now, sync_result, access_token_id, filter):
if filter.event_format == 'client':
event_formatter = format_event_for_client_v2_without_room_id
elif filter.event_format == 'federation':
event_formatter = format_event_raw
else:
raise Exception("Unknown event format %s" % (filter.event_format, ))
joined = SyncRestServlet.encode_joined( joined = SyncRestServlet.encode_joined(
sync_result.joined, time_now, access_token_id, filter.event_fields sync_result.joined, time_now, access_token_id,
filter.event_fields,
event_formatter,
) )
invited = SyncRestServlet.encode_invited( invited = SyncRestServlet.encode_invited(
sync_result.invited, time_now, access_token_id, sync_result.invited, time_now, access_token_id,
event_formatter,
) )
archived = SyncRestServlet.encode_archived( archived = SyncRestServlet.encode_archived(
sync_result.archived, time_now, access_token_id, sync_result.archived, time_now, access_token_id,
filter.event_fields, filter.event_fields,
event_formatter,
) )
return { return {
@ -228,7 +240,7 @@ class SyncRestServlet(RestServlet):
} }
@staticmethod @staticmethod
def encode_joined(rooms, time_now, token_id, event_fields): def encode_joined(rooms, time_now, token_id, event_fields, event_formatter):
""" """
Encode the joined rooms in a sync result Encode the joined rooms in a sync result
@ -241,6 +253,8 @@ class SyncRestServlet(RestServlet):
of transaction IDs of transaction IDs
event_fields(list<str>): List of event fields to include. If empty, event_fields(list<str>): List of event fields to include. If empty,
all fields will be returned. all fields will be returned.
event_formatter (func[dict]): function to convert from federation format
to client format
Returns: Returns:
dict[str, dict[str, object]]: the joined rooms list, in our dict[str, dict[str, object]]: the joined rooms list, in our
response format response format
@ -248,13 +262,14 @@ class SyncRestServlet(RestServlet):
joined = {} joined = {}
for room in rooms: for room in rooms:
joined[room.room_id] = SyncRestServlet.encode_room( joined[room.room_id] = SyncRestServlet.encode_room(
room, time_now, token_id, only_fields=event_fields room, time_now, token_id, joined=True, only_fields=event_fields,
event_formatter=event_formatter,
) )
return joined return joined
@staticmethod @staticmethod
def encode_invited(rooms, time_now, token_id): def encode_invited(rooms, time_now, token_id, event_formatter):
""" """
Encode the invited rooms in a sync result Encode the invited rooms in a sync result
@ -265,6 +280,8 @@ class SyncRestServlet(RestServlet):
calculations calculations
token_id(int): ID of the user's auth token - used for namespacing token_id(int): ID of the user's auth token - used for namespacing
of transaction IDs of transaction IDs
event_formatter (func[dict]): function to convert from federation format
to client format
Returns: Returns:
dict[str, dict[str, object]]: the invited rooms list, in our dict[str, dict[str, object]]: the invited rooms list, in our
@ -274,7 +291,7 @@ class SyncRestServlet(RestServlet):
for room in rooms: for room in rooms:
invite = serialize_event( invite = serialize_event(
room.invite, time_now, token_id=token_id, room.invite, time_now, token_id=token_id,
event_format=format_event_for_client_v2_without_room_id, event_format=event_formatter,
is_invite=True, is_invite=True,
) )
unsigned = dict(invite.get("unsigned", {})) unsigned = dict(invite.get("unsigned", {}))
@ -288,7 +305,7 @@ class SyncRestServlet(RestServlet):
return invited return invited
@staticmethod @staticmethod
def encode_archived(rooms, time_now, token_id, event_fields): def encode_archived(rooms, time_now, token_id, event_fields, event_formatter):
""" """
Encode the archived rooms in a sync result Encode the archived rooms in a sync result
@ -301,6 +318,8 @@ class SyncRestServlet(RestServlet):
of transaction IDs of transaction IDs
event_fields(list<str>): List of event fields to include. If empty, event_fields(list<str>): List of event fields to include. If empty,
all fields will be returned. all fields will be returned.
event_formatter (func[dict]): function to convert from federation format
to client format
Returns: Returns:
dict[str, dict[str, object]]: The invited rooms list, in our dict[str, dict[str, object]]: The invited rooms list, in our
response format response format
@ -308,13 +327,18 @@ class SyncRestServlet(RestServlet):
joined = {} joined = {}
for room in rooms: for room in rooms:
joined[room.room_id] = SyncRestServlet.encode_room( joined[room.room_id] = SyncRestServlet.encode_room(
room, time_now, token_id, joined=False, only_fields=event_fields room, time_now, token_id, joined=False,
only_fields=event_fields,
event_formatter=event_formatter,
) )
return joined return joined
@staticmethod @staticmethod
def encode_room(room, time_now, token_id, joined=True, only_fields=None): def encode_room(
room, time_now, token_id, joined,
only_fields, event_formatter,
):
""" """
Args: Args:
room (JoinedSyncResult|ArchivedSyncResult): sync result for a room (JoinedSyncResult|ArchivedSyncResult): sync result for a
@ -326,14 +350,15 @@ class SyncRestServlet(RestServlet):
joined (bool): True if the user is joined to this room - will mean joined (bool): True if the user is joined to this room - will mean
we handle ephemeral events we handle ephemeral events
only_fields(list<str>): Optional. The list of event fields to include. only_fields(list<str>): Optional. The list of event fields to include.
event_formatter (func[dict]): function to convert from federation format
to client format
Returns: Returns:
dict[str, object]: the room, encoded in our response format dict[str, object]: the room, encoded in our response format
""" """
def serialize(event): def serialize(event):
# TODO(mjark): Respect formatting requirements in the filter.
return serialize_event( return serialize_event(
event, time_now, token_id=token_id, event, time_now, token_id=token_id,
event_format=format_event_for_client_v2_without_room_id, event_format=event_formatter,
only_event_fields=only_fields, only_event_fields=only_fields,
) )