0
0
Fork 0
mirror of https://github.com/matrix-org/dendrite synced 2024-12-13 15:53:16 +01:00

Use TransactionWriter in roomserver SQLite (#1208)

This commit is contained in:
Neil Alexander 2020-07-21 10:48:49 +01:00 committed by GitHub
parent 489f34fed7
commit d76eb1b994
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 259 additions and 148 deletions

View file

@ -49,13 +49,16 @@ const bulkSelectEventJSONSQL = `
type eventJSONStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
insertEventJSONStmt *sql.Stmt
bulkSelectEventJSONStmt *sql.Stmt
}
func NewSqliteEventJSONTable(db *sql.DB) (tables.EventJSON, error) {
s := &eventJSONStatements{}
s.db = db
s := &eventJSONStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
}
_, err := db.Exec(eventJSONSchema)
if err != nil {
return nil, err
@ -69,8 +72,10 @@ func NewSqliteEventJSONTable(db *sql.DB) (tables.EventJSON, error) {
func (s *eventJSONStatements) InsertEventJSON(
ctx context.Context, txn *sql.Tx, eventNID types.EventNID, eventJSON []byte,
) error {
_, err := sqlutil.TxStmt(txn, s.insertEventJSONStmt).ExecContext(ctx, int64(eventNID), eventJSON)
return err
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
_, err := sqlutil.TxStmt(txn, s.insertEventJSONStmt).ExecContext(ctx, int64(eventNID), eventJSON)
return err
})
}
func (s *eventJSONStatements) BulkSelectEventJSON(

View file

@ -64,6 +64,7 @@ const bulkSelectEventStateKeyNIDSQL = `
type eventStateKeyStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
insertEventStateKeyNIDStmt *sql.Stmt
selectEventStateKeyNIDStmt *sql.Stmt
bulkSelectEventStateKeyNIDStmt *sql.Stmt
@ -71,8 +72,10 @@ type eventStateKeyStatements struct {
}
func NewSqliteEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) {
s := &eventStateKeyStatements{}
s.db = db
s := &eventStateKeyStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
}
_, err := db.Exec(eventStateKeysSchema)
if err != nil {
return nil, err
@ -89,12 +92,18 @@ func (s *eventStateKeyStatements) InsertEventStateKeyNID(
ctx context.Context, txn *sql.Tx, eventStateKey string,
) (types.EventStateKeyNID, error) {
var eventStateKeyNID int64
var err error
var res sql.Result
insertStmt := sqlutil.TxStmt(txn, s.insertEventStateKeyNIDStmt)
if res, err = insertStmt.ExecContext(ctx, eventStateKey); err == nil {
err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
insertStmt := sqlutil.TxStmt(txn, s.insertEventStateKeyNIDStmt)
res, err := insertStmt.ExecContext(ctx, eventStateKey)
if err != nil {
return err
}
eventStateKeyNID, err = res.LastInsertId()
}
if err != nil {
return err
}
return nil
})
return types.EventStateKeyNID(eventStateKeyNID), err
}

View file

