0
0
Fork 0
mirror of https://github.com/matrix-org/dendrite synced 2024-12-15 12:13:43 +01:00

syncapi: Rename and split out tokens (#1025)

* syncapi: Rename and split out tokens

Previously we used the badly named `PaginationToken` which was
used for both `/sync` and `/messages` requests. This quickly
became confusing because named fields like `PDUPosition` meant
different things depending on the token type. Instead, we now have
two token types: `TopologyToken` and `StreamingToken`, both of
which have fields which make more sense for their specific situations.

Updated the codebase to use one or the other. `PaginationToken` still
lives on as `syncToken`, an unexported type which both tokens rely on.
This allows us to guarantee that the specific mappings of positions
to a string remain solely under the control of the `types` package.
This enables us to move high-level conceptual things like
"decrement this topological token" to function calls e.g
`TopologicalToken.Decrement()`.

Currently broken because `/messages` seemingly used both stream and
topological tokens, though I need to confirm this.

* final tweaks/hacks

* spurious logging

* Review comments and linting
This commit is contained in:
Kegsay 2020-05-13 12:14:50 +01:00 committed by GitHub
parent 31e6a7f193
commit 5e9dce1c0c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 457 additions and 439 deletions

View file

@ -90,7 +90,7 @@ func (s *OutputClientDataConsumer) onMessage(msg *sarama.ConsumerMessage) error
}).Panicf("could not save account data") }).Panicf("could not save account data")
} }
s.notifier.OnNewEvent(nil, "", []string{string(msg.Key)}, types.PaginationToken{PDUPosition: pduPos}) s.notifier.OnNewEvent(nil, "", []string{string(msg.Key)}, types.NewStreamToken(pduPos, 0))
return nil return nil
} }

View file

@ -65,9 +65,7 @@ func (s *OutputTypingEventConsumer) Start() error {
s.db.SetTypingTimeoutCallback(func(userID, roomID string, latestSyncPosition int64) { s.db.SetTypingTimeoutCallback(func(userID, roomID string, latestSyncPosition int64) {
s.notifier.OnNewEvent( s.notifier.OnNewEvent(
nil, roomID, nil, nil, roomID, nil,
types.PaginationToken{ types.NewStreamToken(0, types.StreamPosition(latestSyncPosition)),
EDUTypingPosition: types.StreamPosition(latestSyncPosition),
},
) )
}) })
@ -96,6 +94,6 @@ func (s *OutputTypingEventConsumer) onMessage(msg *sarama.ConsumerMessage) error
typingPos = s.db.RemoveTypingUser(typingEvent.UserID, typingEvent.RoomID) typingPos = s.db.RemoveTypingUser(typingEvent.UserID, typingEvent.RoomID)
} }
s.notifier.OnNewEvent(nil, output.Event.RoomID, nil, types.PaginationToken{EDUTypingPosition: typingPos}) s.notifier.OnNewEvent(nil, output.Event.RoomID, nil, types.NewStreamToken(0, typingPos))
return nil return nil
} }

View file

@ -146,7 +146,7 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent(
}).Panicf("roomserver output log: write event failure") }).Panicf("roomserver output log: write event failure")
return nil return nil
} }
s.notifier.OnNewEvent(&ev, "", nil, types.PaginationToken{PDUPosition: pduPos}) s.notifier.OnNewEvent(&ev, "", nil, types.NewStreamToken(pduPos, 0))
return nil return nil
} }
@ -164,7 +164,7 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent(
}).Panicf("roomserver output log: write invite failure") }).Panicf("roomserver output log: write invite failure")
return nil return nil
} }
s.notifier.OnNewEvent(&msg.Event, "", nil, types.PaginationToken{PDUPosition: pduPos}) s.notifier.OnNewEvent(&msg.Event, "", nil, types.NewStreamToken(pduPos, 0))
return nil return nil
} }

View file

