mirror of
https://github.com/matrix-org/dendrite
synced 2025-01-18 22:31:57 +01:00
ec716793eb
* Initial federation sender -> federation API refactoring * Move base into own package, avoids import cycle * Fix build errors * Fix tests * Add signing key server tables * Try to fold signing key server into federation API * Fix dendritejs builds * Update embedded interfaces * Fix panic, fix lint error * Update configs, docker * Rename some things * Reuse same keyring on the implementing side * Fix federation tests, `NewBaseDendrite` can accept freeform options * Fix build * Update create_db, configs * Name tables back * Don't rename federationsender consumer for now
702 lines
23 KiB
Go
702 lines
23 KiB
Go
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package internal
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
|
|
fedsenderapi "github.com/matrix-org/dendrite/federationapi/api"
|
|
"github.com/matrix-org/dendrite/keyserver/api"
|
|
"github.com/matrix-org/dendrite/keyserver/producers"
|
|
"github.com/matrix-org/dendrite/keyserver/storage"
|
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
|
"github.com/matrix-org/gomatrixserverlib"
|
|
"github.com/matrix-org/util"
|
|
"github.com/sirupsen/logrus"
|
|
"github.com/tidwall/gjson"
|
|
"github.com/tidwall/sjson"
|
|
)
|
|
|
|
type KeyInternalAPI struct {
|
|
DB storage.Database
|
|
ThisServer gomatrixserverlib.ServerName
|
|
FedClient fedsenderapi.FederationClient
|
|
UserAPI userapi.UserInternalAPI
|
|
Producer *producers.KeyChange
|
|
Updater *DeviceListUpdater
|
|
}
|
|
|
|
func (a *KeyInternalAPI) SetUserAPI(i userapi.UserInternalAPI) {
|
|
a.UserAPI = i
|
|
}
|
|
|
|
func (a *KeyInternalAPI) InputDeviceListUpdate(
|
|
ctx context.Context, req *api.InputDeviceListUpdateRequest, res *api.InputDeviceListUpdateResponse,
|
|
) {
|
|
err := a.Updater.Update(ctx, req.Event)
|
|
if err != nil {
|
|
res.Error = &api.KeyError{
|
|
Err: fmt.Sprintf("failed to update device list: %s", err),
|
|
}
|
|
}
|
|
}
|
|
|
|
func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) {
|
|
if req.Partition < 0 {
|
|
req.Partition = a.Producer.DefaultPartition()
|
|
}
|
|
userIDs, latest, err := a.DB.KeyChanges(ctx, req.Partition, req.Offset, req.ToOffset)
|
|
if err != nil {
|
|
res.Error = &api.KeyError{
|
|
Err: err.Error(),
|
|
}
|
|
}
|
|
res.Offset = latest
|
|
res.Partition = req.Partition
|
|
res.UserIDs = userIDs
|
|
}
|
|
|
|
func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
|
|
res.KeyErrors = make(map[string]map[string]*api.KeyError)
|
|
a.uploadLocalDeviceKeys(ctx, req, res)
|
|
a.uploadOneTimeKeys(ctx, req, res)
|
|
}
|
|
|
|
func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) {
|
|
res.OneTimeKeys = make(map[string]map[string]map[string]json.RawMessage)
|
|
res.Failures = make(map[string]interface{})
|
|
// wrap request map in a top-level by-domain map
|
|
domainToDeviceKeys := make(map[string]map[string]map[string]string)
|
|
for userID, val := range req.OneTimeKeys {
|
|
_, serverName, err := gomatrixserverlib.SplitID('@', userID)
|
|
if err != nil {
|
|
continue // ignore invalid users
|
|
}
|
|
nested, ok := domainToDeviceKeys[string(serverName)]
|
|
if !ok {
|
|
nested = make(map[string]map[string]string)
|
|
}
|
|
nested[userID] = val
|
|
domainToDeviceKeys[string(serverName)] = nested
|
|
}
|
|
// claim local keys
|
|
if local, ok := domainToDeviceKeys[string(a.ThisServer)]; ok {
|
|
keys, err := a.DB.ClaimKeys(ctx, local)
|
|
if err != nil {
|
|
res.Error = &api.KeyError{
|
|
Err: fmt.Sprintf("failed to ClaimKeys locally: %s", err),
|
|
}
|
|
}
|
|
util.GetLogger(ctx).WithField("keys_claimed", len(keys)).WithField("num_users", len(local)).Info("Claimed local keys")
|
|
for _, key := range keys {
|
|
_, ok := res.OneTimeKeys[key.UserID]
|
|
if !ok {
|
|
res.OneTimeKeys[key.UserID] = make(map[string]map[string]json.RawMessage)
|
|
}
|
|
_, ok = res.OneTimeKeys[key.UserID][key.DeviceID]
|
|
if !ok {
|
|
res.OneTimeKeys[key.UserID][key.DeviceID] = make(map[string]json.RawMessage)
|
|
}
|
|
for keyID, keyJSON := range key.KeyJSON {
|
|
res.OneTimeKeys[key.UserID][key.DeviceID][keyID] = keyJSON
|
|
}
|
|
}
|
|
delete(domainToDeviceKeys, string(a.ThisServer))
|
|
}
|
|
if len(domainToDeviceKeys) > 0 {
|
|
a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys)
|
|
}
|
|
}
|
|
|
|
func (a *KeyInternalAPI) claimRemoteKeys(
|
|
ctx context.Context, timeout time.Duration, res *api.PerformClaimKeysResponse, domainToDeviceKeys map[string]map[string]map[string]string,
|
|
) {
|
|
resultCh := make(chan *gomatrixserverlib.RespClaimKeys, len(domainToDeviceKeys))
|
|
// allows us to wait until all federation servers have been poked
|
|
var wg sync.WaitGroup
|
|
wg.Add(len(domainToDeviceKeys))
|
|
// mutex for failures
|
|
var failMu sync.Mutex
|
|
util.GetLogger(ctx).WithField("num_servers", len(domainToDeviceKeys)).Info("Claiming remote keys from servers")
|
|
|
|
// fan out
|
|
for d, k := range domainToDeviceKeys {
|
|
go func(domain string, keysToClaim map[string]map[string]string) {
|
|
defer wg.Done()
|
|
fedCtx, cancel := context.WithTimeout(ctx, timeout)
|
|
defer cancel()
|
|
claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, gomatrixserverlib.ServerName(domain), keysToClaim)
|
|
if err != nil {
|
|
util.GetLogger(ctx).WithError(err).WithField("server", domain).Error("ClaimKeys failed")
|
|
failMu.Lock()
|
|
res.Failures[domain] = map[string]interface{}{
|
|
"message": err.Error(),
|
|
}
|
|
failMu.Unlock()
|
|
return
|
|
}
|
|
resultCh <- &claimKeyRes
|
|
}(d, k)
|
|
}
|
|
|
|
// Close the result channel when the goroutines have quit so the for .. range exits
|
|
go func() {
|
|
wg.Wait()
|
|
close(resultCh)
|
|
}()
|
|
|
|
keysClaimed := 0
|
|
for result := range resultCh {
|
|
for userID, nest := range result.OneTimeKeys {
|
|
res.OneTimeKeys[userID] = make(map[string]map[string]json.RawMessage)
|
|
for deviceID, nest2 := range nest {
|
|
res.OneTimeKeys[userID][deviceID] = make(map[string]json.RawMessage)
|
|
for keyIDWithAlgo, otk := range nest2 {
|
|
keyJSON, err := json.Marshal(otk)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
res.OneTimeKeys[userID][deviceID][keyIDWithAlgo] = keyJSON
|
|
keysClaimed++
|
|
}
|
|
}
|
|
}
|
|
}
|
|
util.GetLogger(ctx).WithField("num_keys", keysClaimed).Info("Claimed remote keys")
|
|
}
|
|
|
|
func (a *KeyInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) {
|
|
if err := a.DB.DeleteDeviceKeys(ctx, req.UserID, req.KeyIDs); err != nil {
|
|
res.Error = &api.KeyError{
|
|
Err: fmt.Sprintf("Failed to delete device keys: %s", err),
|
|
}
|
|
}
|
|
}
|
|
|
|
func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOneTimeKeysRequest, res *api.QueryOneTimeKeysResponse) {
|
|
count, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
|
|
if err != nil {
|
|
res.Error = &api.KeyError{
|
|
Err: fmt.Sprintf("Failed to query OTK counts: %s", err),
|
|
}
|
|
return
|
|
}
|
|
res.Count = *count
|
|
}
|
|
|
|
func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) {
|
|
msgs, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil)
|
|
if err != nil {
|
|
res.Error = &api.KeyError{
|
|
Err: fmt.Sprintf("failed to query DB for device keys: %s", err),
|
|
}
|
|
return
|
|
}
|
|
maxStreamID := 0
|
|
for _, m := range msgs {
|
|
if m.StreamID > maxStreamID {
|
|
maxStreamID = m.StreamID
|
|
}
|
|
}
|
|
// remove deleted devices
|
|
var result []api.DeviceMessage
|
|
for _, m := range msgs {
|
|
if m.KeyJSON == nil {
|
|
continue
|
|
}
|
|
result = append(result, m)
|
|
}
|
|
res.Devices = result
|
|
res.StreamID = maxStreamID
|
|
}
|
|
|
|
func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) {
|
|
res.DeviceKeys = make(map[string]map[string]json.RawMessage)
|
|
res.MasterKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
|
|
res.SelfSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
|
|
res.UserSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
|
|
res.Failures = make(map[string]interface{})
|
|
|
|
// get cross-signing keys from the database
|
|
a.crossSigningKeysFromDatabase(ctx, req, res)
|
|
|
|
// make a map from domain to device keys
|
|
domainToDeviceKeys := make(map[string]map[string][]string)
|
|
domainToCrossSigningKeys := make(map[string]map[string]struct{})
|
|
for userID, deviceIDs := range req.UserToDevices {
|
|
_, serverName, err := gomatrixserverlib.SplitID('@', userID)
|
|
if err != nil {
|
|
continue // ignore invalid users
|
|
}
|
|
domain := string(serverName)
|
|
// query local devices
|
|
if serverName == a.ThisServer {
|
|
deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs)
|
|
if err != nil {
|
|
res.Error = &api.KeyError{
|
|
Err: fmt.Sprintf("failed to query local device keys: %s", err),
|
|
}
|
|
return
|
|
}
|
|
|
|
// pull out display names after we have the keys so we handle wildcards correctly
|
|
var dids []string
|
|
for _, dk := range deviceKeys {
|
|
dids = append(dids, dk.DeviceID)
|
|
}
|
|
var queryRes userapi.QueryDeviceInfosResponse
|
|
err = a.UserAPI.QueryDeviceInfos(ctx, &userapi.QueryDeviceInfosRequest{
|
|
DeviceIDs: dids,
|
|
}, &queryRes)
|
|
if err != nil {
|
|
util.GetLogger(ctx).Warnf("Failed to QueryDeviceInfos for device IDs, display names will be missing")
|
|
}
|
|
|
|
if res.DeviceKeys[userID] == nil {
|
|
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
|
|
}
|
|
for _, dk := range deviceKeys {
|
|
if len(dk.KeyJSON) == 0 {
|
|
continue // don't include blank keys
|
|
}
|
|
// inject display name if known (either locally or remotely)
|
|
displayName := dk.DisplayName
|
|
if queryRes.DeviceInfo[dk.DeviceID].DisplayName != "" {
|
|
displayName = queryRes.DeviceInfo[dk.DeviceID].DisplayName
|
|
}
|
|
dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct {
|
|
DisplayName string `json:"device_display_name,omitempty"`
|
|
}{displayName})
|
|
res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON
|
|
}
|
|
} else {
|
|
domainToDeviceKeys[domain] = make(map[string][]string)
|
|
domainToDeviceKeys[domain][userID] = append(domainToDeviceKeys[domain][userID], deviceIDs...)
|
|
}
|
|
// work out if our cross-signing request for this user was
|
|
// satisfied, if not add them to the list of things to fetch
|
|
if _, ok := res.MasterKeys[userID]; !ok {
|
|
if _, ok := domainToCrossSigningKeys[domain]; !ok {
|
|
domainToCrossSigningKeys[domain] = make(map[string]struct{})
|
|
}
|
|
domainToCrossSigningKeys[domain][userID] = struct{}{}
|
|
}
|
|
if _, ok := res.SelfSigningKeys[userID]; !ok {
|
|
if _, ok := domainToCrossSigningKeys[domain]; !ok {
|
|
domainToCrossSigningKeys[domain] = make(map[string]struct{})
|
|
}
|
|
domainToCrossSigningKeys[domain][userID] = struct{}{}
|
|
}
|
|
}
|
|
|
|
// attempt to satisfy key queries from the local database first as we should get device updates pushed to us
|
|
domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, domainToDeviceKeys)
|
|
if len(domainToDeviceKeys) > 0 || len(domainToCrossSigningKeys) > 0 {
|
|
// perform key queries for remote devices
|
|
a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys, domainToCrossSigningKeys)
|
|
}
|
|
|
|
// Finally, append signatures that we know about
|
|
// TODO: This is horrible because we need to round-trip the signature from
|
|
// JSON, add the signatures and marshal it again, for some reason?
|
|
for userID, forUserID := range res.DeviceKeys {
|
|
for keyID, key := range forUserID {
|
|
sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, userID, gomatrixserverlib.KeyID(keyID))
|
|
if err != nil {
|
|
logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed")
|
|
continue
|
|
}
|
|
if len(sigMap) == 0 {
|
|
continue
|
|
}
|
|
var deviceKey gomatrixserverlib.DeviceKeys
|
|
if err = json.Unmarshal(key, &deviceKey); err != nil {
|
|
continue
|
|
}
|
|
for sourceUserID, forSourceUser := range sigMap {
|
|
for sourceKeyID, sourceSig := range forSourceUser {
|
|
deviceKey.Signatures[sourceUserID][sourceKeyID] = sourceSig
|
|
}
|
|
}
|
|
if js, err := json.Marshal(deviceKey); err == nil {
|
|
res.DeviceKeys[userID][keyID] = js
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (a *KeyInternalAPI) remoteKeysFromDatabase(
|
|
ctx context.Context, res *api.QueryKeysResponse, domainToDeviceKeys map[string]map[string][]string,
|
|
) map[string]map[string][]string {
|
|
fetchRemote := make(map[string]map[string][]string)
|
|
for domain, userToDeviceMap := range domainToDeviceKeys {
|
|
for userID, deviceIDs := range userToDeviceMap {
|
|
// we can't safely return keys from the db when all devices are requested as we don't
|
|
// know if one has just been added.
|
|
if len(deviceIDs) > 0 {
|
|
err := a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, deviceIDs)
|
|
if err == nil {
|
|
continue
|
|
}
|
|
util.GetLogger(ctx).WithError(err).Error("populateResponseWithDeviceKeysFromDatabase")
|
|
}
|
|
// fetch device lists from remote
|
|
if _, ok := fetchRemote[domain]; !ok {
|
|
fetchRemote[domain] = make(map[string][]string)
|
|
}
|
|
fetchRemote[domain][userID] = append(fetchRemote[domain][userID], deviceIDs...)
|
|
|
|
}
|
|
}
|
|
return fetchRemote
|
|
}
|
|
|
|
func (a *KeyInternalAPI) queryRemoteKeys(
|
|
ctx context.Context, timeout time.Duration, res *api.QueryKeysResponse,
|
|
domainToDeviceKeys map[string]map[string][]string, domainToCrossSigningKeys map[string]map[string]struct{},
|
|
) {
|
|
resultCh := make(chan *gomatrixserverlib.RespQueryKeys, len(domainToDeviceKeys))
|
|
// allows us to wait until all federation servers have been poked
|
|
var wg sync.WaitGroup
|
|
// mutex for writing directly to res (e.g failures)
|
|
var respMu sync.Mutex
|
|
|
|
domains := map[string]struct{}{}
|
|
for domain := range domainToDeviceKeys {
|
|
if domain == string(a.ThisServer) {
|
|
continue
|
|
}
|
|
domains[domain] = struct{}{}
|
|
}
|
|
for domain := range domainToCrossSigningKeys {
|
|
if domain == string(a.ThisServer) {
|
|
continue
|
|
}
|
|
domains[domain] = struct{}{}
|
|
}
|
|
wg.Add(len(domains))
|
|
|
|
// fan out
|
|
for domain := range domains {
|
|
go a.queryRemoteKeysOnServer(
|
|
ctx, domain, domainToDeviceKeys[domain], domainToCrossSigningKeys[domain],
|
|
&wg, &respMu, timeout, resultCh, res,
|
|
)
|
|
}
|
|
|
|
// Close the result channel when the goroutines have quit so the for .. range exits
|
|
go func() {
|
|
wg.Wait()
|
|
close(resultCh)
|
|
}()
|
|
|
|
for result := range resultCh {
|
|
for userID, nest := range result.DeviceKeys {
|
|
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
|
|
for deviceID, deviceKey := range nest {
|
|
keyJSON, err := json.Marshal(deviceKey)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
res.DeviceKeys[userID][deviceID] = keyJSON
|
|
}
|
|
}
|
|
|
|
for userID, body := range result.MasterKeys {
|
|
res.MasterKeys[userID] = body
|
|
}
|
|
|
|
for userID, body := range result.SelfSigningKeys {
|
|
res.SelfSigningKeys[userID] = body
|
|
}
|
|
|
|
// TODO: do we want to persist these somewhere now
|
|
// that we have fetched them?
|
|
}
|
|
}
|
|
|
|
func (a *KeyInternalAPI) queryRemoteKeysOnServer(
|
|
ctx context.Context, serverName string, devKeys map[string][]string, crossSigningKeys map[string]struct{},
|
|
wg *sync.WaitGroup, respMu *sync.Mutex, timeout time.Duration, resultCh chan<- *gomatrixserverlib.RespQueryKeys,
|
|
res *api.QueryKeysResponse,
|
|
) {
|
|
defer wg.Done()
|
|
fedCtx := ctx
|
|
if timeout > 0 {
|
|
var cancel context.CancelFunc
|
|
fedCtx, cancel = context.WithTimeout(ctx, timeout)
|
|
defer cancel()
|
|
}
|
|
// for users who we do not have any knowledge about, try to start doing device list updates for them
|
|
// by hitting /users/devices - otherwise fallback to /keys/query which has nicer bulk properties but
|
|
// lack a stream ID.
|
|
userIDsForAllDevices := map[string]struct{}{}
|
|
for userID, deviceIDs := range devKeys {
|
|
if len(deviceIDs) == 0 {
|
|
userIDsForAllDevices[userID] = struct{}{}
|
|
delete(devKeys, userID)
|
|
}
|
|
}
|
|
// for cross-signing keys, it's probably easier just to hit /keys/query if we aren't already doing
|
|
// a device list update, so we'll populate those back into the /keys/query list if not
|
|
for userID := range crossSigningKeys {
|
|
if devKeys == nil {
|
|
devKeys = map[string][]string{}
|
|
}
|
|
if _, ok := userIDsForAllDevices[userID]; !ok {
|
|
devKeys[userID] = []string{}
|
|
}
|
|
}
|
|
for userID := range userIDsForAllDevices {
|
|
err := a.Updater.ManualUpdate(context.Background(), gomatrixserverlib.ServerName(serverName), userID)
|
|
if err != nil {
|
|
logrus.WithFields(logrus.Fields{
|
|
logrus.ErrorKey: err,
|
|
"user_id": userID,
|
|
"server": serverName,
|
|
}).Error("Failed to manually update device lists for user")
|
|
// try to do it via /keys/query
|
|
devKeys[userID] = []string{}
|
|
continue
|
|
}
|
|
// refresh entries from DB: unlike remoteKeysFromDatabase we know we previously had no device info for this
|
|
// user so the fact that we're populating all devices here isn't a problem so long as we have devices.
|
|
respMu.Lock()
|
|
err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, nil)
|
|
respMu.Unlock()
|
|
if err != nil {
|
|
logrus.WithFields(logrus.Fields{
|
|
logrus.ErrorKey: err,
|
|
"user_id": userID,
|
|
"server": serverName,
|
|
}).Error("Failed to manually update device lists for user")
|
|
// try to do it via /keys/query
|
|
devKeys[userID] = []string{}
|
|
continue
|
|
}
|
|
}
|
|
if len(devKeys) == 0 {
|
|
return
|
|
}
|
|
queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, gomatrixserverlib.ServerName(serverName), devKeys)
|
|
if err == nil {
|
|
resultCh <- &queryKeysResp
|
|
return
|
|
}
|
|
respMu.Lock()
|
|
res.Failures[serverName] = map[string]interface{}{
|
|
"message": err.Error(),
|
|
}
|
|
|
|
// last ditch, use the cache only. This is good for when clients hit /keys/query and the remote server
|
|
// is down, better to return something than nothing at all. Clients can know about the failure by
|
|
// inspecting the failures map though so they can know it's a cached response.
|
|
for userID, dkeys := range devKeys {
|
|
// drop the error as it's already a failure at this point
|
|
_ = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, dkeys)
|
|
}
|
|
respMu.Unlock()
|
|
|
|
}
|
|
|
|
func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
|
|
ctx context.Context, res *api.QueryKeysResponse, userID string, deviceIDs []string,
|
|
) error {
|
|
keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs)
|
|
// if we can't query the db or there are fewer keys than requested, fetch from remote.
|
|
if err != nil {
|
|
return fmt.Errorf("DeviceKeysForUser %s %v failed: %w", userID, deviceIDs, err)
|
|
}
|
|
if len(keys) < len(deviceIDs) {
|
|
return fmt.Errorf("DeviceKeysForUser %s returned fewer devices than requested, falling back to remote", userID)
|
|
}
|
|
if len(deviceIDs) == 0 && len(keys) == 0 {
|
|
return fmt.Errorf("DeviceKeysForUser %s returned no keys but wanted all keys, falling back to remote", userID)
|
|
}
|
|
if res.DeviceKeys[userID] == nil {
|
|
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
|
|
}
|
|
|
|
for _, key := range keys {
|
|
if len(key.KeyJSON) == 0 {
|
|
continue // ignore deleted keys
|
|
}
|
|
// inject the display name
|
|
key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct {
|
|
DisplayName string `json:"device_display_name,omitempty"`
|
|
}{key.DisplayName})
|
|
res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
|
|
var keysToStore []api.DeviceMessage
|
|
// assert that the user ID / device ID are not lying for each key
|
|
for _, key := range req.DeviceKeys {
|
|
_, serverName, err := gomatrixserverlib.SplitID('@', key.UserID)
|
|
if err != nil {
|
|
continue // ignore invalid users
|
|
}
|
|
if serverName != a.ThisServer {
|
|
continue // ignore remote users
|
|
}
|
|
if len(key.KeyJSON) == 0 {
|
|
keysToStore = append(keysToStore, key.WithStreamID(0))
|
|
continue // deleted keys don't need sanity checking
|
|
}
|
|
gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str
|
|
gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str
|
|
if gotUserID == key.UserID && gotDeviceID == key.DeviceID {
|
|
keysToStore = append(keysToStore, key.WithStreamID(0))
|
|
continue
|
|
}
|
|
|
|
res.KeyError(key.UserID, key.DeviceID, &api.KeyError{
|
|
Err: fmt.Sprintf(
|
|
"user_id or device_id mismatch: users: %s - %s, devices: %s - %s",
|
|
gotUserID, key.UserID, gotDeviceID, key.DeviceID,
|
|
),
|
|
})
|
|
}
|
|
|
|
// get existing device keys so we can check for changes
|
|
existingKeys := make([]api.DeviceMessage, len(keysToStore))
|
|
for i := range keysToStore {
|
|
existingKeys[i] = api.DeviceMessage{
|
|
Type: api.TypeDeviceKeyUpdate,
|
|
DeviceKeys: &api.DeviceKeys{
|
|
UserID: keysToStore[i].UserID,
|
|
DeviceID: keysToStore[i].DeviceID,
|
|
},
|
|
}
|
|
}
|
|
if err := a.DB.DeviceKeysJSON(ctx, existingKeys); err != nil {
|
|
res.Error = &api.KeyError{
|
|
Err: fmt.Sprintf("failed to query existing device keys: %s", err.Error()),
|
|
}
|
|
return
|
|
}
|
|
if req.OnlyDisplayNameUpdates {
|
|
// add the display name field from keysToStore into existingKeys
|
|
keysToStore = appendDisplayNames(existingKeys, keysToStore)
|
|
}
|
|
// store the device keys and emit changes
|
|
err := a.DB.StoreLocalDeviceKeys(ctx, keysToStore)
|
|
if err != nil {
|
|
res.Error = &api.KeyError{
|
|
Err: fmt.Sprintf("failed to store device keys: %s", err.Error()),
|
|
}
|
|
return
|
|
}
|
|
err = emitDeviceKeyChanges(a.Producer, existingKeys, keysToStore)
|
|
if err != nil {
|
|
util.GetLogger(ctx).Errorf("Failed to emitDeviceKeyChanges: %s", err)
|
|
}
|
|
}
|
|
|
|
func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
|
|
if req.UserID == "" {
|
|
res.Error = &api.KeyError{
|
|
Err: "user ID missing",
|
|
}
|
|
}
|
|
if req.DeviceID != "" && len(req.OneTimeKeys) == 0 {
|
|
counts, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
|
|
if err != nil {
|
|
res.Error = &api.KeyError{
|
|
Err: fmt.Sprintf("a.DB.OneTimeKeysCount: %s", err),
|
|
}
|
|
}
|
|
if counts != nil {
|
|
res.OneTimeKeyCounts = append(res.OneTimeKeyCounts, *counts)
|
|
}
|
|
return
|
|
}
|
|
for _, key := range req.OneTimeKeys {
|
|
// grab existing keys based on (user/device/algorithm/key ID)
|
|
keyIDsWithAlgorithms := make([]string, len(key.KeyJSON))
|
|
i := 0
|
|
for keyIDWithAlgo := range key.KeyJSON {
|
|
keyIDsWithAlgorithms[i] = keyIDWithAlgo
|
|
i++
|
|
}
|
|
existingKeys, err := a.DB.ExistingOneTimeKeys(ctx, req.UserID, req.DeviceID, keyIDsWithAlgorithms)
|
|
if err != nil {
|
|
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
|
|
Err: "failed to query existing one-time keys: " + err.Error(),
|
|
})
|
|
continue
|
|
}
|
|
for keyIDWithAlgo := range existingKeys {
|
|
// if keys exist and the JSON doesn't match, error out as the key already exists
|
|
if !bytes.Equal(existingKeys[keyIDWithAlgo], key.KeyJSON[keyIDWithAlgo]) {
|
|
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
|
|
Err: fmt.Sprintf("%s device %s: algorithm / key ID %s one-time key already exists", req.UserID, req.DeviceID, keyIDWithAlgo),
|
|
})
|
|
continue
|
|
}
|
|
}
|
|
// store one-time keys
|
|
counts, err := a.DB.StoreOneTimeKeys(ctx, key)
|
|
if err != nil {
|
|
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
|
|
Err: fmt.Sprintf("%s device %s : failed to store one-time keys: %s", req.UserID, req.DeviceID, err.Error()),
|
|
})
|
|
continue
|
|
}
|
|
// collect counts
|
|
res.OneTimeKeyCounts = append(res.OneTimeKeyCounts, *counts)
|
|
}
|
|
|
|
}
|
|
|
|
func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.DeviceMessage) error {
|
|
// find keys in new that are not in existing
|
|
var keysAdded []api.DeviceMessage
|
|
for _, newKey := range new {
|
|
exists := false
|
|
for _, existingKey := range existing {
|
|
// Do not treat the absence of keys as equal, or else we will not emit key changes
|
|
// when users delete devices which never had a key to begin with as both KeyJSONs are nil.
|
|
if bytes.Equal(existingKey.KeyJSON, newKey.KeyJSON) && len(existingKey.KeyJSON) > 0 {
|
|
exists = true
|
|
break
|
|
}
|
|
}
|
|
if !exists {
|
|
keysAdded = append(keysAdded, newKey)
|
|
}
|
|
}
|
|
return producer.ProduceKeyChanges(keysAdded)
|
|
}
|
|
|
|
func appendDisplayNames(existing, new []api.DeviceMessage) []api.DeviceMessage {
|
|
for i, existingDevice := range existing {
|
|
for _, newDevice := range new {
|
|
if existingDevice.DeviceID != newDevice.DeviceID {
|
|
continue
|
|
}
|
|
existingDevice.DisplayName = newDevice.DisplayName
|
|
existing[i] = existingDevice
|
|
}
|
|
}
|
|
return existing
|
|
}
|