@ -78,6 +78,7 @@ const bulkSelectEventTypeNIDSQL = `
type eventTypeStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
insertEventTypeNIDStmt *sql.Stmt
insertEventTypeNIDResultStmt *sql.Stmt
selectEventTypeNIDStmt *sql.Stmt
@ -85,8 +86,10 @@ type eventTypeStatements struct {
}
func NewSqliteEventTypesTable(db *sql.DB) (tables.EventTypes, error) {
s := &eventTypeStatements{}
s.db = db
s := &eventTypeStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
}
_, err := db.Exec(eventTypesSchema)
if err != nil {
return nil, err
@ -104,12 +107,15 @@ func (s *eventTypeStatements) InsertEventTypeNID(
ctx context.Context, tx *sql.Tx, eventType string,
) (types.EventTypeNID, error) {
var eventTypeNID int64
var err error
insertStmt := sqlutil.TxStmt(tx, s.insertEventTypeNIDStmt)
resultStmt := sqlutil.TxStmt(tx, s.insertEventTypeNIDResultStmt)
if _, err = insertStmt.ExecContext(ctx, eventType); err == nil {
err = resultStmt.QueryRowContext(ctx).Scan(&eventTypeNID)
}
err := s.writer.Do(s.db, tx, func(tx *sql.Tx) error {
insertStmt := sqlutil.TxStmt(tx, s.insertEventTypeNIDStmt)
resultStmt := sqlutil.TxStmt(tx, s.insertEventTypeNIDResultStmt)
_, err := insertStmt.ExecContext(ctx, eventType)
if err != nil {
return err
}
return resultStmt.QueryRowContext(ctx).Scan(&eventTypeNID)
})
return types.EventTypeNID(eventTypeNID), err
}

View file

@ -99,6 +99,7 @@ const selectRoomNIDForEventNIDSQL = "" +
type eventStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
insertEventStmt *sql.Stmt
selectEventStmt *sql.Stmt
bulkSelectStateEventByIDStmt *sql.Stmt
@ -115,8 +116,10 @@ type eventStatements struct {
}
func NewSqliteEventsTable(db *sql.DB) (tables.Events, error) {
s := &eventStatements{}
s.db = db
s := &eventStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
}
_, err := db.Exec(eventsSchema)
if err != nil {
return nil, err
@ -151,19 +154,23 @@ func (s *eventStatements) InsertEvent(
depth int64,
) (types.EventNID, types.StateSnapshotNID, error) {
// attempt to insert: the last_row_id is the event NID
insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt)
result, err := insertStmt.ExecContext(
ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID),
eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth,
)
if err != nil {
return 0, 0, err
}
modified, err := result.RowsAffected()
if modified == 0 && err == nil {
return 0, 0, sql.ErrNoRows
}
eventNID, err := result.LastInsertId()
var eventNID int64
err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt)
result, err := insertStmt.ExecContext(
ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID),
eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth,
)
if err != nil {
return err
}
modified, err := result.RowsAffected()
if modified == 0 && err == nil {
return sql.ErrNoRows
}
eventNID, err = result.LastInsertId()
return err
})
return types.EventNID(eventNID), 0, err
}
@ -279,8 +286,10 @@ func (s *eventStatements) BulkSelectStateAtEventByID(
func (s *eventStatements) UpdateEventState(
ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
) error {
_, err := s.updateEventStateStmt.ExecContext(ctx, int64(stateNID), int64(eventNID))
return err
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
_, err := s.updateEventStateStmt.ExecContext(ctx, int64(stateNID), int64(eventNID))
return err
})
}
func (s *eventStatements) SelectEventSentToOutput(
@ -288,17 +297,15 @@ func (s *eventStatements) SelectEventSentToOutput(
) (sentToOutput bool, err error) {
selectStmt := sqlutil.TxStmt(txn, s.selectEventSentToOutputStmt)
err = selectStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&sentToOutput)
//err = s.selectEventSentToOutputStmt.QueryRowContext(ctx, int64(eventNID)).Scan(&sentToOutput)
if err != nil {
}
return
}
func (s *eventStatements) UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error {
updateStmt := sqlutil.TxStmt(txn, s.updateEventSentToOutputStmt)
_, err := updateStmt.ExecContext(ctx, int64(eventNID))
//_, err := s.updateEventSentToOutputStmt.ExecContext(ctx, int64(eventNID))
return err
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
updateStmt := sqlutil.TxStmt(txn, s.updateEventSentToOutputStmt)
_, err := updateStmt.ExecContext(ctx, int64(eventNID))
return err
})
}
func (s *eventStatements) SelectEventID(

View file

@ -63,6 +63,8 @@ SELECT invite_event_id FROM roomserver_invites WHERE room_nid = $1 AND target_ni
`
type inviteStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
insertInviteEventStmt *sql.Stmt
selectInviteActiveForUserInRoomStmt *sql.Stmt
updateInviteRetiredStmt *sql.Stmt
@ -70,7 +72,10 @@ type inviteStatements struct {
}
func NewSqliteInvitesTable(db *sql.DB) (tables.Invites, error) {
s := &inviteStatements{}
s := &inviteStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
}
_, err := db.Exec(inviteSchema)
if err != nil {
return nil, err
@ -90,42 +95,48 @@ func (s *inviteStatements) InsertInviteEvent(
targetUserNID, senderUserNID types.EventStateKeyNID,
inviteEventJSON []byte,
) (bool, error) {
stmt := sqlutil.TxStmt(txn, s.insertInviteEventStmt)
result, err := stmt.ExecContext(
ctx, inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON,
)
if err != nil {
return false, err
}
count, err := result.RowsAffected()
if err != nil {
return false, err
}
return count != 0, nil
var count int64
err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.insertInviteEventStmt)
result, err := stmt.ExecContext(
ctx, inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON,
)
if err != nil {
return err
}
count, err = result.RowsAffected()
if err != nil {
return err
}
return nil
})
return count != 0, err
}
func (s *inviteStatements) UpdateInviteRetired(
ctx context.Context,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (eventIDs []string, err error) {
// gather all the event IDs we will retire
stmt := sqlutil.TxStmt(txn, s.selectInvitesAboutToRetireStmt)
rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID)
if err != nil {
return nil, err
}
defer (func() { err = rows.Close() })()
for rows.Next() {
var inviteEventID string
if err = rows.Scan(&inviteEventID); err != nil {
return nil, err
err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
// gather all the event IDs we will retire
stmt := sqlutil.TxStmt(txn, s.selectInvitesAboutToRetireStmt)
rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID)
if err != nil {
return err
}
eventIDs = append(eventIDs, inviteEventID)
}
// now retire the invites
stmt = sqlutil.TxStmt(txn, s.updateInviteRetiredStmt)
_, err = stmt.ExecContext(ctx, roomNID, targetUserNID)
defer (func() { err = rows.Close() })()
for rows.Next() {
var inviteEventID string
if err = rows.Scan(&inviteEventID); err != nil {
return err
}
eventIDs = append(eventIDs, inviteEventID)
}
// now retire the invites
stmt = sqlutil.TxStmt(txn, s.updateInviteRetiredStmt)
_, err = stmt.ExecContext(ctx, roomNID, targetUserNID)
return err
})
return
}