@ -38,8 +38,9 @@ type messagesReq struct {
federation *gomatrixserverlib.FederationClient federation *gomatrixserverlib.FederationClient
cfg *config.Dendrite cfg *config.Dendrite
roomID string roomID string
from *types.PaginationToken from *types.TopologyToken
to *types.PaginationToken to *types.TopologyToken
fromStream *types.StreamingToken
wasToProvided bool wasToProvided bool
limit int limit int
backwardOrdering bool backwardOrdering bool
@ -66,11 +67,16 @@ func OnIncomingMessagesRequest(
// Extract parameters from the request's URL. // Extract parameters from the request's URL.
// Pagination tokens. // Pagination tokens.
from, err := types.NewPaginationTokenFromString(req.URL.Query().Get("from")) var fromStream *types.StreamingToken
from, err := types.NewTopologyTokenFromString(req.URL.Query().Get("from"))
if err != nil { if err != nil {
fs, err2 := types.NewStreamTokenFromString(req.URL.Query().Get("from"))
fromStream = &fs
if err2 != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.InvalidArgumentValue("Invalid from parameter: " + err.Error()), JSON: jsonerror.InvalidArgumentValue("Invalid from parameter: " + err2.Error()),
}
} }
} }
@ -88,10 +94,10 @@ func OnIncomingMessagesRequest(
// Pagination tokens. To is optional, and its default value depends on the // Pagination tokens. To is optional, and its default value depends on the
// direction ("b" or "f"). // direction ("b" or "f").
var to *types.PaginationToken var to types.TopologyToken
wasToProvided := true wasToProvided := true
if s := req.URL.Query().Get("to"); len(s) > 0 { if s := req.URL.Query().Get("to"); len(s) > 0 {
to, err = types.NewPaginationTokenFromString(s) to, err = types.NewTopologyTokenFromString(s)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
@ -139,8 +145,9 @@ func OnIncomingMessagesRequest(
federation: federation, federation: federation,
cfg: cfg, cfg: cfg,
roomID: roomID, roomID: roomID,
from: from, from: &from,
to: to, to: &to,
fromStream: fromStream,
wasToProvided: wasToProvided, wasToProvided: wasToProvided,
limit: limit, limit: limit,
backwardOrdering: backwardOrdering, backwardOrdering: backwardOrdering,
@ -178,12 +185,20 @@ func OnIncomingMessagesRequest(
// remote homeserver. // remote homeserver.
func (r *messagesReq) retrieveEvents() ( func (r *messagesReq) retrieveEvents() (
clientEvents []gomatrixserverlib.ClientEvent, start, clientEvents []gomatrixserverlib.ClientEvent, start,
end *types.PaginationToken, err error, end types.TopologyToken, err error,
) { ) {
// Retrieve the events from the local database. // Retrieve the events from the local database.
streamEvents, err := r.db.GetEventsInRange( var streamEvents []types.StreamEvent
if r.fromStream != nil {
toStream := r.to.StreamToken()
streamEvents, err = r.db.GetEventsInStreamingRange(
r.ctx, r.fromStream, &toStream, r.roomID, r.limit, r.backwardOrdering,
)
} else {
streamEvents, err = r.db.GetEventsInTopologicalRange(
r.ctx, r.from, r.to, r.roomID, r.limit, r.backwardOrdering, r.ctx, r.from, r.to, r.roomID, r.limit, r.backwardOrdering,
) )
}
if err != nil { if err != nil {
err = fmt.Errorf("GetEventsInRange: %w", err) err = fmt.Errorf("GetEventsInRange: %w", err)
return return
@ -206,7 +221,7 @@ func (r *messagesReq) retrieveEvents() (
// If we didn't get any event, we don't need to proceed any further. // If we didn't get any event, we don't need to proceed any further.
if len(events) == 0 { if len(events) == 0 {
return []gomatrixserverlib.ClientEvent{}, r.from, r.to, nil return []gomatrixserverlib.ClientEvent{}, *r.from, *r.to, nil
} }
// Sort the events to ensure we send them in the right order. // Sort the events to ensure we send them in the right order.
@ -246,12 +261,8 @@ func (r *messagesReq) retrieveEvents() (
} }
// Generate pagination tokens to send to the client using the positions // Generate pagination tokens to send to the client using the positions
// retrieved previously. // retrieved previously.
start = types.NewPaginationTokenFromTypeAndPosition( start = types.NewTopologyToken(startPos, startStreamPos)
types.PaginationTokenTypeTopology, startPos, startStreamPos, end = types.NewTopologyToken(endPos, endStreamPos)
)
end = types.NewPaginationTokenFromTypeAndPosition(
types.PaginationTokenTypeTopology, endPos, endStreamPos,
)
if r.backwardOrdering { if r.backwardOrdering {
// A stream/topological position is a cursor located between two events. // A stream/topological position is a cursor located between two events.
@ -259,14 +270,7 @@ func (r *messagesReq) retrieveEvents() (
// we consider a left to right chronological order), tokens need to refer // we consider a left to right chronological order), tokens need to refer
// to them by the event on their left, therefore we need to decrement the // to them by the event on their left, therefore we need to decrement the
// end position we send in the response if we're going backward. // end position we send in the response if we're going backward.
end.PDUPosition-- end.Decrement()
end.EDUTypingPosition += 1000
}
// The lowest token value is 1, therefore we need to manually set it to that
// value if we're below it.
if end.PDUPosition < types.StreamPosition(1) {
end.PDUPosition = types.StreamPosition(1)
} }
return clientEvents, start, end, err return clientEvents, start, end, err
@ -317,11 +321,11 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent
// The condition in the SQL query is a strict "greater than" so // The condition in the SQL query is a strict "greater than" so
// we need to check against to-1. // we need to check against to-1.
streamPos := types.StreamPosition(streamEvents[len(streamEvents)-1].StreamPosition) streamPos := types.StreamPosition(streamEvents[len(streamEvents)-1].StreamPosition)
isSetLargeEnough = (r.to.PDUPosition-1 == streamPos) isSetLargeEnough = (r.to.PDUPosition()-1 == streamPos)
} }
} else { } else {
streamPos := types.StreamPosition(streamEvents[0].StreamPosition) streamPos := types.StreamPosition(streamEvents[0].StreamPosition)
isSetLargeEnough = (r.from.PDUPosition-1 == streamPos) isSetLargeEnough = (r.from.PDUPosition()-1 == streamPos)
} }
} }
@ -424,18 +428,17 @@ func (r *messagesReq) backfill(roomID string, fromEventIDs []string, limit int)
func setToDefault( func setToDefault(
ctx context.Context, db storage.Database, backwardOrdering bool, ctx context.Context, db storage.Database, backwardOrdering bool,
roomID string, roomID string,
) (to *types.PaginationToken, err error) { ) (to types.TopologyToken, err error) {
if backwardOrdering { if backwardOrdering {
// go 1 earlier than the first event so we correctly fetch the earliest event // go 1 earlier than the first event so we correctly fetch the earliest event
to = types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, 0, 0) to = types.NewTopologyToken(0, 0)
} else { } else {
var pos, stream types.StreamPosition var depth, stream types.StreamPosition
pos, stream, err = db.MaxTopologicalPosition(ctx, roomID) depth, stream, err = db.MaxTopologicalPosition(ctx, roomID)
if err != nil { if err != nil {
return return
} }
to = types.NewTopologyToken(depth, stream)
to = types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, pos, stream)
} }
return return

View file

@ -50,13 +50,13 @@ type Database interface {
// Returns an error if there was an issue with the retrieval. // Returns an error if there was an issue with the retrieval.
GetStateEventsForRoom(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter) (stateEvents []gomatrixserverlib.HeaderedEvent, err error) GetStateEventsForRoom(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter) (stateEvents []gomatrixserverlib.HeaderedEvent, err error)
// SyncPosition returns the latest positions for syncing. // SyncPosition returns the latest positions for syncing.
SyncPosition(ctx context.Context) (types.PaginationToken, error) SyncPosition(ctx context.Context) (types.StreamingToken, error)
// IncrementalSync returns all the data needed in order to create an incremental // IncrementalSync returns all the data needed in order to create an incremental
// sync response for the given user. Events returned will include any client // sync response for the given user. Events returned will include any client
// transaction IDs associated with the given device. These transaction IDs come // transaction IDs associated with the given device. These transaction IDs come
// from when the device sent the event via an API that included a transaction // from when the device sent the event via an API that included a transaction
// ID. // ID.
IncrementalSync(ctx context.Context, device authtypes.Device, fromPos, toPos types.PaginationToken, numRecentEventsPerRoom int, wantFullState bool) (*types.Response, error) IncrementalSync(ctx context.Context, device authtypes.Device, fromPos, toPos types.StreamingToken, numRecentEventsPerRoom int, wantFullState bool) (*types.Response, error)
// CompleteSync returns a complete /sync API response for the given user. // CompleteSync returns a complete /sync API response for the given user.
CompleteSync(ctx context.Context, userID string, numRecentEventsPerRoom int) (*types.Response, error) CompleteSync(ctx context.Context, userID string, numRecentEventsPerRoom int) (*types.Response, error)
// GetAccountDataInRange returns all account data for a given user inserted or // GetAccountDataInRange returns all account data for a given user inserted or
@ -88,9 +88,10 @@ type Database interface {
// RemoveTypingUser removes a typing user from the typing cache. // RemoveTypingUser removes a typing user from the typing cache.
// Returns the newly calculated sync position for typing notifications. // Returns the newly calculated sync position for typing notifications.
RemoveTypingUser(userID, roomID string) types.StreamPosition RemoveTypingUser(userID, roomID string) types.StreamPosition
// GetEventsInRange retrieves all of the events on a given ordering using the // GetEventsInStreamingRange retrieves all of the events on a given ordering using the given extremities and limit.
// given extremities and limit. GetEventsInStreamingRange(ctx context.Context, from, to *types.StreamingToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error)
GetEventsInRange(ctx context.Context, from, to *types.PaginationToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error) // GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit.
GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error)
// EventPositionInTopology returns the depth and stream position of the given event. // EventPositionInTopology returns the depth and stream position of the given event.
EventPositionInTopology(ctx context.Context, eventID string) (depth types.StreamPosition, stream types.StreamPosition, err error) EventPositionInTopology(ctx context.Context, eventID string) (depth types.StreamPosition, stream types.StreamPosition, err error)
// EventsAtTopologicalPosition returns all of the events matching a given // EventsAtTopologicalPosition returns all of the events matching a given

View file

@ -228,29 +228,25 @@ func (d *SyncServerDatasource) GetStateEventsForRoom(
return return
} }
func (d *SyncServerDatasource) GetEventsInRange( func (d *SyncServerDatasource) GetEventsInTopologicalRange(
ctx context.Context, ctx context.Context,
from, to *types.PaginationToken, from, to *types.TopologyToken,
roomID string, limit int, roomID string, limit int,
backwardOrdering bool, backwardOrdering bool,
) (events []types.StreamEvent, err error) { ) (events []types.StreamEvent, err error) {
// If the pagination token's type is types.PaginationTokenTypeTopology, the
// events must be retrieved from the rooms' topology table rather than the
// table contaning the syncapi server's whole stream of events.
if from.Type == types.PaginationTokenTypeTopology {
// Determine the backward and forward limit, i.e. the upper and lower // Determine the backward and forward limit, i.e. the upper and lower
// limits to the selection in the room's topology, from the direction. // limits to the selection in the room's topology, from the direction.
var backwardLimit, forwardLimit, forwardMicroLimit types.StreamPosition var backwardLimit, forwardLimit, forwardMicroLimit types.StreamPosition
if backwardOrdering { if backwardOrdering {
// Backward ordering is antichronological (latest event to oldest // Backward ordering is antichronological (latest event to oldest
// one). // one).
backwardLimit = to.PDUPosition backwardLimit = to.Depth()
forwardLimit = from.PDUPosition forwardLimit = from.Depth()
forwardMicroLimit = from.EDUTypingPosition forwardMicroLimit = from.PDUPosition()
} else { } else {
// Forward ordering is chronological (oldest event to latest one). // Forward ordering is chronological (oldest event to latest one).
backwardLimit = from.PDUPosition backwardLimit = from.Depth()
forwardLimit = to.PDUPosition forwardLimit = to.Depth()
} }
// Select the event IDs from the defined range. // Select the event IDs from the defined range.
@ -265,32 +261,35 @@ func (d *SyncServerDatasource) GetEventsInRange(
// Retrieve the events' contents using their IDs. // Retrieve the events' contents using their IDs.
events, err = d.events.selectEvents(ctx, nil, eIDs) events, err = d.events.selectEvents(ctx, nil, eIDs)
return return
} }
// If the pagination token's type is types.PaginationTokenTypeStream, the
// events must be retrieved from the table contaning the syncapi server's
// whole stream of events.
// GetEventsInStreamingRange retrieves all of the events on a given ordering using the
// given extremities and limit.
func (d *SyncServerDatasource) GetEventsInStreamingRange(
ctx context.Context,
from, to *types.StreamingToken,
roomID string, limit int,
backwardOrdering bool,
) (events []types.StreamEvent, err error) {
if backwardOrdering { if backwardOrdering {
// When using backward ordering, we want the most recent events first. // When using backward ordering, we want the most recent events first.
if events, err = d.events.selectRecentEvents( if events, err = d.events.selectRecentEvents(
ctx, nil, roomID, to.PDUPosition, from.PDUPosition, limit, false, false, ctx, nil, roomID, to.PDUPosition(), from.PDUPosition(), limit, false, false,
); err != nil { ); err != nil {
return return
} }
} else { } else {
// When using forward ordering, we want the least recent events first. // When using forward ordering, we want the least recent events first.
if events, err = d.events.selectEarlyEvents( if events, err = d.events.selectEarlyEvents(
ctx, nil, roomID, from.PDUPosition, to.PDUPosition, limit, ctx, nil, roomID, from.PDUPosition(), to.PDUPosition(), limit,
); err != nil { ); err != nil {
return return
} }
} }
return events, err
return
} }
func (d *SyncServerDatasource) SyncPosition(ctx context.Context) (types.PaginationToken, error) { func (d *SyncServerDatasource) SyncPosition(ctx context.Context) (types.StreamingToken, error) {
return d.syncPositionTx(ctx, nil) return d.syncPositionTx(ctx, nil)
} }
@ -353,7 +352,7 @@ func (d *SyncServerDatasource) syncStreamPositionTx(
func (d *SyncServerDatasource) syncPositionTx( func (d *SyncServerDatasource) syncPositionTx(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
) (sp types.PaginationToken, err error) { ) (sp types.StreamingToken, err error) {
maxEventID, err := d.events.selectMaxEventID(ctx, txn) maxEventID, err := d.events.selectMaxEventID(ctx, txn)
if err != nil { if err != nil {
@ -373,8 +372,7 @@ func (d *SyncServerDatasource) syncPositionTx(
if maxInviteID > maxEventID { if maxInviteID > maxEventID {
maxEventID = maxInviteID maxEventID = maxInviteID
} }
sp.PDUPosition = types.StreamPosition(maxEventID) sp = types.NewStreamToken(types.StreamPosition(maxEventID), types.StreamPosition(d.eduCache.GetLatestSyncPosition()))
sp.EDUTypingPosition = types.StreamPosition(d.eduCache.GetLatestSyncPosition())
return return
} }
@ -439,7 +437,7 @@ func (d *SyncServerDatasource) addPDUDeltaToResponse(
// addTypingDeltaToResponse adds all typing notifications to a sync response // addTypingDeltaToResponse adds all typing notifications to a sync response
// since the specified position. // since the specified position.
func (d *SyncServerDatasource) addTypingDeltaToResponse( func (d *SyncServerDatasource) addTypingDeltaToResponse(
since types.PaginationToken, since types.StreamingToken,
joinedRoomIDs []string, joinedRoomIDs []string,
res *types.Response, res *types.Response,
) error { ) error {
@ -448,7 +446,7 @@ func (d *SyncServerDatasource) addTypingDeltaToResponse(
var err error var err error
for _, roomID := range joinedRoomIDs { for _, roomID := range joinedRoomIDs {
if typingUsers, updated := d.eduCache.GetTypingUsersIfUpdatedAfter( if typingUsers, updated := d.eduCache.GetTypingUsersIfUpdatedAfter(
roomID, int64(since.EDUTypingPosition), roomID, int64(since.EDUPosition()),
); updated { ); updated {
ev := gomatrixserverlib.ClientEvent{ ev := gomatrixserverlib.ClientEvent{
Type: gomatrixserverlib.MTyping, Type: gomatrixserverlib.MTyping,
@ -473,12 +471,12 @@ func (d *SyncServerDatasource) addTypingDeltaToResponse(
// addEDUDeltaToResponse adds updates for EDUs of each type since fromPos if // addEDUDeltaToResponse adds updates for EDUs of each type since fromPos if
// the positions of that type are not equal in fromPos and toPos. // the positions of that type are not equal in fromPos and toPos.
func (d *SyncServerDatasource) addEDUDeltaToResponse( func (d *SyncServerDatasource) addEDUDeltaToResponse(
fromPos, toPos types.PaginationToken, fromPos, toPos types.StreamingToken,
joinedRoomIDs []string, joinedRoomIDs []string,
res *types.Response, res *types.Response,
) (err error) { ) (err error) {
if fromPos.EDUTypingPosition != toPos.EDUTypingPosition { if fromPos.EDUPosition() != toPos.EDUPosition() {
err = d.addTypingDeltaToResponse( err = d.addTypingDeltaToResponse(
fromPos, joinedRoomIDs, res, fromPos, joinedRoomIDs, res,
) )
@ -490,7 +488,7 @@ func (d *SyncServerDatasource) addEDUDeltaToResponse(
func (d *SyncServerDatasource) IncrementalSync( func (d *SyncServerDatasource) IncrementalSync(
ctx context.Context, ctx context.Context,
device authtypes.Device, device authtypes.Device,
fromPos, toPos types.PaginationToken, fromPos, toPos types.StreamingToken,
numRecentEventsPerRoom int, numRecentEventsPerRoom int,
wantFullState bool, wantFullState bool,
) (*types.Response, error) { ) (*types.Response, error) {
@ -499,9 +497,9 @@ func (d *SyncServerDatasource) IncrementalSync(
var joinedRoomIDs []string var joinedRoomIDs []string
var err error var err error
if fromPos.PDUPosition != toPos.PDUPosition || wantFullState { if fromPos.PDUPosition() != toPos.PDUPosition() || wantFullState {
joinedRoomIDs, err = d.addPDUDeltaToResponse( joinedRoomIDs, err = d.addPDUDeltaToResponse(
ctx, device, fromPos.PDUPosition, toPos.PDUPosition, numRecentEventsPerRoom, wantFullState, res, ctx, device, fromPos.PDUPosition(), toPos.PDUPosition(), numRecentEventsPerRoom, wantFullState, res,
) )
} else { } else {
joinedRoomIDs, err = d.roomstate.selectRoomIDsWithMembership( joinedRoomIDs, err = d.roomstate.selectRoomIDsWithMembership(
@ -530,7 +528,7 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync(
numRecentEventsPerRoom int, numRecentEventsPerRoom int,
) ( ) (
res *types.Response, res *types.Response,
toPos types.PaginationToken, toPos types.StreamingToken,
joinedRoomIDs []string, joinedRoomIDs []string,
err error, err error,
) { ) {
@ -577,7 +575,7 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync(
// See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316 // See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316
var recentStreamEvents []types.StreamEvent var recentStreamEvents []types.StreamEvent
recentStreamEvents, err = d.events.selectRecentEvents( recentStreamEvents, err = d.events.selectRecentEvents(
ctx, txn, roomID, types.StreamPosition(0), toPos.PDUPosition, ctx, txn, roomID, types.StreamPosition(0), toPos.PDUPosition(),
numRecentEventsPerRoom, true, true, numRecentEventsPerRoom, true, true,
) )
if err != nil { if err != nil {
@ -588,27 +586,25 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync(
// oldest event in the room's topology. // oldest event in the room's topology.
var backwardTopologyPos, backwardStreamPos types.StreamPosition var backwardTopologyPos, backwardStreamPos types.StreamPosition
backwardTopologyPos, backwardStreamPos, err = d.topology.selectPositionInTopology(ctx, recentStreamEvents[0].EventID()) backwardTopologyPos, backwardStreamPos, err = d.topology.selectPositionInTopology(ctx, recentStreamEvents[0].EventID())
if backwardTopologyPos-1 <= 0 { if err != nil {
backwardTopologyPos = types.StreamPosition(1) return
} else {
backwardTopologyPos--
} }
prevBatch := types.NewTopologyToken(backwardTopologyPos, backwardStreamPos)
prevBatch.Decrement()
// We don't include a device here as we don't need to send down // We don't include a device here as we don't need to send down
// transaction IDs for complete syncs // transaction IDs for complete syncs
recentEvents := d.StreamEventsToEvents(nil, recentStreamEvents) recentEvents := d.StreamEventsToEvents(nil, recentStreamEvents)
stateEvents = removeDuplicates(stateEvents, recentEvents) stateEvents = removeDuplicates(stateEvents, recentEvents)
jr := types.NewJoinResponse() jr := types.NewJoinResponse()
jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( jr.Timeline.PrevBatch = prevBatch.String()
types.PaginationTokenTypeTopology, backwardTopologyPos, backwardStreamPos,
).String()
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = true jr.Timeline.Limited = true
jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(stateEvents, gomatrixserverlib.FormatSync) jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(stateEvents, gomatrixserverlib.FormatSync)
res.Rooms.Join[roomID] = *jr res.Rooms.Join[roomID] = *jr
} }
if err = d.addInvitesToResponse(ctx, txn, userID, 0, toPos.PDUPosition, res); err != nil { if err = d.addInvitesToResponse(ctx, txn, userID, 0, toPos.PDUPosition(), res); err != nil {
return return
} }
@ -628,7 +624,7 @@ func (d *SyncServerDatasource) CompleteSync(
// Use a zero value SyncPosition for fromPos so all EDU states are added. // Use a zero value SyncPosition for fromPos so all EDU states are added.
err = d.addEDUDeltaToResponse( err = d.addEDUDeltaToResponse(
types.PaginationToken{}, toPos, joinedRoomIDs, res, types.NewStreamToken(0, 0), toPos, joinedRoomIDs, res,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -757,14 +753,15 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse(
recentEvents := d.StreamEventsToEvents(device, recentStreamEvents) recentEvents := d.StreamEventsToEvents(device, recentStreamEvents)
delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents) // roll back delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents) // roll back
backwardTopologyPos, backwardStreamPos := d.getBackwardTopologyPos(ctx, recentStreamEvents) backwardTopologyPos, backwardStreamPos := d.getBackwardTopologyPos(ctx, recentStreamEvents)
prevBatch := types.NewTopologyToken(
backwardTopologyPos, backwardStreamPos,
)
switch delta.membership { switch delta.membership {
case gomatrixserverlib.Join: case gomatrixserverlib.Join:
jr := types.NewJoinResponse() jr := types.NewJoinResponse()
jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( jr.Timeline.PrevBatch = prevBatch.String()
types.PaginationTokenTypeTopology, backwardTopologyPos, backwardStreamPos,
).String()
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true jr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true
jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync) jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync)
@ -775,9 +772,7 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse(
// TODO: recentEvents may contain events that this user is not allowed to see because they are // TODO: recentEvents may contain events that this user is not allowed to see because they are
// no longer in the room. // no longer in the room.
lr := types.NewLeaveResponse() lr := types.NewLeaveResponse()
lr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( lr.Timeline.PrevBatch = prevBatch.String()
types.PaginationTokenTypeTopology, backwardTopologyPos, backwardStreamPos,
).String()
lr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) lr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true
lr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync) lr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync)

View file

@ -269,18 +269,14 @@ func (d *SyncServerDatasource) GetStateEventsForRoom(
return return
} }
// GetEventsInRange retrieves all of the events on a given ordering using the // GetEventsInTopologicalRange retrieves all of the events on a given ordering using the
// given extremities and limit. // given extremities and limit.
func (d *SyncServerDatasource) GetEventsInRange( func (d *SyncServerDatasource) GetEventsInTopologicalRange(
ctx context.Context, ctx context.Context,
from, to *types.PaginationToken, from, to *types.TopologyToken,
roomID string, limit int, roomID string, limit int,
backwardOrdering bool, backwardOrdering bool,
) (events []types.StreamEvent, err error) { ) (events []types.StreamEvent, err error) {
// If the pagination token's type is types.PaginationTokenTypeTopology, the
// events must be retrieved from the rooms' topology table rather than the
// table contaning the syncapi server's whole stream of events.
if from.Type == types.PaginationTokenTypeTopology {
// TODO: ARGH CONFUSING // TODO: ARGH CONFUSING
// Determine the backward and forward limit, i.e. the upper and lower // Determine the backward and forward limit, i.e. the upper and lower
// limits to the selection in the room's topology, from the direction. // limits to the selection in the room's topology, from the direction.
@ -288,13 +284,13 @@ func (d *SyncServerDatasource) GetEventsInRange(
if backwardOrdering { if backwardOrdering {
// Backward ordering is antichronological (latest event to oldest // Backward ordering is antichronological (latest event to oldest
// one). // one).
backwardLimit = to.PDUPosition backwardLimit = to.Depth()
forwardLimit = from.PDUPosition forwardLimit = from.Depth()
forwardMicroLimit = from.EDUTypingPosition forwardMicroLimit = from.PDUPosition()
} else { } else {
// Forward ordering is chronological (oldest event to latest one). // Forward ordering is chronological (oldest event to latest one).
backwardLimit = from.PDUPosition backwardLimit = from.Depth()
forwardLimit = to.PDUPosition forwardLimit = to.Depth()
} }
// Select the event IDs from the defined range. // Select the event IDs from the defined range.
@ -309,23 +305,27 @@ func (d *SyncServerDatasource) GetEventsInRange(
// Retrieve the events' contents using their IDs. // Retrieve the events' contents using their IDs.
events, err = d.events.selectEvents(ctx, nil, eIDs) events, err = d.events.selectEvents(ctx, nil, eIDs)
return return
} }
// If the pagination token's type is types.PaginationTokenTypeStream, the
// events must be retrieved from the table contaning the syncapi server's
// whole stream of events.
// GetEventsInStreamingRange retrieves all of the events on a given ordering using the
// given extremities and limit.
func (d *SyncServerDatasource) GetEventsInStreamingRange(
ctx context.Context,
from, to *types.StreamingToken,
roomID string, limit int,
backwardOrdering bool,
) (events []types.StreamEvent, err error) {
if backwardOrdering { if backwardOrdering {
// When using backward ordering, we want the most recent events first. // When using backward ordering, we want the most recent events first.
if events, err = d.events.selectRecentEvents( if events, err = d.events.selectRecentEvents(
ctx, nil, roomID, to.PDUPosition, from.PDUPosition, limit, false, false, ctx, nil, roomID, to.PDUPosition(), from.PDUPosition(), limit, false, false,
); err != nil { ); err != nil {
return return
} }
} else { } else {
// When using forward ordering, we want the least recent events first. // When using forward ordering, we want the least recent events first.
if events, err = d.events.selectEarlyEvents( if events, err = d.events.selectEarlyEvents(
ctx, nil, roomID, from.PDUPosition, to.PDUPosition, limit, ctx, nil, roomID, from.PDUPosition(), to.PDUPosition(), limit,
); err != nil { ); err != nil {
return return
} }
@ -334,10 +334,14 @@ func (d *SyncServerDatasource) GetEventsInRange(
} }
// SyncPosition returns the latest positions for syncing. // SyncPosition returns the latest positions for syncing.
func (d *SyncServerDatasource) SyncPosition(ctx context.Context) (tok types.PaginationToken, err error) { func (d *SyncServerDatasource) SyncPosition(ctx context.Context) (tok types.StreamingToken, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error { err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
tok, err = d.syncPositionTx(ctx, txn) pos, err := d.syncPositionTx(ctx, txn)
if err != nil {
return err return err
}
tok = *pos
return nil
}) })
return return
} }
@ -412,30 +416,31 @@ func (d *SyncServerDatasource) syncStreamPositionTx(
func (d *SyncServerDatasource) syncPositionTx( func (d *SyncServerDatasource) syncPositionTx(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
) (sp types.PaginationToken, err error) { ) (*types.StreamingToken, error) {
maxEventID, err := d.events.selectMaxEventID(ctx, txn) maxEventID, err := d.events.selectMaxEventID(ctx, txn)
if err != nil { if err != nil {
return sp, err return nil, err
} }
maxAccountDataID, err := d.accountData.selectMaxAccountDataID(ctx, txn) maxAccountDataID, err := d.accountData.selectMaxAccountDataID(ctx, txn)
if err != nil { if err != nil {
return sp, err return nil, err
} }
if maxAccountDataID > maxEventID { if maxAccountDataID > maxEventID {
maxEventID = maxAccountDataID maxEventID = maxAccountDataID
} }
maxInviteID, err := d.invites.selectMaxInviteID(ctx, txn) maxInviteID, err := d.invites.selectMaxInviteID(ctx, txn)
if err != nil { if err != nil {
return sp, err return nil, err
} }
if maxInviteID > maxEventID { if maxInviteID > maxEventID {
maxEventID = maxInviteID maxEventID = maxInviteID
} }
sp.PDUPosition = types.StreamPosition(maxEventID) sp := types.NewStreamToken(
sp.EDUTypingPosition = types.StreamPosition(d.eduCache.GetLatestSyncPosition()) types.StreamPosition(maxEventID),
sp.Type = types.PaginationTokenTypeStream types.StreamPosition(d.eduCache.GetLatestSyncPosition()),
return )
return &sp, nil
} }
// addPDUDeltaToResponse adds all PDU deltas to a sync response. // addPDUDeltaToResponse adds all PDU deltas to a sync response.
@ -499,7 +504,7 @@ func (d *SyncServerDatasource) addPDUDeltaToResponse(
// addTypingDeltaToResponse adds all typing notifications to a sync response // addTypingDeltaToResponse adds all typing notifications to a sync response
// since the specified position. // since the specified position.
func (d *SyncServerDatasource) addTypingDeltaToResponse( func (d *SyncServerDatasource) addTypingDeltaToResponse(
since types.PaginationToken, since types.StreamingToken,
joinedRoomIDs []string, joinedRoomIDs []string,
res *types.Response, res *types.Response,
) error { ) error {
@ -508,7 +513,7 @@ func (d *SyncServerDatasource) addTypingDeltaToResponse(
var err error var err error
for _, roomID := range joinedRoomIDs { for _, roomID := range joinedRoomIDs {
if typingUsers, updated := d.eduCache.GetTypingUsersIfUpdatedAfter( if typingUsers, updated := d.eduCache.GetTypingUsersIfUpdatedAfter(
roomID, int64(since.EDUTypingPosition), roomID, int64(since.EDUPosition()),
); updated { ); updated {
ev := gomatrixserverlib.ClientEvent{ ev := gomatrixserverlib.ClientEvent{
Type: gomatrixserverlib.MTyping, Type: gomatrixserverlib.MTyping,
@ -533,12 +538,12 @@ func (d *SyncServerDatasource) addTypingDeltaToResponse(
// addEDUDeltaToResponse adds updates for EDUs of each type since fromPos if // addEDUDeltaToResponse adds updates for EDUs of each type since fromPos if
// the positions of that type are not equal in fromPos and toPos. // the positions of that type are not equal in fromPos and toPos.
func (d *SyncServerDatasource) addEDUDeltaToResponse( func (d *SyncServerDatasource) addEDUDeltaToResponse(
fromPos, toPos types.PaginationToken, fromPos, toPos types.StreamingToken,
joinedRoomIDs []string, joinedRoomIDs []string,
res *types.Response, res *types.Response,
) (err error) { ) (err error) {
if fromPos.EDUTypingPosition != toPos.EDUTypingPosition { if fromPos.EDUPosition() != toPos.EDUPosition() {
err = d.addTypingDeltaToResponse( err = d.addTypingDeltaToResponse(
fromPos, joinedRoomIDs, res, fromPos, joinedRoomIDs, res,
) )
@ -555,18 +560,21 @@ func (d *SyncServerDatasource) addEDUDeltaToResponse(
func (d *SyncServerDatasource) IncrementalSync( func (d *SyncServerDatasource) IncrementalSync(
ctx context.Context, ctx context.Context,
device authtypes.Device, device authtypes.Device,
fromPos, toPos types.PaginationToken, fromPos, toPos types.StreamingToken,
numRecentEventsPerRoom int, numRecentEventsPerRoom int,
wantFullState bool, wantFullState bool,
) (*types.Response, error) { ) (*types.Response, error) {
fmt.Println("from ", fromPos, "to", toPos)
nextBatchPos := fromPos.WithUpdates(toPos) nextBatchPos := fromPos.WithUpdates(toPos)
res := types.NewResponse(nextBatchPos) res := types.NewResponse(nextBatchPos)
fmt.Println("from ", fromPos, "to", toPos, "next", nextBatchPos)
var joinedRoomIDs []string var joinedRoomIDs []string
var err error var err error
if fromPos.PDUPosition != toPos.PDUPosition || wantFullState { fmt.Println("from", fromPos.PDUPosition(), "to", toPos.PDUPosition())
if fromPos.PDUPosition() != toPos.PDUPosition() || wantFullState {
joinedRoomIDs, err = d.addPDUDeltaToResponse( joinedRoomIDs, err = d.addPDUDeltaToResponse(
ctx, device, fromPos.PDUPosition, toPos.PDUPosition, numRecentEventsPerRoom, wantFullState, res, ctx, device, fromPos.PDUPosition(), toPos.PDUPosition(), numRecentEventsPerRoom, wantFullState, res,
) )
} else { } else {
joinedRoomIDs, err = d.roomstate.selectRoomIDsWithMembership( joinedRoomIDs, err = d.roomstate.selectRoomIDsWithMembership(
@ -595,7 +603,7 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync(
numRecentEventsPerRoom int, numRecentEventsPerRoom int,
) ( ) (
res *types.Response, res *types.Response,
toPos types.PaginationToken, toPos *types.StreamingToken,
joinedRoomIDs []string, joinedRoomIDs []string,
err error, err error,
) { ) {
@ -621,7 +629,7 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync(
return return
} }
res = types.NewResponse(toPos) res = types.NewResponse(*toPos)
// Extract room state and recent events for all rooms the user is joined to. // Extract room state and recent events for all rooms the user is joined to.
joinedRoomIDs, err = d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join) joinedRoomIDs, err = d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join)
@ -643,7 +651,7 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync(
// See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316 // See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316
var recentStreamEvents []types.StreamEvent var recentStreamEvents []types.StreamEvent
recentStreamEvents, err = d.events.selectRecentEvents( recentStreamEvents, err = d.events.selectRecentEvents(
ctx, txn, roomID, types.StreamPosition(0), toPos.PDUPosition, ctx, txn, roomID, types.StreamPosition(0), toPos.PDUPosition(),
numRecentEventsPerRoom, true, true, numRecentEventsPerRoom, true, true,
) )
if err != nil { if err != nil {
@ -655,28 +663,22 @@ func (d *SyncServerDatasource) getResponseWithPDUsForCompleteSync(
// oldest event in the room's topology. // oldest event in the room's topology.
var backwardTopologyPos, backwardTopologyStreamPos types.StreamPosition var backwardTopologyPos, backwardTopologyStreamPos types.StreamPosition
backwardTopologyPos, backwardTopologyStreamPos, err = d.topology.selectPositionInTopology(ctx, txn, recentStreamEvents[0].EventID()) backwardTopologyPos, backwardTopologyStreamPos, err = d.topology.selectPositionInTopology(ctx, txn, recentStreamEvents[0].EventID())
if backwardTopologyPos-1 <= 0 { prevBatch := types.NewTopologyToken(backwardTopologyPos, backwardTopologyStreamPos)
backwardTopologyPos = types.StreamPosition(1) prevBatch.Decrement()
} else {
backwardTopologyPos--
backwardTopologyStreamPos += 1000 // this has to be bigger than the number of events we backfill per request
}
// We don't include a device here as we don't need to send down // We don't include a device here as we don't need to send down
// transaction IDs for complete syncs // transaction IDs for complete syncs
recentEvents := d.StreamEventsToEvents(nil, recentStreamEvents) recentEvents := d.StreamEventsToEvents(nil, recentStreamEvents)
stateEvents = removeDuplicates(stateEvents, recentEvents) stateEvents = removeDuplicates(stateEvents, recentEvents)
jr := types.NewJoinResponse() jr := types.NewJoinResponse()
jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( jr.Timeline.PrevBatch = prevBatch.String()
types.PaginationTokenTypeTopology, backwardTopologyPos, backwardTopologyStreamPos,
).String()
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = true jr.Timeline.Limited = true
jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(stateEvents, gomatrixserverlib.FormatSync) jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(stateEvents, gomatrixserverlib.FormatSync)
res.Rooms.Join[roomID] = *jr res.Rooms.Join[roomID] = *jr
} }
if err = d.addInvitesToResponse(ctx, txn, userID, 0, toPos.PDUPosition, res); err != nil { if err = d.addInvitesToResponse(ctx, txn, userID, 0, toPos.PDUPosition(), res); err != nil {
return return
} }
@ -697,7 +699,7 @@ func (d *SyncServerDatasource) CompleteSync(
// Use a zero value SyncPosition for fromPos so all EDU states are added. // Use a zero value SyncPosition for fromPos so all EDU states are added.
err = d.addEDUDeltaToResponse( err = d.addEDUDeltaToResponse(
types.PaginationToken{}, toPos, joinedRoomIDs, res, types.NewStreamToken(0, 0), *toPos, joinedRoomIDs, res,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -860,14 +862,14 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse(
recentEvents := d.StreamEventsToEvents(device, recentStreamEvents) recentEvents := d.StreamEventsToEvents(device, recentStreamEvents)
delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents) delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents)
backwardTopologyPos, backwardStreamPos := d.getBackwardTopologyPos(ctx, txn, recentStreamEvents) backwardTopologyPos, backwardStreamPos := d.getBackwardTopologyPos(ctx, txn, recentStreamEvents)
prevBatch := types.NewTopologyToken(
backwardTopologyPos, backwardStreamPos,
)
switch delta.membership { switch delta.membership {
case gomatrixserverlib.Join: case gomatrixserverlib.Join:
jr := types.NewJoinResponse() jr := types.NewJoinResponse()
jr.Timeline.PrevBatch = prevBatch.String()
jr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition(
types.PaginationTokenTypeTopology, backwardTopologyPos, backwardStreamPos,
).String()
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true jr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true
jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync) jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync)
@ -878,9 +880,7 @@ func (d *SyncServerDatasource) addRoomDeltaToResponse(
// TODO: recentEvents may contain events that this user is not allowed to see because they are // TODO: recentEvents may contain events that this user is not allowed to see because they are
// no longer in the room. // no longer in the room.
lr := types.NewLeaveResponse() lr := types.NewLeaveResponse()
lr.Timeline.PrevBatch = types.NewPaginationTokenFromTypeAndPosition( lr.Timeline.PrevBatch = prevBatch.String()
types.PaginationTokenTypeTopology, backwardTopologyPos, backwardStreamPos,
).String()
lr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) lr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true
lr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync) lr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync)

View file

@ -154,10 +154,10 @@ func TestSyncResponse(t *testing.T) {
{ {
Name: "IncrementalSync penultimate", Name: "IncrementalSync penultimate",
DoSync: func() (*types.Response, error) { DoSync: func() (*types.Response, error) {
from := types.NewPaginationTokenFromTypeAndPosition( // pretend we are at the penultimate event from := types.NewStreamToken( // pretend we are at the penultimate event
types.PaginationTokenTypeStream, positions[len(positions)-2], types.StreamPosition(0), positions[len(positions)-2], types.StreamPosition(0),
) )
return db.IncrementalSync(ctx, testUserDeviceA, *from, latest, 5, false) return db.IncrementalSync(ctx, testUserDeviceA, from, latest, 5, false)
}, },
WantTimeline: events[len(events)-1:], WantTimeline: events[len(events)-1:],
}, },
@ -166,11 +166,11 @@ func TestSyncResponse(t *testing.T) {
{ {
Name: "IncrementalSync limited", Name: "IncrementalSync limited",
DoSync: func() (*types.Response, error) { DoSync: func() (*types.Response, error) {
from := types.NewPaginationTokenFromTypeAndPosition( // pretend we are 10 events behind from := types.NewStreamToken( // pretend we are 10 events behind
types.PaginationTokenTypeStream, positions[len(positions)-11], types.StreamPosition(0), positions[len(positions)-11], types.StreamPosition(0),
) )
// limit is set to 5 // limit is set to 5
return db.IncrementalSync(ctx, testUserDeviceA, *from, latest, 5, false) return db.IncrementalSync(ctx, testUserDeviceA, from, latest, 5, false)
}, },
// want the last 5 events, NOT the last 10. // want the last 5 events, NOT the last 10.
WantTimeline: events[len(events)-5:], WantTimeline: events[len(events)-5:],
@ -207,7 +207,7 @@ func TestSyncResponse(t *testing.T) {
if err != nil { if err != nil {
st.Fatalf("failed to do sync: %s", err) st.Fatalf("failed to do sync: %s", err)
} }
next := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeStream, latest.PDUPosition, latest.EDUTypingPosition) next := types.NewStreamToken(latest.PDUPosition(), latest.EDUPosition())
if res.NextBatch != next.String() { if res.NextBatch != next.String() {
st.Errorf("NextBatch got %s want %s", res.NextBatch, next.String()) st.Errorf("NextBatch got %s want %s", res.NextBatch, next.String())
} }
@ -230,11 +230,11 @@ func TestGetEventsInRangeWithPrevBatch(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("failed to get SyncPosition: %s", err) t.Fatalf("failed to get SyncPosition: %s", err)
} }
from := types.NewPaginationTokenFromTypeAndPosition( from := types.NewStreamToken(
types.PaginationTokenTypeStream, positions[len(positions)-2], types.StreamPosition(0), positions[len(positions)-2], types.StreamPosition(0),
) )
res, err := db.IncrementalSync(ctx, testUserDeviceA, *from, latest, 5, false) res, err := db.IncrementalSync(ctx, testUserDeviceA, from, latest, 5, false)
if err != nil { if err != nil {
t.Fatalf("failed to IncrementalSync with latest token") t.Fatalf("failed to IncrementalSync with latest token")
} }
@ -249,14 +249,14 @@ func TestGetEventsInRangeWithPrevBatch(t *testing.T) {
if prev == "" { if prev == "" {
t.Fatalf("IncrementalSync expected prev_batch token") t.Fatalf("IncrementalSync expected prev_batch token")
} }
prevBatchToken, err := types.NewPaginationTokenFromString(prev) prevBatchToken, err := types.NewTopologyTokenFromString(prev)
if err != nil { if err != nil {
t.Fatalf("failed to NewPaginationTokenFromString : %s", err) t.Fatalf("failed to NewTopologyTokenFromString : %s", err)
} }
// backpaginate 5 messages starting at the latest position. // backpaginate 5 messages starting at the latest position.
// head towards the beginning of time // head towards the beginning of time
to := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, 0, 0) to := types.NewTopologyToken(0, 0)
paginatedEvents, err := db.GetEventsInRange(ctx, prevBatchToken, to, testRoomID, 5, true) paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &prevBatchToken, &to, testRoomID, 5, true)
if err != nil { if err != nil {
t.Fatalf("GetEventsInRange returned an error: %s", err) t.Fatalf("GetEventsInRange returned an error: %s", err)
} }
@ -275,10 +275,10 @@ func TestGetEventsInRangeWithStreamToken(t *testing.T) {
t.Fatalf("failed to get SyncPosition: %s", err) t.Fatalf("failed to get SyncPosition: %s", err)
} }
// head towards the beginning of time // head towards the beginning of time
to := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, 0, 0) to := types.NewStreamToken(0, 0)
// backpaginate 5 messages starting at the latest position. // backpaginate 5 messages starting at the latest position.
paginatedEvents, err := db.GetEventsInRange(ctx, &latest, to, testRoomID, 5, true) paginatedEvents, err := db.GetEventsInStreamingRange(ctx, &latest, &to, testRoomID, 5, true)
if err != nil { if err != nil {
t.Fatalf("GetEventsInRange returned an error: %s", err) t.Fatalf("GetEventsInRange returned an error: %s", err)
} }
@ -296,12 +296,12 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("failed to get MaxTopologicalPosition: %s", err) t.Fatalf("failed to get MaxTopologicalPosition: %s", err)
} }
from := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, latest, latestStream) from := types.NewTopologyToken(latest, latestStream)
// head towards the beginning of time // head towards the beginning of time
to := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, 0, 0) to := types.NewTopologyToken(0, 0)
// backpaginate 5 messages starting at the latest position. // backpaginate 5 messages starting at the latest position.
paginatedEvents, err := db.GetEventsInRange(ctx, from, to, testRoomID, 5, true) paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, testRoomID, 5, true)
if err != nil { if err != nil {
t.Fatalf("GetEventsInRange returned an error: %s", err) t.Fatalf("GetEventsInRange returned an error: %s", err)
} }
@ -366,14 +366,14 @@ func TestGetEventsInRangeWithEventsSameDepth(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("failed to get EventPositionInTopology for event: %s", err) t.Fatalf("failed to get EventPositionInTopology for event: %s", err)
} }
fromLatest := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, latestPos, latestStreamPos) fromLatest := types.NewTopologyToken(latestPos, latestStreamPos)
fromFork := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, topoPos, streamPos) fromFork := types.NewTopologyToken(topoPos, streamPos)
// head towards the beginning of time // head towards the beginning of time
to := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, 0, 0) to := types.NewTopologyToken(0, 0)
testCases := []struct { testCases := []struct {
Name string Name string
From *types.PaginationToken From types.TopologyToken
Limit int Limit int
Wants []gomatrixserverlib.HeaderedEvent Wants []gomatrixserverlib.HeaderedEvent
}{ }{
@ -399,7 +399,7 @@ func TestGetEventsInRangeWithEventsSameDepth(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
// backpaginate messages starting at the latest position. // backpaginate messages starting at the latest position.
paginatedEvents, err := db.GetEventsInRange(ctx, tc.From, to, testRoomID, tc.Limit, true) paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &tc.From, &to, testRoomID, tc.Limit, true)
if err != nil { if err != nil {
t.Fatalf("%s GetEventsInRange returned an error: %s", tc.Name, err) t.Fatalf("%s GetEventsInRange returned an error: %s", tc.Name, err)
} }
@ -446,13 +446,13 @@ func TestGetEventsInRangeWithEventsInsertedLikeBackfill(t *testing.T) {
} }
// head towards the beginning of time // head towards the beginning of time
to := types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, 0, 0) to := types.NewTopologyToken(0, 0)
// starting at `from`, backpaginate to the beginning of time, asserting as we go. // starting at `from`, backpaginate to the beginning of time, asserting as we go.
chunkSize = 3 chunkSize = 3
events = reversed(events) events = reversed(events)
for i := 0; i < len(events); i += chunkSize { for i := 0; i < len(events); i += chunkSize {
paginatedEvents, err := db.GetEventsInRange(ctx, from, to, testRoomID, chunkSize, true) paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, from, &to, testRoomID, chunkSize, true)
if err != nil { if err != nil {
t.Fatalf("GetEventsInRange returned an error: %s", err) t.Fatalf("GetEventsInRange returned an error: %s", err)
} }
@ -506,19 +506,15 @@ func assertEventsEqual(t *testing.T, msg string, checkRoomID bool, gots []gomatr
} }
} }
func topologyTokenBefore(t *testing.T, db storage.Database, eventID string) *types.PaginationToken { func topologyTokenBefore(t *testing.T, db storage.Database, eventID string) *types.TopologyToken {
pos, spos, err := db.EventPositionInTopology(ctx, eventID) pos, spos, err := db.EventPositionInTopology(ctx, eventID)
if err != nil { if err != nil {
t.Fatalf("failed to get EventPositionInTopology: %s", err) t.Fatalf("failed to get EventPositionInTopology: %s", err)
} }
if pos-1 <= 0 { tok := types.NewTopologyToken(pos, spos)
pos = types.StreamPosition(1) tok.Decrement()
} else { return &tok
pos = pos - 1
spos += 1000 // this has to be bigger than the chunk limit
}
return types.NewPaginationTokenFromTypeAndPosition(types.PaginationTokenTypeTopology, pos, spos)
} }
func reversed(in []gomatrixserverlib.HeaderedEvent) []gomatrixserverlib.HeaderedEvent { func reversed(in []gomatrixserverlib.HeaderedEvent) []gomatrixserverlib.HeaderedEvent {

View file

@ -36,7 +36,7 @@ type Notifier struct {
// Protects currPos and userStreams. // Protects currPos and userStreams.
streamLock *sync.Mutex streamLock *sync.Mutex
// The latest sync position // The latest sync position
currPos types.PaginationToken currPos types.StreamingToken
// A map of user_id => UserStream which can be used to wake a given user's /sync request. // A map of user_id => UserStream which can be used to wake a given user's /sync request.
userStreams map[string]*UserStream userStreams map[string]*UserStream
// The last time we cleaned out stale entries from the userStreams map // The last time we cleaned out stale entries from the userStreams map
@ -46,7 +46,7 @@ type Notifier struct {
// NewNotifier creates a new notifier set to the given sync position. // NewNotifier creates a new notifier set to the given sync position.
// In order for this to be of any use, the Notifier needs to be told all rooms and // In order for this to be of any use, the Notifier needs to be told all rooms and
// the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase). // the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase).
func NewNotifier(pos types.PaginationToken) *Notifier { func NewNotifier(pos types.StreamingToken) *Notifier {
return &Notifier{ return &Notifier{
currPos: pos, currPos: pos,
roomIDToJoinedUsers: make(map[string]userIDSet), roomIDToJoinedUsers: make(map[string]userIDSet),
@ -68,7 +68,7 @@ func NewNotifier(pos types.PaginationToken) *Notifier {
// event type it handles, leaving other fields as 0. // event type it handles, leaving other fields as 0.
func (n *Notifier) OnNewEvent( func (n *Notifier) OnNewEvent(
ev *gomatrixserverlib.HeaderedEvent, roomID string, userIDs []string, ev *gomatrixserverlib.HeaderedEvent, roomID string, userIDs []string,
posUpdate types.PaginationToken, posUpdate types.StreamingToken,
) { ) {
// update the current position then notify relevant /sync streams. // update the current position then notify relevant /sync streams.
// This needs to be done PRIOR to waking up users as they will read this value. // This needs to be done PRIOR to waking up users as they will read this value.
@ -151,7 +151,7 @@ func (n *Notifier) Load(ctx context.Context, db storage.Database) error {
} }
// CurrentPosition returns the current sync position // CurrentPosition returns the current sync position
func (n *Notifier) CurrentPosition() types.PaginationToken { func (n *Notifier) CurrentPosition() types.StreamingToken {
n.streamLock.Lock() n.streamLock.Lock()
defer n.streamLock.Unlock() defer n.streamLock.Unlock()
@ -173,7 +173,7 @@ func (n *Notifier) setUsersJoinedToRooms(roomIDToUserIDs map[string][]string) {
} }
} }
func (n *Notifier) wakeupUsers(userIDs []string, newPos types.PaginationToken) { func (n *Notifier) wakeupUsers(userIDs []string, newPos types.StreamingToken) {
for _, userID := range userIDs { for _, userID := range userIDs {
stream := n.fetchUserStream(userID, false) stream := n.fetchUserStream(userID, false)
if stream != nil { if stream != nil {

View file

@ -33,11 +33,11 @@ var (
randomMessageEvent gomatrixserverlib.HeaderedEvent randomMessageEvent gomatrixserverlib.HeaderedEvent
aliceInviteBobEvent gomatrixserverlib.HeaderedEvent aliceInviteBobEvent gomatrixserverlib.HeaderedEvent
bobLeaveEvent gomatrixserverlib.HeaderedEvent bobLeaveEvent gomatrixserverlib.HeaderedEvent
syncPositionVeryOld types.PaginationToken syncPositionVeryOld = types.NewStreamToken(5, 0)
syncPositionBefore types.PaginationToken syncPositionBefore = types.NewStreamToken(11, 0)
syncPositionAfter types.PaginationToken syncPositionAfter = types.NewStreamToken(12, 0)
syncPositionNewEDU types.PaginationToken syncPositionNewEDU = types.NewStreamToken(syncPositionAfter.PDUPosition(), 1)
syncPositionAfter2 types.PaginationToken syncPositionAfter2 = types.NewStreamToken(13, 0)
) )
var ( var (
@ -47,26 +47,6 @@ var (
) )
func init() { func init() {
baseSyncPos := types.PaginationToken{
PDUPosition: 0,
EDUTypingPosition: 0,
}
syncPositionVeryOld = baseSyncPos
syncPositionVeryOld.PDUPosition = 5
syncPositionBefore = baseSyncPos
syncPositionBefore.PDUPosition = 11
syncPositionAfter = baseSyncPos
syncPositionAfter.PDUPosition = 12
syncPositionNewEDU = syncPositionAfter
syncPositionNewEDU.EDUTypingPosition = 1
syncPositionAfter2 = baseSyncPos
syncPositionAfter2.PDUPosition = 13
var err error var err error
err = json.Unmarshal([]byte(`{ err = json.Unmarshal([]byte(`{
"_room_version": "1", "_room_version": "1",
@ -118,6 +98,12 @@ func init() {
} }
} }
func mustEqualPositions(t *testing.T, got, want types.StreamingToken) {
if got.String() != want.String() {
t.Fatalf("mustEqualPositions got %s want %s", got.String(), want.String())
}
}
// Test that the current position is returned if a request is already behind. // Test that the current position is returned if a request is already behind.
func TestImmediateNotification(t *testing.T) { func TestImmediateNotification(t *testing.T) {
n := NewNotifier(syncPositionBefore) n := NewNotifier(syncPositionBefore)
@ -125,9 +111,7 @@ func TestImmediateNotification(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("TestImmediateNotification error: %s", err) t.Fatalf("TestImmediateNotification error: %s", err)
} }
if pos != syncPositionBefore { mustEqualPositions(t, pos, syncPositionBefore)
t.Fatalf("TestImmediateNotification want %v, got %v", syncPositionBefore, pos)
}
} }
// Test that new events to a joined room unblocks the request. // Test that new events to a joined room unblocks the request.
@ -144,9 +128,7 @@ func TestNewEventAndJoinedToRoom(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("TestNewEventAndJoinedToRoom error: %w", err) t.Errorf("TestNewEventAndJoinedToRoom error: %w", err)
} }
if pos != syncPositionAfter { mustEqualPositions(t, pos, syncPositionAfter)
t.Errorf("TestNewEventAndJoinedToRoom want %v, got %v", syncPositionAfter, pos)
}
wg.Done() wg.Done()
}() }()
@ -172,9 +154,7 @@ func TestNewInviteEventForUser(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("TestNewInviteEventForUser error: %w", err) t.Errorf("TestNewInviteEventForUser error: %w", err)
} }
if pos != syncPositionAfter { mustEqualPositions(t, pos, syncPositionAfter)
t.Errorf("TestNewInviteEventForUser want %v, got %v", syncPositionAfter, pos)
}
wg.Done() wg.Done()
}() }()
@ -200,9 +180,7 @@ func TestEDUWakeup(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("TestNewInviteEventForUser error: %w", err) t.Errorf("TestNewInviteEventForUser error: %w", err)
} }
if pos != syncPositionNewEDU { mustEqualPositions(t, pos, syncPositionNewEDU)
t.Errorf("TestNewInviteEventForUser want %v, got %v", syncPositionNewEDU, pos)
}
wg.Done() wg.Done()
}() }()
@ -228,9 +206,7 @@ func TestMultipleRequestWakeup(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("TestMultipleRequestWakeup error: %w", err) t.Errorf("TestMultipleRequestWakeup error: %w", err)
} }
if pos != syncPositionAfter { mustEqualPositions(t, pos, syncPositionAfter)
t.Errorf("TestMultipleRequestWakeup want %v, got %v", syncPositionAfter, pos)
}
wg.Done() wg.Done()
} }
go poll() go poll()
@ -268,9 +244,7 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %w", err) t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %w", err)
} }
if pos != syncPositionAfter { mustEqualPositions(t, pos, syncPositionAfter)
t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom want %v, got %v", syncPositionAfter, pos)
}
leaveWG.Done() leaveWG.Done()
}() }()
bobStream := lockedFetchUserStream(n, bob) bobStream := lockedFetchUserStream(n, bob)
@ -287,9 +261,7 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %w", err) t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom error: %w", err)
} }
if pos != syncPositionAfter2 { mustEqualPositions(t, pos, syncPositionAfter2)
t.Errorf("TestNewEventAndWasPreviouslyJoinedToRoom want %v, got %v", syncPositionAfter2, pos)
}
aliceWG.Done() aliceWG.Done()
}() }()
@ -312,13 +284,13 @@ func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) {
time.Sleep(1 * time.Millisecond) time.Sleep(1 * time.Millisecond)
} }
func waitForEvents(n *Notifier, req syncRequest) (types.PaginationToken, error) { func waitForEvents(n *Notifier, req syncRequest) (types.StreamingToken, error) {
listener := n.GetListener(req) listener := n.GetListener(req)
defer listener.Close() defer listener.Close()
select { select {
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):
return types.PaginationToken{}, fmt.Errorf( return types.StreamingToken{}, fmt.Errorf(
"waitForEvents timed out waiting for %s (pos=%v)", req.device.UserID, req.since, "waitForEvents timed out waiting for %s (pos=%v)", req.device.UserID, req.since,
) )
case <-listener.GetNotifyChannel(*req.since): case <-listener.GetNotifyChannel(*req.since):
@ -344,7 +316,7 @@ func lockedFetchUserStream(n *Notifier, userID string) *UserStream {
return n.fetchUserStream(userID, true) return n.fetchUserStream(userID, true)
} }
func newTestSyncRequest(userID string, since types.PaginationToken) syncRequest { func newTestSyncRequest(userID string, since types.StreamingToken) syncRequest {
return syncRequest{ return syncRequest{
device: authtypes.Device{UserID: userID}, device: authtypes.Device{UserID: userID},
timeout: 1 * time.Minute, timeout: 1 * time.Minute,

View file

@ -36,7 +36,7 @@ type syncRequest struct {
device authtypes.Device device authtypes.Device
limit int limit int
timeout time.Duration timeout time.Duration
since *types.PaginationToken // nil means that no since token was supplied since *types.StreamingToken // nil means that no since token was supplied
wantFullState bool wantFullState bool
log *log.Entry log *log.Entry
} }
@ -45,10 +45,15 @@ func newSyncRequest(req *http.Request, device authtypes.Device) (*syncRequest, e
timeout := getTimeout(req.URL.Query().Get("timeout")) timeout := getTimeout(req.URL.Query().Get("timeout"))
fullState := req.URL.Query().Get("full_state") fullState := req.URL.Query().Get("full_state")
wantFullState := fullState != "" && fullState != "false" wantFullState := fullState != "" && fullState != "false"
since, err := getPaginationToken(req.URL.Query().Get("since")) var since *types.StreamingToken
sinceStr := req.URL.Query().Get("since")
if sinceStr != "" {
tok, err := types.NewStreamTokenFromString(sinceStr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
since = &tok
}
// TODO: Additional query params: set_presence, filter // TODO: Additional query params: set_presence, filter
return &syncRequest{ return &syncRequest{
ctx: req.Context(), ctx: req.Context(),
@ -71,16 +76,3 @@ func getTimeout(timeoutMS string) time.Duration {
} }
return time.Duration(i) * time.Millisecond return time.Duration(i) * time.Millisecond
} }
// getSyncStreamPosition tries to parse a 'since' token taken from the API to a
// types.PaginationToken. If the string is empty then (nil, nil) is returned.
// There are two forms of tokens: The full length form containing all PDU and EDU
// positions separated by "_", and the short form containing only the PDU
// position. Short form can be used for, e.g., `prev_batch` tokens.
func getPaginationToken(since string) (*types.PaginationToken, error) {
if since == "" {
return nil, nil
}
return types.NewPaginationTokenFromString(since)
}

View file

@ -132,7 +132,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype
} }
} }
func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.PaginationToken) (res *types.Response, err error) { func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.StreamingToken) (res *types.Response, err error) {
// TODO: handle ignored users // TODO: handle ignored users
if req.since == nil { if req.since == nil {
res, err = rp.db.CompleteSync(req.ctx, req.device.UserID, req.limit) res, err = rp.db.CompleteSync(req.ctx, req.device.UserID, req.limit)
@ -145,7 +145,7 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Pagin
} }
accountDataFilter := gomatrixserverlib.DefaultEventFilter() // TODO: use filter provided in req instead accountDataFilter := gomatrixserverlib.DefaultEventFilter() // TODO: use filter provided in req instead
res, err = rp.appendAccountData(res, req.device.UserID, req, latestPos.PDUPosition, &accountDataFilter) res, err = rp.appendAccountData(res, req.device.UserID, req, latestPos.PDUPosition(), &accountDataFilter)
return return
} }
@ -187,7 +187,7 @@ func (rp *RequestPool) appendAccountData(
// Sync is not initial, get all account data since the latest sync // Sync is not initial, get all account data since the latest sync
dataTypes, err := rp.db.GetAccountDataInRange( dataTypes, err := rp.db.GetAccountDataInRange(
req.ctx, userID, req.ctx, userID,
types.StreamPosition(req.since.PDUPosition), types.StreamPosition(currentPos), types.StreamPosition(req.since.PDUPosition()), types.StreamPosition(currentPos),
accountDataFilter, accountDataFilter,
) )
if err != nil { if err != nil {

View file

@ -34,7 +34,7 @@ type UserStream struct {
// Closed when there is an update. // Closed when there is an update.
signalChannel chan struct{} signalChannel chan struct{}
// The last sync position that there may have been an update for the user // The last sync position that there may have been an update for the user
pos types.PaginationToken pos types.StreamingToken
// The last time when we had some listeners waiting // The last time when we had some listeners waiting
timeOfLastChannel time.Time timeOfLastChannel time.Time
// The number of listeners waiting // The number of listeners waiting
@ -50,7 +50,7 @@ type UserStreamListener struct {
} }
// NewUserStream creates a new user stream // NewUserStream creates a new user stream
func NewUserStream(userID string, currPos types.PaginationToken) *UserStream { func NewUserStream(userID string, currPos types.StreamingToken) *UserStream {
return &UserStream{ return &UserStream{
UserID: userID, UserID: userID,
timeOfLastChannel: time.Now(), timeOfLastChannel: time.Now(),
@ -83,7 +83,7 @@ func (s *UserStream) GetListener(ctx context.Context) UserStreamListener {
} }
// Broadcast a new sync position for this user. // Broadcast a new sync position for this user.
func (s *UserStream) Broadcast(pos types.PaginationToken) { func (s *UserStream) Broadcast(pos types.StreamingToken) {
s.lock.Lock() s.lock.Lock()
defer s.lock.Unlock() defer s.lock.Unlock()
@ -116,9 +116,9 @@ func (s *UserStream) TimeOfLastNonEmpty() time.Time {
return s.timeOfLastChannel return s.timeOfLastChannel
} }
// GetStreamPosition returns last sync position which the UserStream was // GetSyncPosition returns last sync position which the UserStream was
// notified about // notified about
func (s *UserStreamListener) GetSyncPosition() types.PaginationToken { func (s *UserStreamListener) GetSyncPosition() types.StreamingToken {
s.userStream.lock.Lock() s.userStream.lock.Lock()
defer s.userStream.lock.Unlock() defer s.userStream.lock.Unlock()
@ -130,7 +130,7 @@ func (s *UserStreamListener) GetSyncPosition() types.PaginationToken {
// sincePos specifies from which point we want to be notified about. If there // sincePos specifies from which point we want to be notified about. If there
// has already been an update after sincePos we'll return a closed channel // has already been an update after sincePos we'll return a closed channel
// immediately. // immediately.
func (s *UserStreamListener) GetNotifyChannel(sincePos types.PaginationToken) <-chan struct{} { func (s *UserStreamListener) GetNotifyChannel(sincePos types.StreamingToken) <-chan struct{} {
s.userStream.lock.Lock() s.userStream.lock.Lock()
defer s.userStream.lock.Unlock() defer s.userStream.lock.Unlock()

View file

@ -27,19 +27,19 @@ import (
) )
var ( var (
// ErrInvalidPaginationTokenType is returned when an attempt at creating a // ErrInvalidSyncTokenType is returned when an attempt at creating a
// new instance of PaginationToken with an invalid type (i.e. neither "s" // new instance of SyncToken with an invalid type (i.e. neither "s"
// nor "t"). // nor "t").
ErrInvalidPaginationTokenType = fmt.Errorf("Pagination token has an unknown prefix (should be either s or t)") ErrInvalidSyncTokenType = fmt.Errorf("Sync token has an unknown prefix (should be either s or t)")
// ErrInvalidPaginationTokenLen is returned when the pagination token is an // ErrInvalidSyncTokenLen is returned when the pagination token is an
// invalid length // invalid length
ErrInvalidPaginationTokenLen = fmt.Errorf("Pagination token has an invalid length") ErrInvalidSyncTokenLen = fmt.Errorf("Sync token has an invalid length")
) )
// StreamPosition represents the offset in the sync stream a client is at. // StreamPosition represents the offset in the sync stream a client is at.
type StreamPosition int64 type StreamPosition int64
// Same as gomatrixserverlib.Event but also has the PDU stream position for this event. // StreamEvent is the same as gomatrixserverlib.Event but also has the PDU stream position for this event.
type StreamEvent struct { type StreamEvent struct {
gomatrixserverlib.HeaderedEvent gomatrixserverlib.HeaderedEvent
StreamPosition StreamPosition StreamPosition StreamPosition
@ -47,118 +47,201 @@ type StreamEvent struct {
ExcludeFromSync bool ExcludeFromSync bool
} }
// PaginationTokenType represents the type of a pagination token. // SyncTokenType represents the type of a sync token.
// It can be either "s" (representing a position in the whole stream of events) // It can be either "s" (representing a position in the whole stream of events)
// or "t" (representing a position in a room's topology/depth). // or "t" (representing a position in a room's topology/depth).
type PaginationTokenType string type SyncTokenType string
const ( const (
// PaginationTokenTypeStream represents a position in the server's whole // SyncTokenTypeStream represents a position in the server's whole
// stream of events // stream of events
PaginationTokenTypeStream PaginationTokenType = "s" SyncTokenTypeStream SyncTokenType = "s"
// PaginationTokenTypeTopology represents a position in a room's topology. // SyncTokenTypeTopology represents a position in a room's topology.
PaginationTokenTypeTopology PaginationTokenType = "t" SyncTokenTypeTopology SyncTokenType = "t"
) )
// PaginationToken represents a pagination token, used for interactions with type StreamingToken struct {
// /sync or /messages, for example. syncToken
type PaginationToken struct {
//Position StreamPosition
Type PaginationTokenType
// For /sync, this is the PDU position. For /messages, this is the topological position (depth).
// TODO: Given how different the positions are depending on the token type, they should probably be renamed
// or use different structs altogether.
PDUPosition StreamPosition
// For /sync, this is the EDU position. For /messages, this is the stream (PDU) position.
// TODO: Given how different the positions are depending on the token type, they should probably be renamed
// or use different structs altogether.
EDUTypingPosition StreamPosition
} }
// NewPaginationTokenFromString takes a string of the form "xyyyy..." where "x" func (t *StreamingToken) PDUPosition() StreamPosition {
// represents the type of a pagination token and "yyyy..." the token itself, and return t.Positions[0]
// parses it in order to create a new instance of PaginationToken. Returns an }
// error if the token couldn't be parsed into an int64, or if the token type func (t *StreamingToken) EDUPosition() StreamPosition {
// isn't a known type (returns ErrInvalidPaginationTokenType in the latter return t.Positions[1]
// case).
func NewPaginationTokenFromString(s string) (token *PaginationToken, err error) {
if len(s) == 0 {
return nil, ErrInvalidPaginationTokenLen
}
token = new(PaginationToken)
var positions []string
switch t := PaginationTokenType(s[:1]); t {
case PaginationTokenTypeStream, PaginationTokenTypeTopology:
token.Type = t
positions = strings.Split(s[1:], "_")
default:
token.Type = PaginationTokenTypeStream
positions = strings.Split(s, "_")
}
// Try to get the PDU position.
if len(positions) >= 1 {
if pduPos, err := strconv.ParseInt(positions[0], 10, 64); err != nil {
return nil, err
} else if pduPos < 0 {
return nil, errors.New("negative PDU position not allowed")
} else {
token.PDUPosition = StreamPosition(pduPos)
}
}
// Try to get the typing position.
if len(positions) >= 2 {
if typPos, err := strconv.ParseInt(positions[1], 10, 64); err != nil {
return nil, err
} else if typPos < 0 {
return nil, errors.New("negative EDU typing position not allowed")
} else {
token.EDUTypingPosition = StreamPosition(typPos)
}
}
return
} }
// NewPaginationTokenFromTypeAndPosition takes a PaginationTokenType and a // IsAfter returns true if ANY position in this token is greater than `other`.
// StreamPosition and returns an instance of PaginationToken. func (t *StreamingToken) IsAfter(other StreamingToken) bool {
func NewPaginationTokenFromTypeAndPosition( for i := range other.Positions {
t PaginationTokenType, pdupos StreamPosition, typpos StreamPosition, if t.Positions[i] > other.Positions[i] {
) (p *PaginationToken) { return true
return &PaginationToken{
Type: t,
PDUPosition: pdupos,
EDUTypingPosition: typpos,
} }
}
return false
} }
// String translates a PaginationToken to a string of the "xyyyy..." (see // WithUpdates returns a copy of the StreamingToken with updates applied from another StreamingToken.
// NewPaginationToken to know what it represents). // If the latter StreamingToken contains a field that is not 0, it is considered an update,
func (p *PaginationToken) String() string { // and its value will replace the corresponding value in the StreamingToken on which WithUpdates is called.
return fmt.Sprintf("%s%d_%d", p.Type, p.PDUPosition, p.EDUTypingPosition) func (t *StreamingToken) WithUpdates(other StreamingToken) (ret StreamingToken) {
} ret.Type = t.Type
ret.Positions = make([]StreamPosition, len(t.Positions))
// WithUpdates returns a copy of the PaginationToken with updates applied from another PaginationToken. for i := range t.Positions {
// If the latter PaginationToken contains a field that is not 0, it is considered an update, ret.Positions[i] = t.Positions[i]
// and its value will replace the corresponding value in the PaginationToken on which WithUpdates is called. if other.Positions[i] == 0 {
func (pt *PaginationToken) WithUpdates(other PaginationToken) PaginationToken { continue
ret := *pt
if other.PDUPosition != 0 {
ret.PDUPosition = other.PDUPosition
} }
if other.EDUTypingPosition != 0 { ret.Positions[i] = other.Positions[i]
ret.EDUTypingPosition = other.EDUTypingPosition
} }
return ret return ret
} }
// IsAfter returns whether one PaginationToken refers to states newer than another PaginationToken. type TopologyToken struct {
func (sp *PaginationToken) IsAfter(other PaginationToken) bool { syncToken
return sp.PDUPosition > other.PDUPosition || }
sp.EDUTypingPosition > other.EDUTypingPosition
func (t *TopologyToken) Depth() StreamPosition {
return t.Positions[0]
}
func (t *TopologyToken) PDUPosition() StreamPosition {
return t.Positions[1]
}
func (t *TopologyToken) StreamToken() StreamingToken {
return NewStreamToken(t.PDUPosition(), 0)
}
func (t *TopologyToken) String() string {
return t.syncToken.String()
}
// Decrement the topology token to one event earlier.
func (t *TopologyToken) Decrement() {
depth := t.Positions[0]
pduPos := t.Positions[1]
if depth-1 <= 0 {
depth = 1
} else {
depth--
pduPos += 1000
}
// The lowest token value is 1, therefore we need to manually set it to that
// value if we're below it.
if depth < 1 {
depth = 1
}
t.Positions = []StreamPosition{
depth, pduPos,
}
}
// NewSyncTokenFromString takes a string of the form "xyyyy..." where "x"
// represents the type of a pagination token and "yyyy..." the token itself, and
// parses it in order to create a new instance of SyncToken. Returns an
// error if the token couldn't be parsed into an int64, or if the token type
// isn't a known type (returns ErrInvalidSyncTokenType in the latter
// case).
func newSyncTokenFromString(s string) (token *syncToken, err error) {
if len(s) == 0 {
return nil, ErrInvalidSyncTokenLen
}
token = new(syncToken)
var positions []string
switch t := SyncTokenType(s[:1]); t {
case SyncTokenTypeStream, SyncTokenTypeTopology:
token.Type = t
positions = strings.Split(s[1:], "_")
default:
return nil, ErrInvalidSyncTokenType
}
for _, pos := range positions {
if posInt, err := strconv.ParseInt(pos, 10, 64); err != nil {
return nil, err
} else if posInt < 0 {
return nil, errors.New("negative position not allowed")
} else {
token.Positions = append(token.Positions, StreamPosition(posInt))
}
}
return
}
// NewTopologyToken creates a new sync token for /messages
func NewTopologyToken(depth, streamPos StreamPosition) TopologyToken {
if depth < 0 {
depth = 1
}
return TopologyToken{
syncToken: syncToken{
Type: SyncTokenTypeTopology,
Positions: []StreamPosition{depth, streamPos},
},
}
}
func NewTopologyTokenFromString(tok string) (token TopologyToken, err error) {
t, err := newSyncTokenFromString(tok)
if err != nil {
return
}
if t.Type != SyncTokenTypeTopology {
err = fmt.Errorf("token %s is not a topology token", tok)
return
}
if len(t.Positions) != 2 {
err = fmt.Errorf("token %s wrong number of values, got %d want 2", tok, len(t.Positions))
return
}
return TopologyToken{
syncToken: *t,
}, nil
}
// NewStreamToken creates a new sync token for /sync
func NewStreamToken(pduPos, eduPos StreamPosition) StreamingToken {
return StreamingToken{
syncToken: syncToken{
Type: SyncTokenTypeStream,
Positions: []StreamPosition{pduPos, eduPos},
},
}
}
func NewStreamTokenFromString(tok string) (token StreamingToken, err error) {
t, err := newSyncTokenFromString(tok)
if err != nil {
return
}
if t.Type != SyncTokenTypeStream {
err = fmt.Errorf("token %s is not a streaming token", tok)
return
}
if len(t.Positions) != 2 {
err = fmt.Errorf("token %s wrong number of values, got %d want 2", tok, len(t.Positions))
return
}
return StreamingToken{
syncToken: *t,
}, nil
}
// syncToken represents a syncapi token, used for interactions with
// /sync or /messages, for example.
type syncToken struct {
Type SyncTokenType
// A list of stream positions, their meanings vary depending on the token type.
Positions []StreamPosition
}
// String translates a SyncToken to a string of the "xyyyy..." (see
// NewSyncToken to know what it represents).
func (p *syncToken) String() string {
posStr := make([]string, len(p.Positions))
for i := range p.Positions {
posStr[i] = strconv.FormatInt(int64(p.Positions[i]), 10)
}
return fmt.Sprintf("%s%s", p.Type, strings.Join(posStr, "_"))
} }
// PrevEventRef represents a reference to a previous event in a state event upgrade // PrevEventRef represents a reference to a previous event in a state event upgrade
@ -185,7 +268,7 @@ type Response struct {
} }
// NewResponse creates an empty response with initialised maps. // NewResponse creates an empty response with initialised maps.
func NewResponse(token PaginationToken) *Response { func NewResponse(token StreamingToken) *Response {
res := Response{ res := Response{
NextBatch: token.String(), NextBatch: token.String(),
} }
@ -202,14 +285,6 @@ func NewResponse(token PaginationToken) *Response {
res.AccountData.Events = make([]gomatrixserverlib.ClientEvent, 0) res.AccountData.Events = make([]gomatrixserverlib.ClientEvent, 0)
res.Presence.Events = make([]gomatrixserverlib.ClientEvent, 0) res.Presence.Events = make([]gomatrixserverlib.ClientEvent, 0)
// Fill next_batch with a pagination token. Since this is a response to a sync request, we can assume
// we'll always return a stream token.
res.NextBatch = NewPaginationTokenFromTypeAndPosition(
PaginationTokenTypeStream,
StreamPosition(token.PDUPosition),
StreamPosition(token.EDUTypingPosition),
).String()
return &res return &res
} }

View file

@ -2,26 +2,11 @@ package types
import "testing" import "testing"
func TestNewPaginationTokenFromString(t *testing.T) { func TestNewSyncTokenFromString(t *testing.T) {
shouldPass := map[string]PaginationToken{ shouldPass := map[string]syncToken{
"2": PaginationToken{ "s4_0": NewStreamToken(4, 0).syncToken,
Type: PaginationTokenTypeStream, "s3_1": NewStreamToken(3, 1).syncToken,
PDUPosition: 2, "t3_1": NewTopologyToken(3, 1).syncToken,
},
"s4": PaginationToken{
Type: PaginationTokenTypeStream,
PDUPosition: 4,
},
"s3_1": PaginationToken{
Type: PaginationTokenTypeStream,
PDUPosition: 3,
EDUTypingPosition: 1,
},
"t3_1_4": PaginationToken{
Type: PaginationTokenTypeTopology,
PDUPosition: 3,
EDUTypingPosition: 1,
},
} }
shouldFail := []string{ shouldFail := []string{
@ -32,20 +17,21 @@ func TestNewPaginationTokenFromString(t *testing.T) {
"b", "b",
"b-1", "b-1",
"-4", "-4",
"2",
} }
for test, expected := range shouldPass { for test, expected := range shouldPass {
result, err := NewPaginationTokenFromString(test) result, err := newSyncTokenFromString(test)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
if *result != expected { if result.String() != expected.String() {
t.Errorf("expected %v but got %v", expected.String(), result.String()) t.Errorf("%s expected %v but got %v", test, expected.String(), result.String())
} }
} }
for _, test := range shouldFail { for _, test := range shouldFail {
if _, err := NewPaginationTokenFromString(test); err == nil { if _, err := newSyncTokenFromString(test); err == nil {
t.Errorf("input '%v' should have errored but didn't", test) t.Errorf("input '%v' should have errored but didn't", test)
} }
} }