From 43308d2f3f8fcf9bdb3ec55d4e679b576cc19488 Mon Sep 17 00:00:00 2001 From: Alex Chen Date: Sat, 24 Aug 2019 00:55:40 +0800 Subject: [PATCH] Associate transactions with session IDs instead of device IDs (#789) --- clientapi/auth/authtypes/device.go | 4 ++++ .../auth/storage/devices/devices_table.go | 19 +++++++++++++---- clientapi/routing/sendevent.go | 8 +++---- roomserver/api/input.go | 4 ++-- roomserver/input/events.go | 6 +++--- roomserver/storage/storage.go | 12 +++++------ roomserver/storage/transactions_table.go | 18 ++++++++-------- syncapi/storage/output_room_events_table.go | 21 ++++++++++--------- syncapi/storage/syncserver.go | 2 +- 9 files changed, 55 insertions(+), 39 deletions(-) diff --git a/clientapi/auth/authtypes/device.go b/clientapi/auth/authtypes/device.go index a6d3a7b08..930ab3956 100644 --- a/clientapi/auth/authtypes/device.go +++ b/clientapi/auth/authtypes/device.go @@ -21,5 +21,9 @@ type Device struct { // The access_token granted to this device. // This uniquely identifies the device from all other devices and clients. AccessToken string + // The unique ID of the session identified by the access token. + // Can be used as a secure substitution in places where data needs to be + // associated with access tokens. + SessionID int64 // TODO: display name, last used timestamp, keys, etc } diff --git a/clientapi/auth/storage/devices/devices_table.go b/clientapi/auth/storage/devices/devices_table.go index 60aa563a2..d011d25c9 100644 --- a/clientapi/auth/storage/devices/devices_table.go +++ b/clientapi/auth/storage/devices/devices_table.go @@ -27,11 +27,19 @@ import ( ) const devicesSchema = ` +-- This sequence is used for automatic allocation of session_id. +CREATE SEQUENCE IF NOT EXISTS device_session_id_seq START 1; + -- Stores data about devices. CREATE TABLE IF NOT EXISTS device_devices ( -- The access token granted to this device. This has to be the primary key -- so we can distinguish which device is making a given request. access_token TEXT NOT NULL PRIMARY KEY, + -- The auto-allocated unique ID of the session identified by the access token. + -- This can be used as a secure substitution of the access token in situations + -- where data is associated with access tokens (e.g. transaction storage), + -- so we don't have to store users' access tokens everywhere. + session_id BIGINT NOT NULL DEFAULT nextval('device_session_id_seq'), -- The device identifier. This only needs to uniquely identify a device for a given user, not globally. -- access_tokens will be clobbered based on the device ID for a user. device_id TEXT NOT NULL, @@ -51,10 +59,11 @@ CREATE UNIQUE INDEX IF NOT EXISTS device_localpart_id_idx ON device_devices(loca ` const insertDeviceSQL = "" + - "INSERT INTO device_devices(device_id, localpart, access_token, created_ts, display_name) VALUES ($1, $2, $3, $4, $5)" + "INSERT INTO device_devices(device_id, localpart, access_token, created_ts, display_name) VALUES ($1, $2, $3, $4, $5)" + + " RETURNING session_id" const selectDeviceByTokenSQL = "" + - "SELECT device_id, localpart FROM device_devices WHERE access_token = $1" + "SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1" const selectDeviceByIDSQL = "" + "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2" @@ -120,14 +129,16 @@ func (s *devicesStatements) insertDevice( displayName *string, ) (*authtypes.Device, error) { createdTimeMS := time.Now().UnixNano() / 1000000 + var sessionID int64 stmt := common.TxStmt(txn, s.insertDeviceStmt) - if _, err := stmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName); err != nil { + if err := stmt.QueryRowContext(ctx, id, localpart, accessToken, createdTimeMS, displayName).Scan(&sessionID); err != nil { return nil, err } return &authtypes.Device{ ID: id, UserID: userutil.MakeUserID(localpart, s.serverName), AccessToken: accessToken, + SessionID: sessionID, }, nil } @@ -161,7 +172,7 @@ func (s *devicesStatements) selectDeviceByToken( var dev authtypes.Device var localpart string stmt := s.selectDeviceByTokenStmt - err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.ID, &localpart) + err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart) if err == nil { dev.UserID = userutil.MakeUserID(localpart, s.serverName) dev.AccessToken = accessToken diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index 9696b360e..76e36cd46 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -60,18 +60,18 @@ func SendEvent( return *resErr } - var txnAndDeviceID *api.TransactionID + var txnAndSessionID *api.TransactionID if txnID != nil { - txnAndDeviceID = &api.TransactionID{ + txnAndSessionID = &api.TransactionID{ TransactionID: *txnID, - DeviceID: device.ID, + SessionID: device.SessionID, } } // pass the new event to the roomserver and receive the correct event ID // event ID in case of duplicate transaction is discarded eventID, err := producer.SendEvents( - req.Context(), []gomatrixserverlib.Event{*e}, cfg.Matrix.ServerName, txnAndDeviceID, + req.Context(), []gomatrixserverlib.Event{*e}, cfg.Matrix.ServerName, txnAndSessionID, ) if err != nil { return httputil.LogThenError(req, err) diff --git a/roomserver/api/input.go b/roomserver/api/input.go index 2c2e27c62..9643a927c 100644 --- a/roomserver/api/input.go +++ b/roomserver/api/input.go @@ -75,9 +75,9 @@ type InputRoomEvent struct { } // TransactionID contains the transaction ID sent by a client when sending an -// event, along with the ID of that device. +// event, along with the ID of the client session. type TransactionID struct { - DeviceID string `json:"device_id"` + SessionID int64 `json:"session_id"` TransactionID string `json:"id"` } diff --git a/roomserver/input/events.go b/roomserver/input/events.go index feb15b3e1..b30c39928 100644 --- a/roomserver/input/events.go +++ b/roomserver/input/events.go @@ -32,7 +32,7 @@ type RoomEventDatabase interface { StoreEvent( ctx context.Context, event gomatrixserverlib.Event, - txnAndDeviceID *api.TransactionID, + txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID, ) (types.RoomNID, types.StateAtEvent, error) // Look up the state entries for a list of string event IDs @@ -67,7 +67,7 @@ type RoomEventDatabase interface { // Returns an empty string if no such event exists. GetTransactionEventID( ctx context.Context, transactionID string, - deviceID string, userID string, + sessionID int64, userID string, ) (string, error) } @@ -100,7 +100,7 @@ func processRoomEvent( if input.TransactionID != nil { tdID := input.TransactionID eventID, err = db.GetTransactionEventID( - ctx, tdID.TransactionID, tdID.DeviceID, input.Event.Sender(), + ctx, tdID.TransactionID, tdID.SessionID, input.Event.Sender(), ) // On error OR event with the transaction already processed/processesing if err != nil || eventID != "" { diff --git a/roomserver/storage/storage.go b/roomserver/storage/storage.go index 71c13b7ca..7e8eb98c9 100644 --- a/roomserver/storage/storage.go +++ b/roomserver/storage/storage.go @@ -47,7 +47,7 @@ func Open(dataSourceName string) (*Database, error) { // StoreEvent implements input.EventDatabase func (d *Database) StoreEvent( ctx context.Context, event gomatrixserverlib.Event, - txnAndDeviceID *api.TransactionID, authEventNIDs []types.EventNID, + txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID, ) (types.RoomNID, types.StateAtEvent, error) { var ( roomNID types.RoomNID @@ -58,10 +58,10 @@ func (d *Database) StoreEvent( err error ) - if txnAndDeviceID != nil { + if txnAndSessionID != nil { if err = d.statements.insertTransaction( - ctx, txnAndDeviceID.TransactionID, - txnAndDeviceID.DeviceID, event.Sender(), event.EventID(), + ctx, txnAndSessionID.TransactionID, + txnAndSessionID.SessionID, event.Sender(), event.EventID(), ); err != nil { return 0, types.StateAtEvent{}, err } @@ -322,9 +322,9 @@ func (d *Database) GetLatestEventsForUpdate( // GetTransactionEventID implements input.EventDatabase func (d *Database) GetTransactionEventID( ctx context.Context, transactionID string, - deviceID string, userID string, + sessionID int64, userID string, ) (string, error) { - eventID, err := d.statements.selectTransactionEventID(ctx, transactionID, deviceID, userID) + eventID, err := d.statements.selectTransactionEventID(ctx, transactionID, sessionID, userID) if err == sql.ErrNoRows { return "", nil } diff --git a/roomserver/storage/transactions_table.go b/roomserver/storage/transactions_table.go index e9c904cc8..b98ea3f33 100644 --- a/roomserver/storage/transactions_table.go +++ b/roomserver/storage/transactions_table.go @@ -23,8 +23,8 @@ const transactionsSchema = ` CREATE TABLE IF NOT EXISTS roomserver_transactions ( -- The transaction ID of the event. transaction_id TEXT NOT NULL, - -- The device ID of the originating transaction. - device_id TEXT NOT NULL, + -- The session ID of the originating transaction. + session_id BIGINT NOT NULL, -- User ID of the sender who authored the event user_id TEXT NOT NULL, -- Event ID corresponding to the transaction @@ -32,16 +32,16 @@ CREATE TABLE IF NOT EXISTS roomserver_transactions ( event_id TEXT NOT NULL, -- A transaction ID is unique for a user and device -- This automatically creates an index. - PRIMARY KEY (transaction_id, device_id, user_id) + PRIMARY KEY (transaction_id, session_id, user_id) ); ` const insertTransactionSQL = "" + - "INSERT INTO roomserver_transactions (transaction_id, device_id, user_id, event_id)" + + "INSERT INTO roomserver_transactions (transaction_id, session_id, user_id, event_id)" + " VALUES ($1, $2, $3, $4)" const selectTransactionEventIDSQL = "" + "SELECT event_id FROM roomserver_transactions" + - " WHERE transaction_id = $1 AND device_id = $2 AND user_id = $3" + " WHERE transaction_id = $1 AND session_id = $2 AND user_id = $3" type transactionStatements struct { insertTransactionStmt *sql.Stmt @@ -63,12 +63,12 @@ func (s *transactionStatements) prepare(db *sql.DB) (err error) { func (s *transactionStatements) insertTransaction( ctx context.Context, transactionID string, - deviceID string, + sessionID int64, userID string, eventID string, ) (err error) { _, err = s.insertTransactionStmt.ExecContext( - ctx, transactionID, deviceID, userID, eventID, + ctx, transactionID, sessionID, userID, eventID, ) return } @@ -76,11 +76,11 @@ func (s *transactionStatements) insertTransaction( func (s *transactionStatements) selectTransactionEventID( ctx context.Context, transactionID string, - deviceID string, + sessionID int64, userID string, ) (eventID string, err error) { err = s.selectTransactionEventIDStmt.QueryRowContext( - ctx, transactionID, deviceID, userID, + ctx, transactionID, sessionID, userID, ).Scan(&eventID) return } diff --git a/syncapi/storage/output_room_events_table.go b/syncapi/storage/output_room_events_table.go index 8fbeb18c9..2df2a96a1 100644 --- a/syncapi/storage/output_room_events_table.go +++ b/syncapi/storage/output_room_events_table.go @@ -54,7 +54,7 @@ CREATE TABLE IF NOT EXISTS syncapi_output_room_events ( -- if there is no delta. add_state_ids TEXT[], remove_state_ids TEXT[], - device_id TEXT, -- The local device that sent the event, if any + session_id BIGINT, -- The client session that sent the event, if any transaction_id TEXT -- The transaction id used to send the event, if any ); -- for event selection @@ -63,14 +63,14 @@ CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_output_room_ev const insertEventSQL = "" + "INSERT INTO syncapi_output_room_events (" + - "room_id, event_id, event_json, type, sender, contains_url, add_state_ids, remove_state_ids, device_id, transaction_id" + + "room_id, event_id, event_json, type, sender, contains_url, add_state_ids, remove_state_ids, session_id, transaction_id" + ") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id" const selectEventsSQL = "" + "SELECT id, event_json FROM syncapi_output_room_events WHERE event_id = ANY($1)" const selectRecentEventsSQL = "" + - "SELECT id, event_json, device_id, transaction_id FROM syncapi_output_room_events" + + "SELECT id, event_json, session_id, transaction_id FROM syncapi_output_room_events" + " WHERE room_id = $1 AND id > $2 AND id <= $3" + " ORDER BY id DESC LIMIT $4" @@ -221,9 +221,10 @@ func (s *outputRoomEventsStatements) insertEvent( event *gomatrixserverlib.Event, addState, removeState []string, transactionID *api.TransactionID, ) (streamPos int64, err error) { - var deviceID, txnID *string + var txnID *string + var sessionID *int64 if transactionID != nil { - deviceID = &transactionID.DeviceID + sessionID = &transactionID.SessionID txnID = &transactionID.TransactionID } @@ -246,7 +247,7 @@ func (s *outputRoomEventsStatements) insertEvent( containsURL, pq.StringArray(addState), pq.StringArray(removeState), - deviceID, + sessionID, txnID, ).Scan(&streamPos) return @@ -296,11 +297,11 @@ func rowsToStreamEvents(rows *sql.Rows) ([]streamEvent, error) { var ( streamPos int64 eventBytes []byte - deviceID *string + sessionID *int64 txnID *string transactionID *api.TransactionID ) - if err := rows.Scan(&streamPos, &eventBytes, &deviceID, &txnID); err != nil { + if err := rows.Scan(&streamPos, &eventBytes, &sessionID, &txnID); err != nil { return nil, err } // TODO: Handle redacted events @@ -309,9 +310,9 @@ func rowsToStreamEvents(rows *sql.Rows) ([]streamEvent, error) { return nil, err } - if deviceID != nil && txnID != nil { + if sessionID != nil && txnID != nil { transactionID = &api.TransactionID{ - DeviceID: *deviceID, + SessionID: *sessionID, TransactionID: *txnID, } } diff --git a/syncapi/storage/syncserver.go b/syncapi/storage/syncserver.go index fb883702c..cda44d2e3 100644 --- a/syncapi/storage/syncserver.go +++ b/syncapi/storage/syncserver.go @@ -893,7 +893,7 @@ func streamEventsToEvents(device *authtypes.Device, in []streamEvent) []gomatrix for i := 0; i < len(in); i++ { out[i] = in[i].Event if device != nil && in[i].transactionID != nil { - if device.UserID == in[i].Sender() && device.ID == in[i].transactionID.DeviceID { + if device.UserID == in[i].Sender() && device.SessionID == in[i].transactionID.SessionID { err := out[i].SetUnsignedField( "transaction_id", in[i].transactionID.TransactionID, )