View file

@ -76,6 +76,8 @@ const updateMembershipSQL = "" +
" WHERE room_nid = $4 AND target_nid = $5"
type membershipStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
insertMembershipStmt *sql.Stmt
selectMembershipForUpdateStmt *sql.Stmt
selectMembershipFromRoomAndTargetStmt *sql.Stmt
@ -87,7 +89,10 @@ type membershipStatements struct {
}
func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) {
s := &membershipStatements{}
s := &membershipStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
}
_, err := db.Exec(membershipSchema)
if err != nil {
return nil, err
@ -110,9 +115,11 @@ func (s *membershipStatements) InsertMembership(
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
localTarget bool,
) error {
stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt)
_, err := stmt.ExecContext(ctx, roomNID, targetUserNID, localTarget)
return err
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt)
_, err := stmt.ExecContext(ctx, roomNID, targetUserNID, localTarget)
return err
})
}
func (s *membershipStatements) SelectMembershipForUpdate(
@ -194,9 +201,11 @@ func (s *membershipStatements) UpdateMembership(
senderUserNID types.EventStateKeyNID, membership tables.MembershipState,
eventNID types.EventNID,
) error {
stmt := sqlutil.TxStmt(txn, s.updateMembershipStmt)
_, err := stmt.ExecContext(
ctx, senderUserNID, membership, eventNID, roomNID, targetUserNID,
)
return err
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.updateMembershipStmt)
_, err := stmt.ExecContext(
ctx, senderUserNID, membership, eventNID, roomNID, targetUserNID,
)
return err
})
}

