diff --git a/src/api/client_server/account.rs b/src/api/client_server/account.rs index 6af597e1..6d37ce99 100644 --- a/src/api/client_server/account.rs +++ b/src/api/client_server/account.rs @@ -184,11 +184,11 @@ pub async fn register_route( None, &user_id, GlobalAccountDataEventType::PushRules.to_string().into(), - &ruma::events::push_rules::PushRulesEvent { + &serde_json::to_value(ruma::events::push_rules::PushRulesEvent { content: ruma::events::push_rules::PushRulesEventContent { global: push::Ruleset::server_default(&user_id), }, - }, + }).expect("to json always works"), )?; // Inhibit login does not work for guests diff --git a/src/api/client_server/media.rs b/src/api/client_server/media.rs index 316e284b..80cbb613 100644 --- a/src/api/client_server/media.rs +++ b/src/api/client_server/media.rs @@ -40,12 +40,12 @@ pub async fn create_content_route( services().media .create( mxc.clone(), - &body + body .filename .as_ref() .map(|filename| "inline; filename=".to_owned() + filename) .as_deref(), - &body.content_type.as_deref(), + body.content_type.as_deref(), &body.file, ) .await?; @@ -76,8 +76,8 @@ pub async fn get_remote_content( services().media .create( mxc.to_string(), - &content_response.content_disposition.as_deref(), - &content_response.content_type.as_deref(), + content_response.content_disposition.as_deref(), + content_response.content_type.as_deref(), &content_response.file, ) .await?; @@ -195,8 +195,8 @@ pub async fn get_content_thumbnail_route( services().media .upload_thumbnail( mxc, - &None, - &get_thumbnail_response.content_type.as_deref(), + None, + get_thumbnail_response.content_type.as_deref(), body.width.try_into().expect("all UInts are valid u32s"), body.height.try_into().expect("all UInts are valid u32s"), &get_thumbnail_response.file, diff --git a/src/api/client_server/membership.rs b/src/api/client_server/membership.rs index 720c1e64..58ed0401 100644 --- a/src/api/client_server/membership.rs +++ b/src/api/client_server/membership.rs @@ -860,9 +860,8 @@ pub(crate) async fn invite_helper<'a>( "Could not accept incoming PDU as timeline event.", ))?; - let servers = services() - .rooms - .state_cache + // Bind to variable because of lifetimes + let servers = services().rooms.state_cache .room_servers(room_id) .filter_map(|r| r.ok()) .filter(|server| &**server != services().globals.server_name()); diff --git a/src/api/client_server/push.rs b/src/api/client_server/push.rs index 112fa002..12ec25dd 100644 --- a/src/api/client_server/push.rs +++ b/src/api/client_server/push.rs @@ -20,7 +20,7 @@ pub async fn get_pushrules_all_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event: PushRulesEvent = services() + let event = services() .account_data .get( None, @@ -32,8 +32,12 @@ pub async fn get_pushrules_all_route( "PushRules event not found.", ))?; + let account_data = serde_json::from_str::(event.get()) + .map_err(|_| Error::bad_database("Invalid account data event in db."))? + .content; + Ok(get_pushrules_all::v3::Response { - global: event.content.global, + global: account_data.global, }) } @@ -45,7 +49,7 @@ pub async fn get_pushrule_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event: PushRulesEvent = services() + let event = services() .account_data .get( None, @@ -57,7 +61,11 @@ pub async fn get_pushrule_route( "PushRules event not found.", ))?; - let global = event.content.global; + let account_data = serde_json::from_str::(event.get()) + .map_err(|_| Error::bad_database("Invalid account data event in db."))? + .content; + + let global = account_data.global; let rule = match body.kind { RuleKind::Override => global .override_ @@ -108,7 +116,7 @@ pub async fn set_pushrule_route( )); } - let mut event: PushRulesEvent = services() + let event = services() .account_data .get( None, @@ -120,7 +128,10 @@ pub async fn set_pushrule_route( "PushRules event not found.", ))?; - let global = &mut event.content.global; + let mut account_data = serde_json::from_str::(event.get()) + .map_err(|_| Error::bad_database("Invalid account data event in db."))?; + + let global = &mut account_data.content.global; match body.kind { RuleKind::Override => { global.override_.replace( @@ -187,7 +198,7 @@ pub async fn set_pushrule_route( None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into(), - &event, + &serde_json::to_value(account_data).expect("to json value always works"), )?; Ok(set_pushrule::v3::Response {}) @@ -208,7 +219,7 @@ pub async fn get_pushrule_actions_route( )); } - let mut event: PushRulesEvent = services() + let event = services() .account_data .get( None, @@ -220,7 +231,11 @@ pub async fn get_pushrule_actions_route( "PushRules event not found.", ))?; - let global = &mut event.content.global; + let account_data = serde_json::from_str::(event.get()) + .map_err(|_| Error::bad_database("Invalid account data event in db."))? + .content; + + let global = account_data.global; let actions = match body.kind { RuleKind::Override => global .override_ @@ -265,7 +280,7 @@ pub async fn set_pushrule_actions_route( )); } - let mut event: PushRulesEvent = services() + let event = services() .account_data .get( None, @@ -277,7 +292,10 @@ pub async fn set_pushrule_actions_route( "PushRules event not found.", ))?; - let global = &mut event.content.global; + let mut account_data = serde_json::from_str::(event.get()) + .map_err(|_| Error::bad_database("Invalid account data event in db."))?; + + let global = &mut account_data.content.global; match body.kind { RuleKind::Override => { if let Some(mut rule) = global.override_.get(body.rule_id.as_str()).cloned() { @@ -316,7 +334,7 @@ pub async fn set_pushrule_actions_route( None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into(), - &event, + &serde_json::to_value(account_data).expect("to json value always works"), )?; Ok(set_pushrule_actions::v3::Response {}) @@ -337,7 +355,7 @@ pub async fn get_pushrule_enabled_route( )); } - let mut event: PushRulesEvent = services() + let event = services() .account_data .get( None, @@ -349,7 +367,10 @@ pub async fn get_pushrule_enabled_route( "PushRules event not found.", ))?; - let global = &mut event.content.global; + let account_data = serde_json::from_str::(event.get()) + .map_err(|_| Error::bad_database("Invalid account data event in db."))?; + + let global = account_data.content.global; let enabled = match body.kind { RuleKind::Override => global .override_ @@ -397,7 +418,7 @@ pub async fn set_pushrule_enabled_route( )); } - let mut event: PushRulesEvent = services() + let event = services() .account_data .get( None, @@ -409,7 +430,10 @@ pub async fn set_pushrule_enabled_route( "PushRules event not found.", ))?; - let global = &mut event.content.global; + let mut account_data = serde_json::from_str::(event.get()) + .map_err(|_| Error::bad_database("Invalid account data event in db."))?; + + let global = &mut account_data.content.global; match body.kind { RuleKind::Override => { if let Some(mut rule) = global.override_.get(body.rule_id.as_str()).cloned() { @@ -453,7 +477,7 @@ pub async fn set_pushrule_enabled_route( None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into(), - &event, + &serde_json::to_value(account_data).expect("to json value always works"), )?; Ok(set_pushrule_enabled::v3::Response {}) @@ -474,7 +498,7 @@ pub async fn delete_pushrule_route( )); } - let mut event: PushRulesEvent = services() + let event = services() .account_data .get( None, @@ -486,7 +510,10 @@ pub async fn delete_pushrule_route( "PushRules event not found.", ))?; - let global = &mut event.content.global; + let mut account_data = serde_json::from_str::(event.get()) + .map_err(|_| Error::bad_database("Invalid account data event in db."))?; + + let global = &mut account_data.content.global; match body.kind { RuleKind::Override => { if let Some(rule) = global.override_.get(body.rule_id.as_str()).cloned() { @@ -520,7 +547,7 @@ pub async fn delete_pushrule_route( None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into(), - &event, + &serde_json::to_value(account_data).expect("to json value always works"), )?; Ok(delete_pushrule::v3::Response {}) diff --git a/src/api/client_server/read_marker.rs b/src/api/client_server/read_marker.rs index eda57d57..c6d77c12 100644 --- a/src/api/client_server/read_marker.rs +++ b/src/api/client_server/read_marker.rs @@ -27,7 +27,7 @@ pub async fn set_read_marker_route( Some(&body.room_id), sender_user, RoomAccountDataEventType::FullyRead, - &fully_read_event, + &serde_json::to_value(fully_read_event).expect("to json value always works"), )?; if let Some(event) = &body.read_receipt { diff --git a/src/api/client_server/sync.rs b/src/api/client_server/sync.rs index 3489a9a9..9eb63831 100644 --- a/src/api/client_server/sync.rs +++ b/src/api/client_server/sync.rs @@ -175,7 +175,7 @@ async fn sync_helper( services().rooms.edus.presence.ping_presence(&sender_user)?; // Setup watchers, so if there's no response, we can wait for them - let watcher = services().globals.db.watch(&sender_user, &sender_device); + let watcher = services().globals.watch(&sender_user, &sender_device); let next_batch = services().globals.current_count()?; let next_batch_string = next_batch.to_string(); diff --git a/src/api/client_server/tag.rs b/src/api/client_server/tag.rs index bbea2d58..abf2b873 100644 --- a/src/api/client_server/tag.rs +++ b/src/api/client_server/tag.rs @@ -1,4 +1,4 @@ -use crate::{Result, Ruma, services}; +use crate::{Result, Ruma, services, Error}; use ruma::{ api::client::tag::{create_tag, delete_tag, get_tags}, events::{ @@ -18,18 +18,22 @@ pub async fn update_tag_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let mut tags_event = services() + let event = services() .account_data .get( Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag, - )? - .unwrap_or_else(|| TagEvent { + )?; + + let mut tags_event = event.map(|e| serde_json::from_str(e.get()) + .map_err(|_| Error::bad_database("Invalid account data event in db."))) + .unwrap_or_else(|| Ok(TagEvent { content: TagEventContent { tags: BTreeMap::new(), }, - }); + }))?; + tags_event .content .tags @@ -39,7 +43,7 @@ pub async fn update_tag_route( Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag, - &tags_event, + &serde_json::to_value(tags_event).expect("to json value always works"), )?; Ok(create_tag::v3::Response {}) @@ -55,25 +59,29 @@ pub async fn delete_tag_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let mut tags_event = services() + let mut event = services() .account_data .get( Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag, - )? - .unwrap_or_else(|| TagEvent { + )?; + + let mut tags_event = event.map(|e| serde_json::from_str(e.get()) + .map_err(|_| Error::bad_database("Invalid account data event in db."))) + .unwrap_or_else(|| Ok(TagEvent { content: TagEventContent { tags: BTreeMap::new(), }, - }); + }))?; + tags_event.content.tags.remove(&body.tag.clone().into()); services().account_data.update( Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag, - &tags_event, + &serde_json::to_value(tags_event).expect("to json value always works"), )?; Ok(delete_tag::v3::Response {}) @@ -89,20 +97,23 @@ pub async fn get_tags_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let mut event = services() + .account_data + .get( + Some(&body.room_id), + sender_user, + RoomAccountDataEventType::Tag, + )?; + + let mut tags_event = event.map(|e| serde_json::from_str(e.get()) + .map_err(|_| Error::bad_database("Invalid account data event in db."))) + .unwrap_or_else(|| Ok(TagEvent { + content: TagEventContent { + tags: BTreeMap::new(), + }, + }))?; + Ok(get_tags::v3::Response { - tags: services() - .account_data - .get( - Some(&body.room_id), - sender_user, - RoomAccountDataEventType::Tag, - )? - .unwrap_or_else(|| TagEvent { - content: TagEventContent { - tags: BTreeMap::new(), - }, - }) - .content - .tags, + tags: tags_event.content.tags, }) } diff --git a/src/api/server_server.rs b/src/api/server_server.rs index 45d749d0..647f4574 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -1655,10 +1655,10 @@ pub async fn get_devices_route( .collect(), master_key: services() .users - .get_master_key(&body.user_id, |u| u.server_name() == sender_servername)?, + .get_master_key(&body.user_id, &|u| u.server_name() == sender_servername)?, self_signing_key: services() .users - .get_self_signing_key(&body.user_id, |u| u.server_name() == sender_servername)?, + .get_self_signing_key(&body.user_id, &|u| u.server_name() == sender_servername)?, }) } diff --git a/src/database/key_value/account_data.rs b/src/database/key_value/account_data.rs index 49c9170f..f0325d2b 100644 --- a/src/database/key_value/account_data.rs +++ b/src/database/key_value/account_data.rs @@ -1,19 +1,19 @@ -use std::collections::HashMap; +use std::{collections::HashMap, sync::Arc}; use ruma::{UserId, DeviceId, signatures::CanonicalJsonValue, api::client::{uiaa::UiaaInfo, error::ErrorKind}, events::{RoomAccountDataEventType, AnyEphemeralRoomEvent}, serde::Raw, RoomId}; use serde::{Serialize, de::DeserializeOwned}; use crate::{Result, database::KeyValueDatabase, service, Error, utils, services}; -impl service::account_data::Data for KeyValueDatabase { +impl service::account_data::Data for Arc { /// Places one event in the account data of the user and removes the previous entry. #[tracing::instrument(skip(self, room_id, user_id, event_type, data))] - fn update( + fn update( &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType, - data: &T, + data: &serde_json::Value, ) -> Result<()> { let mut prefix = room_id .map(|r| r.to_string()) @@ -32,8 +32,7 @@ impl service::account_data::Data for KeyValueDatabase { let mut key = prefix; key.extend_from_slice(event_type.to_string().as_bytes()); - let json = serde_json::to_value(data).expect("all types here can be serialized"); // TODO: maybe add error handling - if json.get("type").is_none() || json.get("content").is_none() { + if data.get("type").is_none() || data.get("content").is_none() { return Err(Error::BadRequest( ErrorKind::InvalidParam, "Account data doesn't have all required fields.", @@ -42,7 +41,7 @@ impl service::account_data::Data for KeyValueDatabase { self.roomuserdataid_accountdata.insert( &roomuserdataid, - &serde_json::to_vec(&json).expect("to_vec always works on json values"), + &serde_json::to_vec(&data).expect("to_vec always works on json values"), )?; let prev = self.roomusertype_roomuserdataid.get(&key)?; @@ -60,12 +59,12 @@ impl service::account_data::Data for KeyValueDatabase { /// Searches the account data for a specific kind. #[tracing::instrument(skip(self, room_id, user_id, kind))] - fn get( + fn get( &self, room_id: Option<&RoomId>, user_id: &UserId, kind: RoomAccountDataEventType, - ) -> Result> { + ) -> Result>> { let mut key = room_id .map(|r| r.to_string()) .unwrap_or_default() diff --git a/src/database/key_value/appservice.rs b/src/database/key_value/appservice.rs index f427ba71..ee6ae206 100644 --- a/src/database/key_value/appservice.rs +++ b/src/database/key_value/appservice.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use crate::{database::KeyValueDatabase, service, utils, Error, Result}; impl service::appservice::Data for KeyValueDatabase { diff --git a/src/database/key_value/globals.rs b/src/database/key_value/globals.rs index e6652290..87119207 100644 --- a/src/database/key_value/globals.rs +++ b/src/database/key_value/globals.rs @@ -1,8 +1,136 @@ -use ruma::signatures::Ed25519KeyPair; +use std::{collections::BTreeMap, sync::Arc}; -use crate::{Result, service, database::KeyValueDatabase, Error, utils}; +use async_trait::async_trait; +use futures_util::{stream::FuturesUnordered, StreamExt}; +use ruma::{signatures::Ed25519KeyPair, UserId, DeviceId, ServerName, api::federation::discovery::{ServerSigningKeys, VerifyKey}, ServerSigningKeyId, MilliSecondsSinceUnixEpoch}; + +use crate::{Result, service, database::KeyValueDatabase, Error, utils, services}; + +pub const COUNTER: &[u8] = b"c"; + +#[async_trait] +impl service::globals::Data for Arc { + fn next_count(&self) -> Result { + utils::u64_from_bytes(&self.global.increment(COUNTER)?) + .map_err(|_| Error::bad_database("Count has invalid bytes.")) + } + + fn current_count(&self) -> Result { + self.global.get(COUNTER)?.map_or(Ok(0_u64), |bytes| { + utils::u64_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Count has invalid bytes.")) + }) + } + + async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { + let userid_bytes = user_id.as_bytes().to_vec(); + let mut userid_prefix = userid_bytes.clone(); + userid_prefix.push(0xff); + + let mut userdeviceid_prefix = userid_prefix.clone(); + userdeviceid_prefix.extend_from_slice(device_id.as_bytes()); + userdeviceid_prefix.push(0xff); + + let mut futures = FuturesUnordered::new(); + + // Return when *any* user changed his key + // TODO: only send for user they share a room with + futures.push( + self.todeviceid_events + .watch_prefix(&userdeviceid_prefix), + ); + + futures.push(self.userroomid_joined.watch_prefix(&userid_prefix)); + futures.push( + self.userroomid_invitestate + .watch_prefix(&userid_prefix), + ); + futures.push(self.userroomid_leftstate.watch_prefix(&userid_prefix)); + futures.push( + self.userroomid_notificationcount + .watch_prefix(&userid_prefix), + ); + futures.push( + self.userroomid_highlightcount + .watch_prefix(&userid_prefix), + ); + + // Events for rooms we are in + for room_id in services().rooms.state_cache.rooms_joined(user_id).filter_map(|r| r.ok()) { + let short_roomid = services() + .rooms + .short + .get_shortroomid(&room_id) + .ok() + .flatten() + .expect("room exists") + .to_be_bytes() + .to_vec(); + + let roomid_bytes = room_id.as_bytes().to_vec(); + let mut roomid_prefix = roomid_bytes.clone(); + roomid_prefix.push(0xff); + + // PDUs + futures.push(self.pduid_pdu.watch_prefix(&short_roomid)); + + // EDUs + futures.push( + self.roomid_lasttypingupdate + .watch_prefix(&roomid_bytes), + ); + + futures.push( + self.readreceiptid_readreceipt + .watch_prefix(&roomid_prefix), + ); + + // Key changes + futures.push(self.keychangeid_userid.watch_prefix(&roomid_prefix)); + + // Room account data + let mut roomuser_prefix = roomid_prefix.clone(); + roomuser_prefix.extend_from_slice(&userid_prefix); + + futures.push( + self.roomusertype_roomuserdataid + .watch_prefix(&roomuser_prefix), + ); + } + + let mut globaluserdata_prefix = vec![0xff]; + globaluserdata_prefix.extend_from_slice(&userid_prefix); + + futures.push( + self.roomusertype_roomuserdataid + .watch_prefix(&globaluserdata_prefix), + ); + + // More key changes (used when user is not joined to any rooms) + futures.push(self.keychangeid_userid.watch_prefix(&userid_prefix)); + + // One time keys + futures.push( + self.userid_lastonetimekeyupdate + .watch_prefix(&userid_bytes), + ); + + futures.push(Box::pin(services().globals.rotate.watch())); + + // Wait until one of them finds something + futures.next().await; + + Ok(()) + } + + fn cleanup(&self) -> Result<()> { + self._db.cleanup() + } + + fn memory_usage(&self) -> Result { + self._db.memory_usage() + } -impl service::globals::Data for KeyValueDatabase { fn load_keypair(&self) -> Result { let keypair_bytes = self.global.get(b"keypair")?.map_or_else( || { @@ -39,4 +167,81 @@ impl service::globals::Data for KeyValueDatabase { fn remove_keypair(&self) -> Result<()> { self.global.remove(b"keypair") } + + fn add_signing_key( + &self, + origin: &ServerName, + new_keys: ServerSigningKeys, + ) -> Result, VerifyKey>> { + // Not atomic, but this is not critical + let signingkeys = self.server_signingkeys.get(origin.as_bytes())?; + + let mut keys = signingkeys + .and_then(|keys| serde_json::from_slice(&keys).ok()) + .unwrap_or_else(|| { + // Just insert "now", it doesn't matter + ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now()) + }); + + let ServerSigningKeys { + verify_keys, + old_verify_keys, + .. + } = new_keys; + + keys.verify_keys.extend(verify_keys.into_iter()); + keys.old_verify_keys.extend(old_verify_keys.into_iter()); + + self.server_signingkeys.insert( + origin.as_bytes(), + &serde_json::to_vec(&keys).expect("serversigningkeys can be serialized"), + )?; + + let mut tree = keys.verify_keys; + tree.extend( + keys.old_verify_keys + .into_iter() + .map(|old| (old.0, VerifyKey::new(old.1.key))), + ); + + Ok(tree) + } + + /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found for the server. + fn signing_keys_for( + &self, + origin: &ServerName, + ) -> Result, VerifyKey>> { + let signingkeys = self + .server_signingkeys + .get(origin.as_bytes())? + .and_then(|bytes| serde_json::from_slice(&bytes).ok()) + .map(|keys: ServerSigningKeys| { + let mut tree = keys.verify_keys; + tree.extend( + keys.old_verify_keys + .into_iter() + .map(|old| (old.0, VerifyKey::new(old.1.key))), + ); + tree + }) + .unwrap_or_else(BTreeMap::new); + + Ok(signingkeys) + } + + fn database_version(&self) -> Result { + self.global.get(b"version")?.map_or(Ok(0), |version| { + utils::u64_from_bytes(&version) + .map_err(|_| Error::bad_database("Database version id is invalid.")) + }) + } + + fn bump_database_version(&self, new_version: u64) -> Result<()> { + self.global + .insert(b"version", &new_version.to_be_bytes())?; + Ok(()) + } + + } diff --git a/src/database/key_value/key_backups.rs b/src/database/key_value/key_backups.rs index 8171451c..c59ed36b 100644 --- a/src/database/key_value/key_backups.rs +++ b/src/database/key_value/key_backups.rs @@ -1,10 +1,10 @@ -use std::collections::BTreeMap; +use std::{collections::BTreeMap, sync::Arc}; use ruma::{UserId, serde::Raw, api::client::{backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, error::ErrorKind}, RoomId}; use crate::{Result, service, database::KeyValueDatabase, services, Error, utils}; -impl service::key_backups::Data for KeyValueDatabase { +impl service::key_backups::Data for Arc { fn create_backup( &self, user_id: &UserId, diff --git a/src/database/key_value/media.rs b/src/database/key_value/media.rs index a84cbd53..1726755a 100644 --- a/src/database/key_value/media.rs +++ b/src/database/key_value/media.rs @@ -1,9 +1,11 @@ +use std::sync::Arc; + use ruma::api::client::error::ErrorKind; use crate::{database::KeyValueDatabase, service, Error, utils, Result}; -impl service::media::Data for KeyValueDatabase { - fn create_file_metadata(&self, mxc: String, width: u32, height: u32, content_disposition: &Option<&str>, content_type: &Option<&str>) -> Result> { +impl service::media::Data for Arc { + fn create_file_metadata(&self, mxc: String, width: u32, height: u32, content_disposition: Option<&str>, content_type: Option<&str>) -> Result> { let mut key = mxc.as_bytes().to_vec(); key.push(0xff); key.extend_from_slice(&width.to_be_bytes()); diff --git a/src/database/key_value/pusher.rs b/src/database/key_value/pusher.rs index b05e47be..85d1d864 100644 --- a/src/database/key_value/pusher.rs +++ b/src/database/key_value/pusher.rs @@ -1,8 +1,10 @@ +use std::sync::Arc; + use ruma::{UserId, api::client::push::{set_pusher, get_pushers}}; use crate::{service, database::KeyValueDatabase, Error, Result}; -impl service::pusher::Data for KeyValueDatabase { +impl service::pusher::Data for Arc { fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::Pusher) -> Result<()> { let mut key = sender.as_bytes().to_vec(); key.push(0xff); diff --git a/src/database/key_value/rooms/alias.rs b/src/database/key_value/rooms/alias.rs index 0aa8dd48..437902df 100644 --- a/src/database/key_value/rooms/alias.rs +++ b/src/database/key_value/rooms/alias.rs @@ -1,8 +1,10 @@ +use std::sync::Arc; + use ruma::{RoomId, RoomAliasId, api::client::error::ErrorKind}; use crate::{service, database::KeyValueDatabase, utils, Error, services, Result}; -impl service::rooms::alias::Data for KeyValueDatabase { +impl service::rooms::alias::Data for Arc { fn set_alias( &self, alias: &RoomAliasId, diff --git a/src/database/key_value/rooms/auth_chain.rs b/src/database/key_value/rooms/auth_chain.rs index 888d472d..2dffb04b 100644 --- a/src/database/key_value/rooms/auth_chain.rs +++ b/src/database/key_value/rooms/auth_chain.rs @@ -1,28 +1,60 @@ -use std::{collections::HashSet, mem::size_of}; +use std::{collections::HashSet, mem::size_of, sync::Arc}; use crate::{service, database::KeyValueDatabase, Result, utils}; -impl service::rooms::auth_chain::Data for KeyValueDatabase { - fn get_cached_eventid_authchain(&self, shorteventid: u64) -> Result>> { - Ok(self.shorteventid_authchain - .get(&shorteventid.to_be_bytes())? - .map(|chain| { - chain - .chunks_exact(size_of::()) - .map(|chunk| { - utils::u64_from_bytes(chunk).expect("byte length is correct") - }) - .collect() - })) +impl service::rooms::auth_chain::Data for Arc { + fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result>>> { + // Check RAM cache + if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) { + return Ok(Some(Arc::clone(result))); + } + + // We only save auth chains for single events in the db + if key.len() == 1 { + // Check DB cache + let chain = self.shorteventid_authchain + .get(&key[0].to_be_bytes())? + .map(|chain| { + chain + .chunks_exact(size_of::()) + .map(|chunk| { + utils::u64_from_bytes(chunk).expect("byte length is correct") + }) + .collect() + }); + + if let Some(chain) = chain { + let chain = Arc::new(chain); + + // Cache in RAM + self.auth_chain_cache + .lock() + .unwrap() + .insert(vec![key[0]], Arc::clone(&chain)); + + return Ok(Some(chain)); + } + } + + Ok(None) + } - fn cache_eventid_authchain(&self, shorteventid: u64, auth_chain: &HashSet) -> Result<()> { - self.shorteventid_authchain.insert( - &shorteventid.to_be_bytes(), - &auth_chain - .iter() - .flat_map(|s| s.to_be_bytes().to_vec()) - .collect::>(), - ) + fn cache_auth_chain(&self, key: Vec, auth_chain: Arc>) -> Result<()> { + // Only persist single events in db + if key.len() == 1 { + self.shorteventid_authchain.insert( + &key[0].to_be_bytes(), + &auth_chain + .iter() + .flat_map(|s| s.to_be_bytes().to_vec()) + .collect::>(), + )?; + } + + // Cache in RAM + self.auth_chain_cache.lock().unwrap().insert(key, auth_chain); + + Ok(()) } } diff --git a/src/database/key_value/rooms/directory.rs b/src/database/key_value/rooms/directory.rs index 727004e7..864e75e9 100644 --- a/src/database/key_value/rooms/directory.rs +++ b/src/database/key_value/rooms/directory.rs @@ -1,8 +1,10 @@ +use std::sync::Arc; + use ruma::RoomId; use crate::{service, database::KeyValueDatabase, utils, Error, Result}; -impl service::rooms::directory::Data for KeyValueDatabase { +impl service::rooms::directory::Data for Arc { fn set_public(&self, room_id: &RoomId) -> Result<()> { self.publicroomids.insert(room_id.as_bytes(), &[]) } diff --git a/src/database/key_value/rooms/edus/mod.rs b/src/database/key_value/rooms/edus/mod.rs index b5007f89..03e4219e 100644 --- a/src/database/key_value/rooms/edus/mod.rs +++ b/src/database/key_value/rooms/edus/mod.rs @@ -2,6 +2,8 @@ mod presence; mod typing; mod read_receipt; +use std::sync::Arc; + use crate::{service, database::KeyValueDatabase}; -impl service::rooms::edus::Data for KeyValueDatabase {} +impl service::rooms::edus::Data for Arc {} diff --git a/src/database/key_value/rooms/edus/presence.rs b/src/database/key_value/rooms/edus/presence.rs index 1477c28b..5aeb1477 100644 --- a/src/database/key_value/rooms/edus/presence.rs +++ b/src/database/key_value/rooms/edus/presence.rs @@ -1,10 +1,10 @@ -use std::collections::HashMap; +use std::{collections::HashMap, sync::Arc}; use ruma::{UserId, RoomId, events::presence::PresenceEvent, presence::PresenceState, UInt}; use crate::{service, database::KeyValueDatabase, utils, Error, services, Result}; -impl service::rooms::edus::presence::Data for KeyValueDatabase { +impl service::rooms::edus::presence::Data for Arc { fn update_presence( &self, user_id: &UserId, diff --git a/src/database/key_value/rooms/edus/read_receipt.rs b/src/database/key_value/rooms/edus/read_receipt.rs index a12e2653..7fcb8ac8 100644 --- a/src/database/key_value/rooms/edus/read_receipt.rs +++ b/src/database/key_value/rooms/edus/read_receipt.rs @@ -1,10 +1,10 @@ -use std::mem; +use std::{mem, sync::Arc}; use ruma::{UserId, RoomId, events::receipt::ReceiptEvent, serde::Raw, signatures::CanonicalJsonObject}; use crate::{database::KeyValueDatabase, service, utils, Error, services, Result}; -impl service::rooms::edus::read_receipt::Data for KeyValueDatabase { +impl service::rooms::edus::read_receipt::Data for Arc { fn readreceipt_update( &self, user_id: &UserId, diff --git a/src/database/key_value/rooms/edus/typing.rs b/src/database/key_value/rooms/edus/typing.rs index b7d35968..7f3526d9 100644 --- a/src/database/key_value/rooms/edus/typing.rs +++ b/src/database/key_value/rooms/edus/typing.rs @@ -1,10 +1,10 @@ -use std::collections::HashSet; +use std::{collections::HashSet, sync::Arc}; use ruma::{UserId, RoomId}; use crate::{database::KeyValueDatabase, service, utils, Error, services, Result}; -impl service::rooms::edus::typing::Data for KeyValueDatabase { +impl service::rooms::edus::typing::Data for Arc { fn typing_add( &self, user_id: &UserId, diff --git a/src/database/key_value/rooms/lazy_load.rs b/src/database/key_value/rooms/lazy_load.rs index 133e1d04..b16657aa 100644 --- a/src/database/key_value/rooms/lazy_load.rs +++ b/src/database/key_value/rooms/lazy_load.rs @@ -1,8 +1,10 @@ +use std::sync::Arc; + use ruma::{UserId, DeviceId, RoomId}; use crate::{service, database::KeyValueDatabase, Result}; -impl service::rooms::lazy_loading::Data for KeyValueDatabase { +impl service::rooms::lazy_loading::Data for Arc { fn lazy_load_was_sent_before( &self, user_id: &UserId, diff --git a/src/database/key_value/rooms/metadata.rs b/src/database/key_value/rooms/metadata.rs index db2bc69b..560beb90 100644 --- a/src/database/key_value/rooms/metadata.rs +++ b/src/database/key_value/rooms/metadata.rs @@ -1,8 +1,10 @@ +use std::sync::Arc; + use ruma::RoomId; use crate::{service, database::KeyValueDatabase, Result, services}; -impl service::rooms::metadata::Data for KeyValueDatabase { +impl service::rooms::metadata::Data for Arc { fn exists(&self, room_id: &RoomId) -> Result { let prefix = match services().rooms.short.get_shortroomid(room_id)? { Some(b) => b.to_be_bytes().to_vec(), diff --git a/src/database/key_value/rooms/mod.rs b/src/database/key_value/rooms/mod.rs index 406943ed..97c29e5b 100644 --- a/src/database/key_value/rooms/mod.rs +++ b/src/database/key_value/rooms/mod.rs @@ -15,6 +15,8 @@ mod state_compressor; mod timeline; mod user; +use std::sync::Arc; + use crate::{database::KeyValueDatabase, service}; -impl service::rooms::Data for KeyValueDatabase {} +impl service::rooms::Data for Arc {} diff --git a/src/database/key_value/rooms/outlier.rs b/src/database/key_value/rooms/outlier.rs index aa975449..b1ae816a 100644 --- a/src/database/key_value/rooms/outlier.rs +++ b/src/database/key_value/rooms/outlier.rs @@ -1,8 +1,10 @@ +use std::sync::Arc; + use ruma::{EventId, signatures::CanonicalJsonObject}; use crate::{service, database::KeyValueDatabase, PduEvent, Error, Result}; -impl service::rooms::outlier::Data for KeyValueDatabase { +impl service::rooms::outlier::Data for Arc { fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result> { self.eventid_outlierpdu .get(event_id.as_bytes())? diff --git a/src/database/key_value/rooms/pdu_metadata.rs b/src/database/key_value/rooms/pdu_metadata.rs index f3ac414f..f5e8f766 100644 --- a/src/database/key_value/rooms/pdu_metadata.rs +++ b/src/database/key_value/rooms/pdu_metadata.rs @@ -4,7 +4,7 @@ use ruma::{RoomId, EventId}; use crate::{service, database::KeyValueDatabase, Result}; -impl service::rooms::pdu_metadata::Data for KeyValueDatabase { +impl service::rooms::pdu_metadata::Data for Arc { fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc]) -> Result<()> { for prev in event_ids { let mut key = room_id.as_bytes().to_vec(); diff --git a/src/database/key_value/rooms/search.rs b/src/database/key_value/rooms/search.rs index dfbdbc64..7b8d2783 100644 --- a/src/database/key_value/rooms/search.rs +++ b/src/database/key_value/rooms/search.rs @@ -1,10 +1,10 @@ -use std::mem::size_of; +use std::{mem::size_of, sync::Arc}; use ruma::RoomId; use crate::{service, database::KeyValueDatabase, utils, Result, services}; -impl service::rooms::search::Data for KeyValueDatabase { +impl service::rooms::search::Data for Arc { fn index_pdu<'a>(&self, shortroomid: u64, pdu_id: &[u8], message_body: String) -> Result<()> { let mut batch = message_body .split_terminator(|c: char| !c.is_alphanumeric()) diff --git a/src/database/key_value/rooms/short.rs b/src/database/key_value/rooms/short.rs index 91296385..9a302b56 100644 --- a/src/database/key_value/rooms/short.rs +++ b/src/database/key_value/rooms/short.rs @@ -1,4 +1,6 @@ +use std::sync::Arc; + use crate::{database::KeyValueDatabase, service}; -impl service::rooms::short::Data for KeyValueDatabase { +impl service::rooms::short::Data for Arc { } diff --git a/src/database/key_value/rooms/state.rs b/src/database/key_value/rooms/state.rs index 405939dd..527c2403 100644 --- a/src/database/key_value/rooms/state.rs +++ b/src/database/key_value/rooms/state.rs @@ -1,11 +1,12 @@ use ruma::{RoomId, EventId}; +use tokio::sync::MutexGuard; use std::sync::Arc; -use std::{sync::MutexGuard, collections::HashSet}; +use std::collections::HashSet; use std::fmt::Debug; use crate::{service, database::KeyValueDatabase, utils, Error, Result}; -impl service::rooms::state::Data for KeyValueDatabase { +impl service::rooms::state::Data for Arc { fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result> { self.roomid_shortstatehash .get(room_id.as_bytes())? @@ -48,7 +49,7 @@ impl service::rooms::state::Data for KeyValueDatabase { fn set_forward_extremities<'a>( &self, room_id: &RoomId, - event_ids: impl IntoIterator + Debug, + event_ids: &mut dyn Iterator, _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex ) -> Result<()> { let mut prefix = room_id.as_bytes().to_vec(); diff --git a/src/database/key_value/rooms/state_accessor.rs b/src/database/key_value/rooms/state_accessor.rs index 4d5bd4a1..9af45db3 100644 --- a/src/database/key_value/rooms/state_accessor.rs +++ b/src/database/key_value/rooms/state_accessor.rs @@ -5,7 +5,7 @@ use async_trait::async_trait; use ruma::{EventId, events::StateEventType, RoomId}; #[async_trait] -impl service::rooms::state_accessor::Data for KeyValueDatabase { +impl service::rooms::state_accessor::Data for Arc { async fn state_full_ids(&self, shortstatehash: u64) -> Result>> { let full_state = services().rooms.state_compressor .load_shortstatehash_info(shortstatehash)? diff --git a/src/database/key_value/rooms/state_cache.rs b/src/database/key_value/rooms/state_cache.rs index 5f054858..bdb8cf81 100644 --- a/src/database/key_value/rooms/state_cache.rs +++ b/src/database/key_value/rooms/state_cache.rs @@ -1,8 +1,10 @@ +use std::sync::Arc; + use ruma::{UserId, RoomId, events::{AnyStrippedStateEvent, AnySyncStateEvent}, serde::Raw}; use crate::{service, database::KeyValueDatabase, services, Result}; -impl service::rooms::state_cache::Data for KeyValueDatabase { +impl service::rooms::state_cache::Data for Arc { fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { let mut userroom_id = user_id.as_bytes().to_vec(); userroom_id.push(0xff); diff --git a/src/database/key_value/rooms/state_compressor.rs b/src/database/key_value/rooms/state_compressor.rs index aee1890c..e1c0280b 100644 --- a/src/database/key_value/rooms/state_compressor.rs +++ b/src/database/key_value/rooms/state_compressor.rs @@ -1,8 +1,8 @@ -use std::{collections::HashSet, mem::size_of}; +use std::{collections::HashSet, mem::size_of, sync::Arc}; use crate::{service::{self, rooms::state_compressor::data::StateDiff}, database::KeyValueDatabase, Error, utils, Result}; -impl service::rooms::state_compressor::Data for KeyValueDatabase { +impl service::rooms::state_compressor::Data for Arc { fn get_statediff(&self, shortstatehash: u64) -> Result { let value = self .shortstatehash_statediff diff --git a/src/database/key_value/rooms/timeline.rs b/src/database/key_value/rooms/timeline.rs index a3b6c17d..2d334b96 100644 --- a/src/database/key_value/rooms/timeline.rs +++ b/src/database/key_value/rooms/timeline.rs @@ -5,7 +5,7 @@ use tracing::error; use crate::{service, database::KeyValueDatabase, utils, Error, PduEvent, Result, services}; -impl service::rooms::timeline::Data for KeyValueDatabase { +impl service::rooms::timeline::Data for Arc { fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { match self .lasttimelinecount_cache diff --git a/src/database/key_value/rooms/user.rs b/src/database/key_value/rooms/user.rs index 66681e3c..4d20b00a 100644 --- a/src/database/key_value/rooms/user.rs +++ b/src/database/key_value/rooms/user.rs @@ -1,8 +1,10 @@ +use std::sync::Arc; + use ruma::{UserId, RoomId}; use crate::{service, database::KeyValueDatabase, utils, Error, Result, services}; -impl service::rooms::user::Data for KeyValueDatabase { +impl service::rooms::user::Data for Arc { fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { let mut userroom_id = user_id.as_bytes().to_vec(); userroom_id.push(0xff); diff --git a/src/database/key_value/transaction_ids.rs b/src/database/key_value/transaction_ids.rs index a63b3c5d..7fa69081 100644 --- a/src/database/key_value/transaction_ids.rs +++ b/src/database/key_value/transaction_ids.rs @@ -1,8 +1,10 @@ +use std::sync::Arc; + use ruma::{UserId, DeviceId, TransactionId}; use crate::{service, database::KeyValueDatabase, Result}; -impl service::transaction_ids::Data for KeyValueDatabase { +impl service::transaction_ids::Data for Arc { fn add_txnid( &self, user_id: &UserId, diff --git a/src/database/key_value/uiaa.rs b/src/database/key_value/uiaa.rs index cf242dec..8752e55a 100644 --- a/src/database/key_value/uiaa.rs +++ b/src/database/key_value/uiaa.rs @@ -1,8 +1,10 @@ +use std::sync::Arc; + use ruma::{UserId, DeviceId, signatures::CanonicalJsonValue, api::client::{uiaa::UiaaInfo, error::ErrorKind}}; use crate::{database::KeyValueDatabase, service, Error, Result}; -impl service::uiaa::Data for KeyValueDatabase { +impl service::uiaa::Data for Arc { fn set_uiaa_request( &self, user_id: &UserId, diff --git a/src/database/key_value/users.rs b/src/database/key_value/users.rs index 338d8800..1ac85b36 100644 --- a/src/database/key_value/users.rs +++ b/src/database/key_value/users.rs @@ -1,11 +1,11 @@ -use std::{mem::size_of, collections::BTreeMap}; +use std::{mem::size_of, collections::BTreeMap, sync::Arc}; use ruma::{api::client::{filter::IncomingFilterDefinition, error::ErrorKind, device::Device}, UserId, RoomAliasId, MxcUri, DeviceId, MilliSecondsSinceUnixEpoch, DeviceKeyId, encryption::{OneTimeKey, CrossSigningKey, DeviceKeys}, serde::Raw, events::{AnyToDeviceEvent, StateEventType}, DeviceKeyAlgorithm, UInt}; use tracing::warn; use crate::{service::{self, users::clean_signatures}, database::KeyValueDatabase, Error, utils, services, Result}; -impl service::users::Data for KeyValueDatabase { +impl service::users::Data for Arc { /// Check if a user has an account on this homeserver. fn exists(&self, user_id: &UserId) -> Result { Ok(self.userid_password.get(user_id.as_bytes())?.is_some()) @@ -687,10 +687,10 @@ impl service::users::Data for KeyValueDatabase { }) } - fn get_master_key bool>( + fn get_master_key( &self, user_id: &UserId, - allowed_signatures: F, + allowed_signatures: &dyn Fn(&UserId) -> bool, ) -> Result>> { self.userid_masterkeyid .get(user_id.as_bytes())? @@ -708,10 +708,10 @@ impl service::users::Data for KeyValueDatabase { }) } - fn get_self_signing_key bool>( + fn get_self_signing_key( &self, user_id: &UserId, - allowed_signatures: F, + allowed_signatures: &dyn Fn(&UserId) -> bool, ) -> Result>> { self.userid_selfsigningkeyid .get(user_id.as_bytes())? diff --git a/src/database/mod.rs b/src/database/mod.rs index aa5c5839..35922f0b 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -402,10 +402,10 @@ impl KeyValueDatabase { }); - let services_raw = Services::build(Arc::clone(&db)); + let services_raw = Box::new(Services::build(Arc::clone(&db))); // This is the first and only time we initialize the SERVICE static - *SERVICES.write().unwrap() = Some(services_raw); + *SERVICES.write().unwrap() = Some(Box::leak(services_raw)); // Matrix resource ownership is based on the server name; changing it @@ -877,105 +877,6 @@ impl KeyValueDatabase { services().globals.rotate.fire(); } - pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) { - let userid_bytes = user_id.as_bytes().to_vec(); - let mut userid_prefix = userid_bytes.clone(); - userid_prefix.push(0xff); - - let mut userdeviceid_prefix = userid_prefix.clone(); - userdeviceid_prefix.extend_from_slice(device_id.as_bytes()); - userdeviceid_prefix.push(0xff); - - let mut futures = FuturesUnordered::new(); - - // Return when *any* user changed his key - // TODO: only send for user they share a room with - futures.push( - self.todeviceid_events - .watch_prefix(&userdeviceid_prefix), - ); - - futures.push(self.userroomid_joined.watch_prefix(&userid_prefix)); - futures.push( - self.userroomid_invitestate - .watch_prefix(&userid_prefix), - ); - futures.push(self.userroomid_leftstate.watch_prefix(&userid_prefix)); - futures.push( - self.userroomid_notificationcount - .watch_prefix(&userid_prefix), - ); - futures.push( - self.userroomid_highlightcount - .watch_prefix(&userid_prefix), - ); - - // Events for rooms we are in - for room_id in services().rooms.state_cache.rooms_joined(user_id).filter_map(|r| r.ok()) { - let short_roomid = services() - .rooms - .short - .get_shortroomid(&room_id) - .ok() - .flatten() - .expect("room exists") - .to_be_bytes() - .to_vec(); - - let roomid_bytes = room_id.as_bytes().to_vec(); - let mut roomid_prefix = roomid_bytes.clone(); - roomid_prefix.push(0xff); - - // PDUs - futures.push(self.pduid_pdu.watch_prefix(&short_roomid)); - - // EDUs - futures.push( - self.roomid_lasttypingupdate - .watch_prefix(&roomid_bytes), - ); - - futures.push( - self.readreceiptid_readreceipt - .watch_prefix(&roomid_prefix), - ); - - // Key changes - futures.push(self.keychangeid_userid.watch_prefix(&roomid_prefix)); - - // Room account data - let mut roomuser_prefix = roomid_prefix.clone(); - roomuser_prefix.extend_from_slice(&userid_prefix); - - futures.push( - self.roomusertype_roomuserdataid - .watch_prefix(&roomuser_prefix), - ); - } - - let mut globaluserdata_prefix = vec![0xff]; - globaluserdata_prefix.extend_from_slice(&userid_prefix); - - futures.push( - self.roomusertype_roomuserdataid - .watch_prefix(&globaluserdata_prefix), - ); - - // More key changes (used when user is not joined to any rooms) - futures.push(self.keychangeid_userid.watch_prefix(&userid_prefix)); - - // One time keys - futures.push( - self.userid_lastonetimekeyupdate - .watch_prefix(&userid_bytes), - ); - - futures.push(Box::pin(services().globals.rotate.watch())); - - // Wait until one of them finds something - futures.next().await; - } - #[tracing::instrument(skip(self))] pub fn flush(&self) -> Result<()> { let start = std::time::Instant::now(); @@ -1021,7 +922,7 @@ impl KeyValueDatabase { } let start = Instant::now(); - if let Err(e) = services().globals.db._db.cleanup() { + if let Err(e) = services().globals.cleanup() { error!("cleanup: Errored: {}", e); } else { info!("cleanup: Finished in {:?}", start.elapsed()); @@ -1048,9 +949,9 @@ fn set_emergency_access() -> Result { None, &conduit_user, GlobalAccountDataEventType::PushRules.to_string().into(), - &GlobalAccountDataEvent { + &serde_json::to_value(&GlobalAccountDataEvent { content: PushRulesEventContent { global: ruleset }, - }, + }).expect("to json value always works"), )?; res diff --git a/src/lib.rs b/src/lib.rs index 75cf6c7e..c103d529 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,9 +20,9 @@ pub use utils::error::{Error, Result}; pub use service::{Services, pdu::PduEvent}; pub use api::ruma_wrapper::{Ruma, RumaResponse}; -pub static SERVICES: RwLock>> = RwLock::new(None); +pub static SERVICES: RwLock> = RwLock::new(None); -pub fn services<'a>() -> Arc { - Arc::clone(&SERVICES.read().unwrap()) +pub fn services<'a>() -> &'static Services { + &SERVICES.read().unwrap().expect("SERVICES should be initialized when this is called") } diff --git a/src/service/account_data/data.rs b/src/service/account_data/data.rs index 0f8e0bf5..65780a69 100644 --- a/src/service/account_data/data.rs +++ b/src/service/account_data/data.rs @@ -1,26 +1,25 @@ use std::collections::HashMap; use ruma::{UserId, RoomId, events::{RoomAccountDataEventType, AnyEphemeralRoomEvent}, serde::Raw}; -use serde::{Serialize, de::DeserializeOwned}; use crate::Result; -pub trait Data { +pub trait Data: Send + Sync { /// Places one event in the account data of the user and removes the previous entry. - fn update( + fn update( &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType, - data: &T, + data: &serde_json::Value, ) -> Result<()>; /// Searches the account data for a specific kind. - fn get( + fn get( &self, room_id: Option<&RoomId>, user_id: &UserId, kind: RoomAccountDataEventType, - ) -> Result>; + ) -> Result>>; /// Returns all changes to the account data that happened after `since`. fn changes_since( diff --git a/src/service/account_data/mod.rs b/src/service/account_data/mod.rs index 35ca1495..9785478b 100644 --- a/src/service/account_data/mod.rs +++ b/src/service/account_data/mod.rs @@ -24,24 +24,24 @@ pub struct Service { impl Service { /// Places one event in the account data of the user and removes the previous entry. #[tracing::instrument(skip(self, room_id, user_id, event_type, data))] - pub fn update( + pub fn update( &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType, - data: &T, + data: &serde_json::Value, ) -> Result<()> { self.db.update(room_id, user_id, event_type, data) } /// Searches the account data for a specific kind. #[tracing::instrument(skip(self, room_id, user_id, event_type))] - pub fn get( + pub fn get( &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType, - ) -> Result> { + ) -> Result>> { self.db.get(room_id, user_id, event_type) } diff --git a/src/service/admin/mod.rs b/src/service/admin/mod.rs index 48f828fc..32a709c1 100644 --- a/src/service/admin/mod.rs +++ b/src/service/admin/mod.rs @@ -28,7 +28,7 @@ use ruma::{ use serde_json::value::to_raw_value; use tokio::sync::{mpsc, MutexGuard, RwLock, RwLockReadGuard}; -use crate::{Result, services, Error, api::{server_server, client_server::AUTO_GEN_PASSWORD_LENGTH}, PduEvent, utils::{HtmlEscape, self}}; +use crate::{Result, services, Error, api::{server_server, client_server::{AUTO_GEN_PASSWORD_LENGTH, leave_all_rooms}}, PduEvent, utils::{HtmlEscape, self}}; use super::pdu::PduBuilder; @@ -179,7 +179,8 @@ impl Service { let conduit_room = services() .rooms - .id_from_alias( + .alias + .resolve_local_alias( format!("#admins:{}", services().globals.server_name()) .as_str() .try_into() @@ -221,7 +222,7 @@ impl Service { .roomid_mutex_state .write() .unwrap() - .entry(conduit_room.clone()) + .entry(conduit_room.to_owned()) .or_default(), ); @@ -599,11 +600,11 @@ impl Service { ruma::events::GlobalAccountDataEventType::PushRules .to_string() .into(), - &ruma::events::push_rules::PushRulesEvent { + &serde_json::to_value(ruma::events::push_rules::PushRulesEvent { content: ruma::events::push_rules::PushRulesEventContent { global: ruma::push::Ruleset::server_default(&user_id), }, - }, + }).expect("to json value always works"), )?; // we dont add a device since we're not the user, just the creator @@ -614,12 +615,14 @@ impl Service { )) } AdminCommand::DisableRoom { room_id } => { - services().rooms.disabledroomids.insert(room_id.as_bytes(), &[])?; - RoomMessageEventContent::text_plain("Room disabled.") + todo!(); + //services().rooms.disabledroomids.insert(room_id.as_bytes(), &[])?; + //RoomMessageEventContent::text_plain("Room disabled.") } AdminCommand::EnableRoom { room_id } => { - services().rooms.disabledroomids.remove(room_id.as_bytes())?; - RoomMessageEventContent::text_plain("Room enabled.") + todo!(); + //services().rooms.disabledroomids.remove(room_id.as_bytes())?; + //RoomMessageEventContent::text_plain("Room enabled.") } AdminCommand::DeactivateUser { leave_rooms, @@ -635,7 +638,7 @@ impl Service { services().users.deactivate_account(&user_id)?; if leave_rooms { - services().rooms.leave_all_rooms(&user_id).await?; + leave_all_rooms(&user_id).await?; } RoomMessageEventContent::text_plain(format!( @@ -694,7 +697,7 @@ impl Service { if leave_rooms { for &user_id in &user_ids { - let _ = services().rooms.leave_all_rooms(user_id).await; + let _ = leave_all_rooms(user_id).await; } } @@ -804,7 +807,7 @@ impl Service { pub(crate) async fn create_admin_room(&self) -> Result<()> { let room_id = RoomId::new(services().globals.server_name()); - services().rooms.get_or_create_shortroomid(&room_id)?; + services().rooms.short.get_or_create_shortroomid(&room_id)?; let mutex_state = Arc::clone( services().globals diff --git a/src/service/appservice/data.rs b/src/service/appservice/data.rs index a70bf9c1..744f0f94 100644 --- a/src/service/appservice/data.rs +++ b/src/service/appservice/data.rs @@ -1,6 +1,6 @@ use crate::Result; -pub trait Data { +pub trait Data: Send + Sync { /// Registers an appservice and returns the ID to the caller fn register_appservice(&self, yaml: serde_yaml::Value) -> Result; diff --git a/src/service/appservice/mod.rs b/src/service/appservice/mod.rs index 1a5ce50c..ad5ab4aa 100644 --- a/src/service/appservice/mod.rs +++ b/src/service/appservice/mod.rs @@ -1,10 +1,12 @@ mod data; +use std::sync::Arc; + pub use data::Data; use crate::Result; pub struct Service { - db: Box, + db: Arc, } impl Service { diff --git a/src/service/globals/data.rs b/src/service/globals/data.rs index f36ab61b..0f74b2a7 100644 --- a/src/service/globals/data.rs +++ b/src/service/globals/data.rs @@ -1,8 +1,30 @@ -use ruma::signatures::Ed25519KeyPair; +use std::collections::BTreeMap; + +use async_trait::async_trait; +use ruma::{signatures::Ed25519KeyPair, DeviceId, UserId, ServerName, api::federation::discovery::{ServerSigningKeys, VerifyKey}, ServerSigningKeyId}; use crate::Result; -pub trait Data { +#[async_trait] +pub trait Data: Send + Sync { + fn next_count(&self) -> Result; + fn current_count(&self) -> Result; + async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()>; + fn cleanup(&self) -> Result<()>; + fn memory_usage(&self) -> Result; fn load_keypair(&self) -> Result; fn remove_keypair(&self) -> Result<()>; + fn add_signing_key( + &self, + origin: &ServerName, + new_keys: ServerSigningKeys, + ) -> Result, VerifyKey>>; + + /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found for the server. + fn signing_keys_for( + &self, + origin: &ServerName, + ) -> Result, VerifyKey>>; + fn database_version(&self) -> Result; + fn bump_database_version(&self, new_version: u64) -> Result<()>; } diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index 48d7b064..8fd69dfe 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -26,8 +26,6 @@ use tokio::sync::{broadcast, watch::Receiver, Mutex as TokioMutex, Semaphore}; use tracing::error; use trust_dns_resolver::TokioAsyncResolver; -pub const COUNTER: &[u8] = b"c"; - type WellKnownMap = HashMap, (FedDest, String)>; type TlsNameMap = HashMap, u16)>; type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries @@ -198,16 +196,24 @@ impl Service { #[tracing::instrument(skip(self))] pub fn next_count(&self) -> Result { - utils::u64_from_bytes(&self.globals.increment(COUNTER)?) - .map_err(|_| Error::bad_database("Count has invalid bytes.")) + self.db.next_count() } #[tracing::instrument(skip(self))] pub fn current_count(&self) -> Result { - self.globals.get(COUNTER)?.map_or(Ok(0_u64), |bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Count has invalid bytes.")) - }) + self.db.current_count() + } + + pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { + self.db.watch(user_id, device_id).await + } + + pub fn cleanup(&self) -> Result<()> { + self.db.cleanup() + } + + pub fn memory_usage(&self) -> Result { + self.db.memory_usage() } pub fn server_name(&self) -> &ServerName { @@ -296,38 +302,7 @@ impl Service { origin: &ServerName, new_keys: ServerSigningKeys, ) -> Result, VerifyKey>> { - // Not atomic, but this is not critical - let signingkeys = self.server_signingkeys.get(origin.as_bytes())?; - - let mut keys = signingkeys - .and_then(|keys| serde_json::from_slice(&keys).ok()) - .unwrap_or_else(|| { - // Just insert "now", it doesn't matter - ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now()) - }); - - let ServerSigningKeys { - verify_keys, - old_verify_keys, - .. - } = new_keys; - - keys.verify_keys.extend(verify_keys.into_iter()); - keys.old_verify_keys.extend(old_verify_keys.into_iter()); - - self.server_signingkeys.insert( - origin.as_bytes(), - &serde_json::to_vec(&keys).expect("serversigningkeys can be serialized"), - )?; - - let mut tree = keys.verify_keys; - tree.extend( - keys.old_verify_keys - .into_iter() - .map(|old| (old.0, VerifyKey::new(old.1.key))), - ); - - Ok(tree) + self.db.add_signing_key(origin, new_keys) } /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found for the server. @@ -335,35 +310,15 @@ impl Service { &self, origin: &ServerName, ) -> Result, VerifyKey>> { - let signingkeys = self - .server_signingkeys - .get(origin.as_bytes())? - .and_then(|bytes| serde_json::from_slice(&bytes).ok()) - .map(|keys: ServerSigningKeys| { - let mut tree = keys.verify_keys; - tree.extend( - keys.old_verify_keys - .into_iter() - .map(|old| (old.0, VerifyKey::new(old.1.key))), - ); - tree - }) - .unwrap_or_else(BTreeMap::new); - - Ok(signingkeys) + self.db.signing_keys_for(origin) } pub fn database_version(&self) -> Result { - self.globals.get(b"version")?.map_or(Ok(0), |version| { - utils::u64_from_bytes(&version) - .map_err(|_| Error::bad_database("Database version id is invalid.")) - }) + self.db.database_version() } pub fn bump_database_version(&self, new_version: u64) -> Result<()> { - self.globals - .insert(b"version", &new_version.to_be_bytes())?; - Ok(()) + self.db.bump_database_version(new_version) } pub fn get_media_folder(&self) -> PathBuf { diff --git a/src/service/key_backups/data.rs b/src/service/key_backups/data.rs index 6f6359eb..226b1e16 100644 --- a/src/service/key_backups/data.rs +++ b/src/service/key_backups/data.rs @@ -3,7 +3,7 @@ use std::collections::BTreeMap; use ruma::{api::client::backup::{BackupAlgorithm, RoomKeyBackup, KeyBackupData}, serde::Raw, UserId, RoomId}; use crate::Result; -pub trait Data { +pub trait Data: Send + Sync { fn create_backup( &self, user_id: &UserId, diff --git a/src/service/media/data.rs b/src/service/media/data.rs index 94975de7..2e24049a 100644 --- a/src/service/media/data.rs +++ b/src/service/media/data.rs @@ -1,7 +1,7 @@ use crate::Result; -pub trait Data { - fn create_file_metadata(&self, mxc: String, width: u32, height: u32, content_disposition: &Option<&str>, content_type: &Option<&str>) -> Result>; +pub trait Data: Send + Sync { + fn create_file_metadata(&self, mxc: String, width: u32, height: u32, content_disposition: Option<&str>, content_type: Option<&str>) -> Result>; /// Returns content_disposition, content_type and the metadata key. fn search_file_metadata(&self, mxc: String, width: u32, height: u32) -> Result<(Option, Option, Vec)>; diff --git a/src/service/media/mod.rs b/src/service/media/mod.rs index d61292bb..f86251fa 100644 --- a/src/service/media/mod.rs +++ b/src/service/media/mod.rs @@ -24,8 +24,8 @@ impl Service { pub async fn create( &self, mxc: String, - content_disposition: &Option<&str>, - content_type: &Option<&str>, + content_disposition: Option<&str>, + content_type: Option<&str>, file: &[u8], ) -> Result<()> { // Width, Height = 0 if it's not a thumbnail @@ -42,8 +42,8 @@ impl Service { pub async fn upload_thumbnail( &self, mxc: String, - content_disposition: &Option<&str>, - content_type: &Option<&str>, + content_disposition: Option<&str>, + content_type: Option<&str>, width: u32, height: u32, file: &[u8], @@ -108,7 +108,7 @@ impl Service { .thumbnail_properties(width, height) .unwrap_or((0, 0, false)); // 0, 0 because that's the original file - if let Ok((content_disposition, content_type, key)) = self.db.search_file_metadata(mxc, width, height) { + if let Ok((content_disposition, content_type, key)) = self.db.search_file_metadata(mxc.clone(), width, height) { // Using saved thumbnail let path = services().globals.get_media_file(&key); let mut file = Vec::new(); @@ -119,7 +119,7 @@ impl Service { content_type, file: file.to_vec(), })) - } else if let Ok((content_disposition, content_type, key)) = self.db.search_file_metadata(mxc, 0, 0) { + } else if let Ok((content_disposition, content_type, key)) = self.db.search_file_metadata(mxc.clone(), 0, 0) { // Generate a thumbnail let path = services().globals.get_media_file(&key); let mut file = Vec::new(); @@ -180,7 +180,7 @@ impl Service { thumbnail.write_to(&mut thumbnail_bytes, image::ImageOutputFormat::Png)?; // Save thumbnail in database so we don't have to generate it again next time - let thumbnail_key = self.db.create_file_metadata(mxc, width, height, content_disposition, content_type)?; + let thumbnail_key = self.db.create_file_metadata(mxc, width, height, content_disposition.as_deref(), content_type.as_deref())?; let path = services().globals.get_media_file(&thumbnail_key); let mut f = File::create(path).await?; diff --git a/src/service/mod.rs b/src/service/mod.rs index 47d4651d..a1a728c5 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -30,20 +30,20 @@ pub struct Services { } impl Services { - pub fn build(db: Arc) { + pub fn build(db: Arc) -> Self { Self { - appservice: appservice::Service { db: Arc::clone(&db) }, - pusher: appservice::Service { db: Arc::clone(&db) }, - rooms: appservice::Service { db: Arc::clone(&db) }, - transaction_ids: appservice::Service { db: Arc::clone(&db) }, - uiaa: appservice::Service { db: Arc::clone(&db) }, - users: appservice::Service { db: Arc::clone(&db) }, - account_data: appservice::Service { db: Arc::clone(&db) }, - admin: appservice::Service { db: Arc::clone(&db) }, - globals: appservice::Service { db: Arc::clone(&db) }, - key_backups: appservice::Service { db: Arc::clone(&db) }, - media: appservice::Service { db: Arc::clone(&db) }, - sending: appservice::Service { db: Arc::clone(&db) }, + appservice: appservice::Service { db: db.clone() }, + pusher: pusher::Service { db: db.clone() }, + rooms: rooms::Service { db: Arc::clone(&db) }, + transaction_ids: transaction_ids::Service { db: Arc::clone(&db) }, + uiaa: uiaa::Service { db: Arc::clone(&db) }, + users: users::Service { db: Arc::clone(&db) }, + account_data: account_data::Service { db: Arc::clone(&db) }, + admin: admin::Service { db: Arc::clone(&db) }, + globals: globals::Service { db: Arc::clone(&db) }, + key_backups: key_backups::Service { db: Arc::clone(&db) }, + media: media::Service { db: Arc::clone(&db) }, + sending: sending::Service { db: Arc::clone(&db) }, } } } diff --git a/src/service/pdu.rs b/src/service/pdu.rs index 2ed79f2c..3be3300c 100644 --- a/src/service/pdu.rs +++ b/src/service/pdu.rs @@ -343,7 +343,7 @@ pub(crate) fn gen_event_id_canonical_json( .and_then(|id| RoomId::parse(id.as_str()?).ok()) .ok_or_else(|| Error::bad_database("PDU in db has invalid room_id."))?; - let room_version_id = services().rooms.get_room_version(&room_id); + let room_version_id = services().rooms.state.get_room_version(&room_id); let event_id = format!( "${}", diff --git a/src/service/pusher/data.rs b/src/service/pusher/data.rs index 3951da79..305a5383 100644 --- a/src/service/pusher/data.rs +++ b/src/service/pusher/data.rs @@ -1,7 +1,7 @@ use ruma::{UserId, api::client::push::{set_pusher, get_pushers}}; use crate::Result; -pub trait Data { +pub trait Data: Send + Sync { fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::Pusher) -> Result<()>; fn get_pusher(&self, senderkey: &[u8]) -> Result>; diff --git a/src/service/pusher/mod.rs b/src/service/pusher/mod.rs index af30ca47..e65c57ab 100644 --- a/src/service/pusher/mod.rs +++ b/src/service/pusher/mod.rs @@ -3,6 +3,7 @@ pub use data::Data; use crate::{services, Error, PduEvent, Result}; use bytes::BytesMut; +use ruma::api::IncomingResponse; use ruma::{ api::{ client::push::{get_pushers, set_pusher, PusherKind}, @@ -20,11 +21,12 @@ use ruma::{ serde::Raw, uint, RoomId, UInt, UserId, }; +use std::sync::Arc; use std::{fmt::Debug, mem}; use tracing::{error, info, warn}; pub struct Service { - db: Box, + db: Arc, } impl Service { @@ -47,8 +49,9 @@ impl Service { self.db.get_pusher_senderkeys(sender) } - #[tracing::instrument(skip(destination, request))] + #[tracing::instrument(skip(self, destination, request))] pub async fn send_request( + &self, destination: &str, request: T, ) -> Result @@ -124,7 +127,7 @@ impl Service { } } - #[tracing::instrument(skip(user, unread, pusher, ruleset, pdu))] + #[tracing::instrument(skip(self, user, unread, pusher, ruleset, pdu))] pub async fn send_push_notice( &self, user: &UserId, @@ -181,7 +184,7 @@ impl Service { Ok(()) } - #[tracing::instrument(skip(user, ruleset, pdu))] + #[tracing::instrument(skip(self, user, ruleset, pdu))] pub fn get_actions<'a>( &self, user: &UserId, @@ -204,7 +207,7 @@ impl Service { Ok(ruleset.get_actions(pdu, &ctx)) } - #[tracing::instrument(skip(unread, pusher, tweaks, event))] + #[tracing::instrument(skip(self, unread, pusher, tweaks, event))] async fn send_notice( &self, unread: UInt, diff --git a/src/service/rooms/alias/data.rs b/src/service/rooms/alias/data.rs index 81022096..26bffae2 100644 --- a/src/service/rooms/alias/data.rs +++ b/src/service/rooms/alias/data.rs @@ -1,7 +1,7 @@ use ruma::{RoomId, RoomAliasId}; use crate::Result; -pub trait Data { +pub trait Data: Send + Sync { /// Creates or updates the alias to the given room id. fn set_alias( &self, diff --git a/src/service/rooms/auth_chain/data.rs b/src/service/rooms/auth_chain/data.rs index e4e8550b..13fac2dc 100644 --- a/src/service/rooms/auth_chain/data.rs +++ b/src/service/rooms/auth_chain/data.rs @@ -1,7 +1,7 @@ -use std::collections::HashSet; +use std::{collections::HashSet, sync::Arc}; use crate::Result; -pub trait Data { - fn get_cached_eventid_authchain(&self, shorteventid: u64) -> Result>>; - fn cache_eventid_authchain(&self, shorteventid: u64, auth_chain: &HashSet) -> Result<()>; +pub trait Data: Send + Sync { + fn get_cached_eventid_authchain(&self, shorteventid: &[u64]) -> Result>>>; + fn cache_auth_chain(&self, shorteventid: Vec, auth_chain: Arc>) -> Result<()>; } diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs index 26a3f3f0..5fe0e3e8 100644 --- a/src/service/rooms/auth_chain/mod.rs +++ b/src/service/rooms/auth_chain/mod.rs @@ -15,41 +15,11 @@ impl Service { &'a self, key: &[u64], ) -> Result>>> { - // Check RAM cache - if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key.to_be_bytes()) { - return Ok(Some(Arc::clone(result))); - } - - // We only save auth chains for single events in the db - if key.len() == 1 { - // Check DB cache - if let Some(chain) = self.db.get_cached_eventid_authchain(key[0]) - { - let chain = Arc::new(chain); - - // Cache in RAM - self.auth_chain_cache - .lock() - .unwrap() - .insert(vec![key[0]], Arc::clone(&chain)); - - return Ok(Some(chain)); - } - } - - Ok(None) + self.db.get_cached_eventid_authchain(key) } #[tracing::instrument(skip(self))] pub fn cache_auth_chain(&self, key: Vec, auth_chain: Arc>) -> Result<()> { - // Only persist single events in db - if key.len() == 1 { - self.db.cache_auth_chain(key[0], auth_chain)?; - } - - // Cache in RAM - self.auth_chain_cache.lock().unwrap().insert(key, auth_chain); - - Ok(()) + self.db.cache_auth_chain(key, auth_chain) } } diff --git a/src/service/rooms/directory/data.rs b/src/service/rooms/directory/data.rs index 13767217..b4e020d7 100644 --- a/src/service/rooms/directory/data.rs +++ b/src/service/rooms/directory/data.rs @@ -1,7 +1,7 @@ use ruma::RoomId; use crate::Result; -pub trait Data { +pub trait Data: Send + Sync { /// Adds the room to the public room directory fn set_public(&self, room_id: &RoomId) -> Result<()>; diff --git a/src/service/rooms/edus/presence/data.rs b/src/service/rooms/edus/presence/data.rs index ca0e2410..f7592555 100644 --- a/src/service/rooms/edus/presence/data.rs +++ b/src/service/rooms/edus/presence/data.rs @@ -3,7 +3,7 @@ use std::collections::HashMap; use ruma::{UserId, RoomId, events::presence::PresenceEvent}; use crate::Result; -pub trait Data { +pub trait Data: Send + Sync { /// Adds a presence event which will be saved until a new event replaces it. /// /// Note: This method takes a RoomId because presence updates are always bound to rooms to diff --git a/src/service/rooms/edus/read_receipt/data.rs b/src/service/rooms/edus/read_receipt/data.rs index e8ed9656..5ebd89d6 100644 --- a/src/service/rooms/edus/read_receipt/data.rs +++ b/src/service/rooms/edus/read_receipt/data.rs @@ -1,7 +1,7 @@ use ruma::{RoomId, events::receipt::ReceiptEvent, UserId, serde::Raw}; use crate::Result; -pub trait Data { +pub trait Data: Send + Sync { /// Replaces the previous read receipt. fn readreceipt_update( &self, diff --git a/src/service/rooms/edus/typing/data.rs b/src/service/rooms/edus/typing/data.rs index ec0be466..426d4e06 100644 --- a/src/service/rooms/edus/typing/data.rs +++ b/src/service/rooms/edus/typing/data.rs @@ -2,7 +2,7 @@ use std::collections::HashSet; use crate::Result; use ruma::{UserId, RoomId}; -pub trait Data { +pub trait Data: Send + Sync { /// Sets a user as typing until the timeout timestamp is reached or roomtyping_remove is /// called. fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) -> Result<()>; diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index e2291126..ac3cca6a 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -117,7 +117,7 @@ impl Service { room_id, pub_key_map, incoming_pdu.prev_events.clone(), - ).await; + ).await?; let mut errors = 0; for prev_id in dbg!(sorted_prev_events) { @@ -240,7 +240,7 @@ impl Service { r } - #[tracing::instrument(skip(create_event, value, pub_key_map))] + #[tracing::instrument(skip(self, create_event, value, pub_key_map))] fn handle_outlier_pdu<'a>( &self, origin: &'a ServerName, @@ -272,7 +272,7 @@ impl Service { RoomVersion::new(room_version_id).expect("room version is supported"); let mut val = match ruma::signatures::verify_event( - &*pub_key_map.read().map_err(|_| "RwLock is poisoned.")?, + &*pub_key_map.read().expect("RwLock is poisoned."), &value, room_version_id, ) { @@ -301,7 +301,7 @@ impl Service { let incoming_pdu = serde_json::from_value::( serde_json::to_value(&val).expect("CanonicalJsonObj is a valid JsonValue"), ) - .map_err(|_| "Event is not a valid PDU.".to_owned())?; + .map_err(|_| Error::bad_database("Event is not a valid PDU."))?; // 4. fetch any missing auth events doing all checks listed here starting at 1. These are not timeline events // 5. Reject "due to auth events" if can't get all the auth events or some of the auth events are also rejected "due to auth events" @@ -329,7 +329,7 @@ impl Service { // Build map of auth events let mut auth_events = HashMap::new(); for id in &incoming_pdu.auth_events { - let auth_event = match services().rooms.get_pdu(id)? { + let auth_event = match services().rooms.timeline.get_pdu(id)? { Some(e) => e, None => { warn!("Could not find auth event {}", id); @@ -373,7 +373,8 @@ impl Service { &incoming_pdu, None::, // TODO: third party invite |k, s| auth_events.get(&(k.to_string().into(), s.to_owned())), - )? { + ).map_err(|_e| Error::BadRequest(ErrorKind::InvalidParam, "Auth check failed"))? + { return Err(Error::BadRequest( ErrorKind::InvalidParam, "Auth check failed", @@ -385,6 +386,7 @@ impl Service { // 7. Persist the event as an outlier. services() .rooms + .outlier .add_pdu_outlier(&incoming_pdu.event_id, &val)?; info!("Added pdu as outlier."); @@ -393,7 +395,7 @@ impl Service { }) } - #[tracing::instrument(skip(incoming_pdu, val, create_event, pub_key_map))] + #[tracing::instrument(skip(self, incoming_pdu, val, create_event, pub_key_map))] pub async fn upgrade_outlier_to_timeline_pdu( &self, incoming_pdu: Arc, @@ -412,7 +414,7 @@ impl Service { .rooms .pdu_metadata.is_event_soft_failed(&incoming_pdu.event_id)? { - return Err("Event has been soft failed".into()); + return Err(Error::BadRequest(ErrorKind::InvalidParam, "Event has been soft failed")); } info!("Upgrading {} to timeline pdu", incoming_pdu.event_id); @@ -1130,7 +1132,8 @@ impl Service { room_id: &RoomId, pub_key_map: &RwLock>>, initial_set: Vec>, - ) -> Vec<(Arc, HashMap, (Arc, BTreeMap)>)> { + ) -> Result<(Vec>, HashMap, +(Arc, BTreeMap)>)> { let mut graph: HashMap, _> = HashMap::new(); let mut eventid_info = HashMap::new(); let mut todo_outlier_stack: Vec> = initial_set; @@ -1164,6 +1167,7 @@ impl Service { if let Some(json) = json_opt.or_else(|| { services() .rooms + .outlier .get_outlier_pdu_json(&prev_event_id) .ok() .flatten() @@ -1209,9 +1213,9 @@ impl Service { .map_or_else(|| uint!(0), |info| info.0.origin_server_ts), ), )) - })?; + }).map_err(|_| Error::bad_database("Error sorting prev events"))?; - (sorted, eventid_info) + Ok((sorted, eventid_info)) } #[tracing::instrument(skip_all)] diff --git a/src/service/rooms/lazy_loading/data.rs b/src/service/rooms/lazy_loading/data.rs index f1019c13..524071c3 100644 --- a/src/service/rooms/lazy_loading/data.rs +++ b/src/service/rooms/lazy_loading/data.rs @@ -1,7 +1,7 @@ use ruma::{RoomId, DeviceId, UserId}; use crate::Result; -pub trait Data { +pub trait Data: Send + Sync { fn lazy_load_was_sent_before( &self, user_id: &UserId, diff --git a/src/service/rooms/metadata/data.rs b/src/service/rooms/metadata/data.rs index 9b1ce079..9444db41 100644 --- a/src/service/rooms/metadata/data.rs +++ b/src/service/rooms/metadata/data.rs @@ -1,6 +1,6 @@ use ruma::RoomId; use crate::Result; -pub trait Data { +pub trait Data: Send + Sync { fn exists(&self, room_id: &RoomId) -> Result; } diff --git a/src/service/rooms/outlier/data.rs b/src/service/rooms/outlier/data.rs index 17d0f7b4..edc7c4fd 100644 --- a/src/service/rooms/outlier/data.rs +++ b/src/service/rooms/outlier/data.rs @@ -2,7 +2,7 @@ use ruma::{signatures::CanonicalJsonObject, EventId}; use crate::{PduEvent, Result}; -pub trait Data { +pub trait Data: Send + Sync { fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result>; fn get_outlier_pdu(&self, event_id: &EventId) -> Result>; fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()>; diff --git a/src/service/rooms/pdu_metadata/data.rs b/src/service/rooms/pdu_metadata/data.rs index fb839023..9bc49cfb 100644 --- a/src/service/rooms/pdu_metadata/data.rs +++ b/src/service/rooms/pdu_metadata/data.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use ruma::{EventId, RoomId}; use crate::Result; -pub trait Data { +pub trait Data: Send + Sync { fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc]) -> Result<()>; fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result; fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()>; diff --git a/src/service/rooms/search/data.rs b/src/service/rooms/search/data.rs index b62904c1..0c14ffe6 100644 --- a/src/service/rooms/search/data.rs +++ b/src/service/rooms/search/data.rs @@ -1,7 +1,7 @@ use ruma::RoomId; use crate::Result; -pub trait Data { +pub trait Data: Send + Sync { fn index_pdu<'a>(&self, shortroomid: u64, pdu_id: &[u8], message_body: String) -> Result<()>; fn search_pdus<'a>( diff --git a/src/service/rooms/short/data.rs b/src/service/rooms/short/data.rs index 3b1c3117..bc2b28f0 100644 --- a/src/service/rooms/short/data.rs +++ b/src/service/rooms/short/data.rs @@ -1,2 +1,2 @@ -pub trait Data { +pub trait Data: Send + Sync { } diff --git a/src/service/rooms/state/data.rs b/src/service/rooms/state/data.rs index 7008d86f..20c177a2 100644 --- a/src/service/rooms/state/data.rs +++ b/src/service/rooms/state/data.rs @@ -1,9 +1,10 @@ use std::sync::Arc; -use std::{sync::MutexGuard, collections::HashSet}; +use std::collections::HashSet; use crate::Result; use ruma::{EventId, RoomId}; +use tokio::sync::MutexGuard; -pub trait Data { +pub trait Data: Send + Sync { /// Returns the last state hash key added to the db for the given room. fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result>; @@ -21,7 +22,7 @@ pub trait Data { /// Replace the forward extremities of the room. fn set_forward_extremities<'a>(&self, room_id: &RoomId, - event_ids: &dyn Iterator, + event_ids: &mut dyn Iterator, _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex ) -> Result<()>; } diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index 979060d9..53859785 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -16,7 +16,7 @@ pub struct Service { impl Service { /// Set the room to the given statehash and update caches. - pub fn force_state( + pub async fn force_state( &self, room_id: &RoomId, shortstatehash: u64, @@ -28,7 +28,7 @@ impl Service { .roomid_mutex_state .write() .unwrap() - .entry(body.room_id.to_owned()) + .entry(room_id.to_owned()) .or_default(), ); let state_lock = mutex_state.lock().await; @@ -74,10 +74,10 @@ impl Service { Err(_) => continue, }; - services().room.state_cache.update_membership(room_id, &user_id, membership, &pdu.sender, None, false)?; + services().rooms.state_cache.update_membership(room_id, &user_id, membership, &pdu.sender, None, false)?; } - services().room.state_cache.update_joined_count(room_id)?; + services().rooms.state_cache.update_joined_count(room_id)?; self.db.set_room_state(room_id, shortstatehash, &state_lock); diff --git a/src/service/rooms/state_accessor/data.rs b/src/service/rooms/state_accessor/data.rs index 48031e49..14f96bc8 100644 --- a/src/service/rooms/state_accessor/data.rs +++ b/src/service/rooms/state_accessor/data.rs @@ -6,7 +6,7 @@ use ruma::{EventId, events::StateEventType, RoomId}; use crate::{Result, PduEvent}; #[async_trait] -pub trait Data { +pub trait Data: Send + Sync { /// Builds a StateMap by iterating over all keys that start /// with state_hash, this gives the full state for the given state_hash. async fn state_full_ids(&self, shortstatehash: u64) -> Result>>; diff --git a/src/service/rooms/state_cache/data.rs b/src/service/rooms/state_cache/data.rs index b45b2ea0..b9db7217 100644 --- a/src/service/rooms/state_cache/data.rs +++ b/src/service/rooms/state_cache/data.rs @@ -1,7 +1,7 @@ use ruma::{UserId, RoomId, serde::Raw, events::AnyStrippedStateEvent}; use crate::Result; -pub trait Data { +pub trait Data: Send + Sync { fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; fn mark_as_invited(&self, user_id: &UserId, room_id: &RoomId, last_state: Option>>) -> Result<()>; diff --git a/src/service/rooms/state_compressor/data.rs b/src/service/rooms/state_compressor/data.rs index cd872422..ce164c6d 100644 --- a/src/service/rooms/state_compressor/data.rs +++ b/src/service/rooms/state_compressor/data.rs @@ -9,7 +9,7 @@ pub struct StateDiff { pub removed: HashSet, } -pub trait Data { +pub trait Data: Send + Sync { fn get_statediff(&self, shortstatehash: u64) -> Result; fn save_statediff(&self, shortstatehash: u64, diff: StateDiff) -> Result<()>; } diff --git a/src/service/rooms/timeline/data.rs b/src/service/rooms/timeline/data.rs index 85bedc69..d073e865 100644 --- a/src/service/rooms/timeline/data.rs +++ b/src/service/rooms/timeline/data.rs @@ -4,7 +4,7 @@ use ruma::{signatures::CanonicalJsonObject, EventId, UserId, RoomId}; use crate::{Result, PduEvent}; -pub trait Data { +pub trait Data: Send + Sync { fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result; /// Returns the `count` of this pdu's id. diff --git a/src/service/rooms/user/data.rs b/src/service/rooms/user/data.rs index a5657bc1..6b7ebc72 100644 --- a/src/service/rooms/user/data.rs +++ b/src/service/rooms/user/data.rs @@ -1,7 +1,7 @@ use ruma::{UserId, RoomId}; use crate::Result; -pub trait Data { +pub trait Data: Send + Sync { fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result; diff --git a/src/service/transaction_ids/data.rs b/src/service/transaction_ids/data.rs index 6e71dd46..c5ff05c0 100644 --- a/src/service/transaction_ids/data.rs +++ b/src/service/transaction_ids/data.rs @@ -1,7 +1,7 @@ use ruma::{DeviceId, UserId, TransactionId}; use crate::Result; -pub trait Data { +pub trait Data: Send + Sync { fn add_txnid( &self, user_id: &UserId, diff --git a/src/service/uiaa/data.rs b/src/service/uiaa/data.rs index d7fa79d2..091f0641 100644 --- a/src/service/uiaa/data.rs +++ b/src/service/uiaa/data.rs @@ -1,7 +1,7 @@ use ruma::{api::client::uiaa::UiaaInfo, DeviceId, UserId, signatures::CanonicalJsonValue}; use crate::Result; -pub trait Data { +pub trait Data: Send + Sync { fn set_uiaa_request( &self, user_id: &UserId, diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index 8adc9366..b13ae1f2 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -223,18 +223,18 @@ impl Service { self.db.get_device_keys(user_id, device_id) } - pub fn get_master_key bool>( + pub fn get_master_key( &self, user_id: &UserId, - allowed_signatures: F, + allowed_signatures: &dyn Fn(&UserId) -> bool, ) -> Result>> { self.db.get_master_key(user_id, allowed_signatures) } - pub fn get_self_signing_key bool>( + pub fn get_self_signing_key( &self, user_id: &UserId, - allowed_signatures: F, + allowed_signatures: &dyn Fn(&UserId) -> bool, ) -> Result>> { self.db.get_self_signing_key(user_id, allowed_signatures) }