diff --git a/clientapi/routing/keys.go b/clientapi/routing/keys.go index 5c3681382..0c12b1117 100644 --- a/clientapi/routing/keys.go +++ b/clientapi/routing/keys.go @@ -99,7 +99,11 @@ func (r *queryKeysRequest) GetTimeout() time.Duration { if r.Timeout == 0 { return 10 * time.Second } - return time.Duration(r.Timeout) * time.Millisecond + timeout := time.Duration(r.Timeout) * time.Millisecond + if timeout > time.Second*20 { + timeout = time.Second * 20 + } + return timeout } func QueryKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *userapi.Device) util.JSONResponse { diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 49ef03054..ff0968b27 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -257,9 +257,6 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques res.UserSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey) res.Failures = make(map[string]interface{}) - // get cross-signing keys from the database - a.crossSigningKeysFromDatabase(ctx, req, res) - // make a map from domain to device keys domainToDeviceKeys := make(map[string]map[string][]string) domainToCrossSigningKeys := make(map[string]map[string]struct{}) @@ -336,6 +333,10 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys, domainToCrossSigningKeys) } + // Now that we've done the potentially expensive work of asking the federation, + // try filling the cross-signing keys from the database that we know about. + a.crossSigningKeysFromDatabase(ctx, req, res) + // Finally, append signatures that we know about // TODO: This is horrible because we need to round-trip the signature from // JSON, add the signatures and marshal it again, for some reason? diff --git a/keyserver/storage/postgres/cross_signing_sigs_table.go b/keyserver/storage/postgres/cross_signing_sigs_table.go index 8b2a865b9..4536b7d80 100644 --- a/keyserver/storage/postgres/cross_signing_sigs_table.go +++ b/keyserver/storage/postgres/cross_signing_sigs_table.go @@ -42,7 +42,7 @@ CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_s const selectCrossSigningSigsForTargetSQL = "" + "SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" + - " WHERE (origin_user_id = $1 OR origin_user_id = target_user_id) AND target_user_id = $2 AND target_key_id = $3" + " WHERE (origin_user_id = $1 OR origin_user_id = $2) AND target_user_id = $2 AND target_key_id = $3" const upsertCrossSigningSigsForTargetSQL = "" + "INSERT INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" + diff --git a/keyserver/storage/sqlite3/cross_signing_sigs_table.go b/keyserver/storage/sqlite3/cross_signing_sigs_table.go index ea431151e..7a153e8fb 100644 --- a/keyserver/storage/sqlite3/cross_signing_sigs_table.go +++ b/keyserver/storage/sqlite3/cross_signing_sigs_table.go @@ -42,7 +42,7 @@ CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_s const selectCrossSigningSigsForTargetSQL = "" + "SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" + - " WHERE (origin_user_id = $1 OR origin_user_id = target_user_id) AND target_user_id = $2 AND target_key_id = $3" + " WHERE (origin_user_id = $1 OR origin_user_id = $2) AND target_user_id = $3 AND target_key_id = $4" const upsertCrossSigningSigsForTargetSQL = "" + "INSERT OR REPLACE INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" + @@ -85,7 +85,7 @@ func NewSqliteCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error) func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget( ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, ) (r types.CrossSigningSigMap, err error) { - rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, originUserID, targetUserID, targetKeyID) + rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, originUserID, targetUserID, targetUserID, targetKeyID) if err != nil { return nil, err }