View file

@ -53,12 +53,17 @@ const selectPreviousEventExistsSQL = `
`
type previousEventStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
insertPreviousEventStmt *sql.Stmt
selectPreviousEventExistsStmt *sql.Stmt
}
func NewSqlitePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) {
s := &previousEventStatements{}
s := &previousEventStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
}
_, err := db.Exec(previousEventSchema)
if err != nil {
return nil, err
@ -77,11 +82,13 @@ func (s *previousEventStatements) InsertPreviousEvent(
previousEventReferenceSHA256 []byte,
eventNID types.EventNID,
) error {
stmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt)
_, err := stmt.ExecContext(
ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID),
)
return err
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt)
_, err := stmt.ExecContext(
ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID),
)
return err
})
}
// Check if the event reference exists

View file

@ -19,6 +19,7 @@ import (
"database/sql"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
)
@ -43,13 +44,18 @@ const selectPublishedSQL = "" +
"SELECT published FROM roomserver_published WHERE room_id = $1"
type publishedStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
upsertPublishedStmt *sql.Stmt
selectAllPublishedStmt *sql.Stmt
selectPublishedStmt *sql.Stmt
}
func NewSqlitePublishedTable(db *sql.DB) (tables.Published, error) {
s := &publishedStatements{}
s := &publishedStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
}
_, err := db.Exec(publishedSchema)
if err != nil {
return nil, err
@ -64,8 +70,10 @@ func NewSqlitePublishedTable(db *sql.DB) (tables.Published, error) {
func (s *publishedStatements) UpsertRoomPublished(
ctx context.Context, roomID string, published bool,
) (err error) {
_, err = s.upsertPublishedStmt.ExecContext(ctx, roomID, published)
return
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
_, err := s.upsertPublishedStmt.ExecContext(ctx, roomID, published)
return err
})
}
func (s *publishedStatements) SelectPublishedFromRoomID(

View file

@ -52,6 +52,8 @@ const markRedactionValidatedSQL = "" +
" UPDATE roomserver_redactions SET validated = $2 WHERE redaction_event_id = $1"
type redactionStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
insertRedactionStmt *sql.Stmt
selectRedactionInfoByRedactionEventIDStmt *sql.Stmt
selectRedactionInfoByEventBeingRedactedStmt *sql.Stmt
@ -59,7 +61,10 @@ type redactionStatements struct {
}
func NewSqliteRedactionsTable(db *sql.DB) (tables.Redactions, error) {
s := &redactionStatements{}
s := &redactionStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
}
_, err := db.Exec(redactionsSchema)
if err != nil {
return nil, err
@ -76,9 +81,11 @@ func NewSqliteRedactionsTable(db *sql.DB) (tables.Redactions, error) {
func (s *redactionStatements) InsertRedaction(
ctx context.Context, txn *sql.Tx, info tables.RedactionInfo,
) error {
stmt := sqlutil.TxStmt(txn, s.insertRedactionStmt)
_, err := stmt.ExecContext(ctx, info.RedactionEventID, info.RedactsEventID, info.Validated)
return err
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.insertRedactionStmt)
_, err := stmt.ExecContext(ctx, info.RedactionEventID, info.RedactsEventID, info.Validated)
return err
})
}
func (s *redactionStatements) SelectRedactionInfoByRedactionEventID(
@ -114,7 +121,9 @@ func (s *redactionStatements) SelectRedactionInfoByEventBeingRedacted(
func (s *redactionStatements) MarkRedactionValidated(
ctx context.Context, txn *sql.Tx, redactionEventID string, validated bool,
) error {
stmt := sqlutil.TxStmt(txn, s.markRedactionValidatedStmt)
_, err := stmt.ExecContext(ctx, redactionEventID, validated)
return err
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.markRedactionValidatedStmt)
_, err := stmt.ExecContext(ctx, redactionEventID, validated)
return err
})
}

