0
0
Fork 0
mirror of https://github.com/matrix-org/dendrite synced 2025-01-22 07:30:19 +01:00
dendrite/federationsender/storage/sqlite3/queue_edus_table.go
Neil Alexander 9d53351dc2
Component-wide TransactionWriters (#1290)
* Offset updates take place using TransactionWriter

* Refactor TransactionWriter in current state server

* Refactor TransactionWriter in federation sender

* Refactor TransactionWriter in key server

* Refactor TransactionWriter in media API

* Refactor TransactionWriter in server key API

* Refactor TransactionWriter in sync API

* Refactor TransactionWriter in user API

* Fix deadlocking Sync API tests

* Un-deadlock device database

* Fix appservice API

* Rename TransactionWriters to Writers

* Move writers up a layer in sync API

* Document sqlutil.Writer interface

* Add note to Writer documentation
2020-08-21 10:42:08 +01:00

207 lines
5.9 KiB
Go

// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
const queueEDUsSchema = `
CREATE TABLE IF NOT EXISTS federationsender_queue_edus (
-- The type of the event (informational).
edu_type TEXT NOT NULL,
-- The domain part of the user ID the EDU event is for.
server_name TEXT NOT NULL,
-- The JSON NID from the federationsender_queue_edus_json table.
json_nid BIGINT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx
ON federationsender_queue_edus (json_nid, server_name);
`
const insertQueueEDUSQL = "" +
"INSERT INTO federationsender_queue_edus (edu_type, server_name, json_nid)" +
" VALUES ($1, $2, $3)"
const deleteQueueEDUsSQL = "" +
"DELETE FROM federationsender_queue_edus WHERE server_name = $1 AND json_nid IN ($2)"
const selectQueueEDUSQL = "" +
"SELECT json_nid FROM federationsender_queue_edus" +
" WHERE server_name = $1" +
" LIMIT $2"
const selectQueueEDUReferenceJSONCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_edus" +
" WHERE json_nid = $1"
const selectQueueEDUCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_edus" +
" WHERE server_name = $1"
const selectQueueServerNamesSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_queue_edus"
type queueEDUsStatements struct {
db *sql.DB
insertQueueEDUStmt *sql.Stmt
selectQueueEDUStmt *sql.Stmt
selectQueueEDUReferenceJSONCountStmt *sql.Stmt
selectQueueEDUCountStmt *sql.Stmt
selectQueueEDUServerNamesStmt *sql.Stmt
}
func NewSQLiteQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) {
s = &queueEDUsStatements{
db: db,
}
_, err = db.Exec(queueEDUsSchema)
if err != nil {
return
}
if s.insertQueueEDUStmt, err = db.Prepare(insertQueueEDUSQL); err != nil {
return
}
if s.selectQueueEDUStmt, err = db.Prepare(selectQueueEDUSQL); err != nil {
return
}
if s.selectQueueEDUReferenceJSONCountStmt, err = db.Prepare(selectQueueEDUReferenceJSONCountSQL); err != nil {
return
}
if s.selectQueueEDUCountStmt, err = db.Prepare(selectQueueEDUCountSQL); err != nil {
return
}
if s.selectQueueEDUServerNamesStmt, err = db.Prepare(selectQueueServerNamesSQL); err != nil {
return
}
return
}
func (s *queueEDUsStatements) InsertQueueEDU(
ctx context.Context,
txn *sql.Tx,
eduType string,
serverName gomatrixserverlib.ServerName,
nid int64,
) error {
stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt)
_, err := stmt.ExecContext(
ctx,
eduType, // the EDU type
serverName, // destination server name
nid, // JSON blob NID
)
return err
}
func (s *queueEDUsStatements) DeleteQueueEDUs(
ctx context.Context, txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
jsonNIDs []int64,
) error {
deleteSQL := strings.Replace(deleteQueueEDUsSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1)
deleteStmt, err := txn.Prepare(deleteSQL)
if err != nil {
return fmt.Errorf("s.deleteQueueJSON s.db.Prepare: %w", err)
}
params := make([]interface{}, len(jsonNIDs)+1)
params[0] = serverName
for k, v := range jsonNIDs {
params[k+1] = v
}
stmt := sqlutil.TxStmt(txn, deleteStmt)
_, err = stmt.ExecContext(ctx, params...)
return err
}
func (s *queueEDUsStatements) SelectQueueEDUs(
ctx context.Context, txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
limit int,
) ([]int64, error) {
stmt := sqlutil.TxStmt(txn, s.selectQueueEDUStmt)
rows, err := stmt.QueryContext(ctx, serverName, limit)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed")
var result []int64
for rows.Next() {
var nid int64
if err = rows.Scan(&nid); err != nil {
return nil, err
}
result = append(result, nid)
}
return result, nil
}
func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount(
ctx context.Context, txn *sql.Tx, jsonNID int64,
) (int64, error) {
var count int64
stmt := sqlutil.TxStmt(txn, s.selectQueueEDUReferenceJSONCountStmt)
err := stmt.QueryRowContext(ctx, jsonNID).Scan(&count)
if err == sql.ErrNoRows {
return -1, nil
}
return count, err
}
func (s *queueEDUsStatements) SelectQueueEDUCount(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (int64, error) {
var count int64
stmt := sqlutil.TxStmt(txn, s.selectQueueEDUCountStmt)
err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
if err == sql.ErrNoRows {
// It's acceptable for there to be no rows referencing a given
// JSON NID but it's not an error condition. Just return as if
// there's a zero count.
return 0, nil
}
return count, err
}
func (s *queueEDUsStatements) SelectQueueEDUServerNames(
ctx context.Context, txn *sql.Tx,
) ([]gomatrixserverlib.ServerName, error) {
stmt := sqlutil.TxStmt(txn, s.selectQueueEDUServerNamesStmt)
rows, err := stmt.QueryContext(ctx)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed")
var result []gomatrixserverlib.ServerName
for rows.Next() {
var serverName gomatrixserverlib.ServerName
if err = rows.Scan(&serverName); err != nil {
return nil, err
}
result = append(result, serverName)
}
return result, rows.Err()
}