View file

@ -20,6 +20,7 @@ import (
"database/sql"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
)
@ -55,6 +56,8 @@ const deleteRoomAliasSQL = `
`
type roomAliasesStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
insertRoomAliasStmt *sql.Stmt
selectRoomIDFromAliasStmt *sql.Stmt
selectAliasesFromRoomIDStmt *sql.Stmt
@ -63,7 +66,10 @@ type roomAliasesStatements struct {
}
func NewSqliteRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) {
s := &roomAliasesStatements{}
s := &roomAliasesStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
}
_, err := db.Exec(roomAliasesSchema)
if err != nil {
return nil, err
@ -80,8 +86,10 @@ func NewSqliteRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) {
func (s *roomAliasesStatements) InsertRoomAlias(
ctx context.Context, alias string, roomID string, creatorUserID string,
) (err error) {
_, err = s.insertRoomAliasStmt.ExecContext(ctx, alias, roomID, creatorUserID)
return
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
_, err := s.insertRoomAliasStmt.ExecContext(ctx, alias, roomID, creatorUserID)
return err
})
}
func (s *roomAliasesStatements) SelectRoomIDFromAlias(
@ -130,6 +138,8 @@ func (s *roomAliasesStatements) SelectCreatorIDFromAlias(
func (s *roomAliasesStatements) DeleteRoomAlias(
ctx context.Context, alias string,
) (err error) {
_, err = s.deleteRoomAliasStmt.ExecContext(ctx, alias)
return
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
_, err := s.deleteRoomAliasStmt.ExecContext(ctx, alias)
return err
})
}

View file

@ -64,6 +64,8 @@ const selectRoomVersionForRoomNIDSQL = "" +
"SELECT room_version FROM roomserver_rooms WHERE room_nid = $1"
type roomStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
insertRoomNIDStmt *sql.Stmt
selectRoomNIDStmt *sql.Stmt
selectLatestEventNIDsStmt *sql.Stmt
@ -74,7 +76,10 @@ type roomStatements struct {
}
func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) {
s := &roomStatements{}
s := &roomStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
}
_, err := db.Exec(roomsSchema)
if err != nil {
return nil, err
@ -94,9 +99,12 @@ func (s *roomStatements) InsertRoomNID(
ctx context.Context, txn *sql.Tx,
roomID string, roomVersion gomatrixserverlib.RoomVersion,
) (types.RoomNID, error) {
var err error
insertStmt := sqlutil.TxStmt(txn, s.insertRoomNIDStmt)
if _, err = insertStmt.ExecContext(ctx, roomID, roomVersion); err == nil {
err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
insertStmt := sqlutil.TxStmt(txn, s.insertRoomNIDStmt)
_, err := insertStmt.ExecContext(ctx, roomID, roomVersion)
return err
})
if err == nil {
return s.SelectRoomNID(ctx, txn, roomID)
} else {
return types.RoomNID(0), err
@ -155,15 +163,17 @@ func (s *roomStatements) UpdateLatestEventNIDs(
lastEventSentNID types.EventNID,
stateSnapshotNID types.StateSnapshotNID,
) error {
stmt := sqlutil.TxStmt(txn, s.updateLatestEventNIDsStmt)
_, err := stmt.ExecContext(
ctx,
eventNIDsAsArray(eventNIDs),
int64(lastEventSentNID),
int64(stateSnapshotNID),
roomNID,
)
return err
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.updateLatestEventNIDsStmt)
_, err := stmt.ExecContext(
ctx,
eventNIDsAsArray(eventNIDs),
int64(lastEventSentNID),
int64(stateSnapshotNID),
roomNID,
)
return err
})
}
func (s *roomStatements) SelectRoomVersionForRoomID(

View file

@ -74,6 +74,7 @@ const bulkSelectFilteredStateBlockEntriesSQL = "" +
type stateBlockStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
insertStateDataStmt *sql.Stmt
selectNextStateBlockNIDStmt *sql.Stmt
bulkSelectStateBlockEntriesStmt *sql.Stmt
@ -81,8 +82,10 @@ type stateBlockStatements struct {
}
func NewSqliteStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
s := &stateBlockStatements{}
s.db = db
s := &stateBlockStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
}
_, err := db.Exec(stateDataSchema)
if err != nil {
return nil, err
@ -104,24 +107,26 @@ func (s *stateBlockStatements) BulkInsertStateData(
return 0, nil
}
var stateBlockNID types.StateBlockNID
err := txn.Stmt(s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID)
if err != nil {
return 0, err
}
for _, entry := range entries {
_, err := txn.Stmt(s.insertStateDataStmt).ExecContext(
ctx,
int64(stateBlockNID),
int64(entry.EventTypeNID),
int64(entry.EventStateKeyNID),
int64(entry.EventNID),
)
err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
err := txn.Stmt(s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID)
if err != nil {
return 0, err
return err
}
}
return stateBlockNID, nil
for _, entry := range entries {
_, err := txn.Stmt(s.insertStateDataStmt).ExecContext(
ctx,
int64(stateBlockNID),
int64(entry.EventTypeNID),
int64(entry.EventStateKeyNID),
int64(entry.EventNID),
)
if err != nil {
return err
}
}
return nil
})
return stateBlockNID, err
}
func (s *stateBlockStatements) BulkSelectStateBlockEntries(

View file

@ -50,13 +50,16 @@ const bulkSelectStateBlockNIDsSQL = "" +
type stateSnapshotStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
insertStateStmt *sql.Stmt
bulkSelectStateBlockNIDsStmt *sql.Stmt
}
func NewSqliteStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
s := &stateSnapshotStatements{}
s.db = db
s := &stateSnapshotStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
}
_, err := db.Exec(stateSnapshotSchema)
if err != nil {
return nil, err
@ -75,14 +78,19 @@ func (s *stateSnapshotStatements) InsertState(
if err != nil {
return
}
insertStmt := txn.Stmt(s.insertStateStmt)
if res, err2 := insertStmt.ExecContext(ctx, int64(roomNID), string(stateBlockNIDsJSON)); err2 == nil {
lastRowID, err3 := res.LastInsertId()
if err3 != nil {
err = err3
err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
insertStmt := txn.Stmt(s.insertStateStmt)
res, err := insertStmt.ExecContext(ctx, int64(roomNID), string(stateBlockNIDsJSON))
if err != nil {
return err
}
lastRowID, err := res.LastInsertId()
if err != nil {
return err
}
stateNID = types.StateSnapshotNID(lastRowID)
}
return nil
})
return
}

View file

@ -44,12 +44,17 @@ const selectTransactionEventIDSQL = `
`
type transactionStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
insertTransactionStmt *sql.Stmt
selectTransactionEventIDStmt *sql.Stmt
}
func NewSqliteTransactionsTable(db *sql.DB) (tables.Transactions, error) {
s := &transactionStatements{}
s := &transactionStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
}
_, err := db.Exec(transactionsSchema)
if err != nil {
return nil, err
@ -68,11 +73,13 @@ func (s *transactionStatements) InsertTransaction(
userID string,
eventID string,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertTransactionStmt)
_, err = stmt.ExecContext(
ctx, transactionID, sessionID, userID, eventID,
)
return
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.insertTransactionStmt)
_, err := stmt.ExecContext(
ctx, transactionID, sessionID, userID, eventID,
)
return err
})
}
func (s *transactionStatements) SelectTransactionEventID(