From f56424bc8d8d582c52be91116ceb29d69791c563 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Sun, 7 Aug 2022 19:42:22 +0200 Subject: [PATCH] Refactor appservices, pusher, timeline, transactionids, users --- .../key_value}/appservice.rs | 24 +- src/database/key_value/pusher.rs | 302 +----- src/database/key_value/rooms/timeline.rs | 663 +----------- .../key_value}/transaction_ids.rs | 13 +- src/database/key_value/users.rs | 148 +-- src/service/appservice/data.rs | 17 + src/service/appservice/mod.rs | 36 + src/service/globals.rs | 14 +- src/service/pusher.rs | 348 ------- src/service/pusher/data.rs | 346 +------ src/service/pusher/mod.rs | 575 +++++------ src/service/rooms/short/mod.rs | 11 +- src/service/rooms/timeline/data.rs | 901 +--------------- src/service/rooms/timeline/mod.rs | 232 +---- src/service/transaction_ids/data.rs | 16 + src/service/transaction_ids/mod.rs | 44 + src/service/users/data.rs | 961 +----------------- src/service/users/mod.rs | 845 +-------------- 18 files changed, 546 insertions(+), 4950 deletions(-) rename src/{service => database/key_value}/appservice.rs (77%) rename src/{service => database/key_value}/transaction_ids.rs (77%) create mode 100644 src/service/appservice/data.rs create mode 100644 src/service/appservice/mod.rs delete mode 100644 src/service/pusher.rs create mode 100644 src/service/transaction_ids/data.rs create mode 100644 src/service/transaction_ids/mod.rs diff --git a/src/service/appservice.rs b/src/database/key_value/appservice.rs similarity index 77% rename from src/service/appservice.rs rename to src/database/key_value/appservice.rs index edd5009b..66a2a5c8 100644 --- a/src/service/appservice.rs +++ b/src/database/key_value/appservice.rs @@ -1,19 +1,5 @@ -use crate::{utils, Error, Result}; -use std::{ - collections::HashMap, - sync::{Arc, RwLock}, -}; - -use super::abstraction::Tree; - -pub struct Appservice { - pub(super) cached_registrations: Arc>>, - pub(super) id_appserviceregistrations: Arc, -} - -impl Appservice { +impl service::appservice::Data for KeyValueDatabase { /// Registers an appservice and returns the ID to the caller - /// pub fn register_appservice(&self, yaml: serde_yaml::Value) -> Result { // TODO: Rumaify let id = yaml.get("id").unwrap().as_str().unwrap(); @@ -34,7 +20,7 @@ impl Appservice { /// # Arguments /// /// * `service_name` - the name you send to register the service previously - pub fn unregister_appservice(&self, service_name: &str) -> Result<()> { + fn unregister_appservice(&self, service_name: &str) -> Result<()> { self.id_appserviceregistrations .remove(service_name.as_bytes())?; self.cached_registrations @@ -44,7 +30,7 @@ impl Appservice { Ok(()) } - pub fn get_registration(&self, id: &str) -> Result> { + fn get_registration(&self, id: &str) -> Result> { self.cached_registrations .read() .unwrap() @@ -66,14 +52,14 @@ impl Appservice { ) } - pub fn iter_ids(&self) -> Result> + '_> { + fn iter_ids(&self) -> Result> + '_> { Ok(self.id_appserviceregistrations.iter().map(|(id, _)| { utils::string_from_bytes(&id) .map_err(|_| Error::bad_database("Invalid id bytes in id_appserviceregistrations.")) })) } - pub fn all(&self) -> Result> { + fn all(&self) -> Result> { self.iter_ids()? .filter_map(|id| id.ok()) .map(move |id| { diff --git a/src/database/key_value/pusher.rs b/src/database/key_value/pusher.rs index 6b906c24..94374ab2 100644 --- a/src/database/key_value/pusher.rs +++ b/src/database/key_value/pusher.rs @@ -1,36 +1,5 @@ -use crate::{Database, Error, PduEvent, Result}; -use bytes::BytesMut; -use ruma::{ - api::{ - client::push::{get_pushers, set_pusher, PusherKind}, - push_gateway::send_event_notification::{ - self, - v1::{Device, Notification, NotificationCounts, NotificationPriority}, - }, - IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken, - }, - events::{ - room::{name::RoomNameEventContent, power_levels::RoomPowerLevelsEventContent}, - AnySyncRoomEvent, RoomEventType, StateEventType, - }, - push::{Action, PushConditionRoomCtx, PushFormat, Ruleset, Tweak}, - serde::Raw, - uint, RoomId, UInt, UserId, -}; -use tracing::{error, info, warn}; - -use std::{fmt::Debug, mem, sync::Arc}; - -use super::abstraction::Tree; - -pub struct PushData { - /// UserId + pushkey -> Pusher - pub(super) senderkey_pusher: Arc, -} - -impl PushData { - #[tracing::instrument(skip(self, sender, pusher))] - pub fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::Pusher) -> Result<()> { +impl service::pusher::Data for KeyValueDatabase { + fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::Pusher) -> Result<()> { let mut key = sender.as_bytes().to_vec(); key.push(0xff); key.extend_from_slice(pusher.pushkey.as_bytes()); @@ -52,8 +21,7 @@ impl PushData { Ok(()) } - #[tracing::instrument(skip(self, senderkey))] - pub fn get_pusher(&self, senderkey: &[u8]) -> Result> { + fn get_pusher(&self, senderkey: &[u8]) -> Result> { self.senderkey_pusher .get(senderkey)? .map(|push| { @@ -63,8 +31,7 @@ impl PushData { .transpose() } - #[tracing::instrument(skip(self, sender))] - pub fn get_pushers(&self, sender: &UserId) -> Result> { + fn get_pushers(&self, sender: &UserId) -> Result> { let mut prefix = sender.as_bytes().to_vec(); prefix.push(0xff); @@ -77,8 +44,7 @@ impl PushData { .collect() } - #[tracing::instrument(skip(self, sender))] - pub fn get_pusher_senderkeys<'a>( + fn get_pusher_senderkeys<'a>( &'a self, sender: &UserId, ) -> impl Iterator> + 'a { @@ -88,261 +54,3 @@ impl PushData { self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| k) } } - -#[tracing::instrument(skip(globals, destination, request))] -pub async fn send_request( - globals: &crate::database::globals::Globals, - destination: &str, - request: T, -) -> Result -where - T: Debug, -{ - let destination = destination.replace("/_matrix/push/v1/notify", ""); - - let http_request = request - .try_into_http_request::( - &destination, - SendAccessToken::IfRequired(""), - &[MatrixVersion::V1_0], - ) - .map_err(|e| { - warn!("Failed to find destination {}: {}", destination, e); - Error::BadServerResponse("Invalid destination") - })? - .map(|body| body.freeze()); - - let reqwest_request = reqwest::Request::try_from(http_request) - .expect("all http requests are valid reqwest requests"); - - // TODO: we could keep this very short and let expo backoff do it's thing... - //*reqwest_request.timeout_mut() = Some(Duration::from_secs(5)); - - let url = reqwest_request.url().clone(); - let response = globals.default_client().execute(reqwest_request).await; - - match response { - Ok(mut response) => { - // reqwest::Response -> http::Response conversion - let status = response.status(); - let mut http_response_builder = http::Response::builder() - .status(status) - .version(response.version()); - mem::swap( - response.headers_mut(), - http_response_builder - .headers_mut() - .expect("http::response::Builder is usable"), - ); - - let body = response.bytes().await.unwrap_or_else(|e| { - warn!("server error {}", e); - Vec::new().into() - }); // TODO: handle timeout - - if status != 200 { - info!( - "Push gateway returned bad response {} {}\n{}\n{:?}", - destination, - status, - url, - crate::utils::string_from_bytes(&body) - ); - } - - let response = T::IncomingResponse::try_from_http_response( - http_response_builder - .body(body) - .expect("reqwest body is valid http body"), - ); - response.map_err(|_| { - info!( - "Push gateway returned invalid response bytes {}\n{}", - destination, url - ); - Error::BadServerResponse("Push gateway returned bad response.") - }) - } - Err(e) => Err(e.into()), - } -} - -#[tracing::instrument(skip(user, unread, pusher, ruleset, pdu, db))] -pub async fn send_push_notice( - user: &UserId, - unread: UInt, - pusher: &get_pushers::v3::Pusher, - ruleset: Ruleset, - pdu: &PduEvent, - db: &Database, -) -> Result<()> { - let mut notify = None; - let mut tweaks = Vec::new(); - - let power_levels: RoomPowerLevelsEventContent = db - .rooms - .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")? - .map(|ev| { - serde_json::from_str(ev.content.get()) - .map_err(|_| Error::bad_database("invalid m.room.power_levels event")) - }) - .transpose()? - .unwrap_or_default(); - - for action in get_actions( - user, - &ruleset, - &power_levels, - &pdu.to_sync_room_event(), - &pdu.room_id, - db, - )? { - let n = match action { - Action::DontNotify => false, - // TODO: Implement proper support for coalesce - Action::Notify | Action::Coalesce => true, - Action::SetTweak(tweak) => { - tweaks.push(tweak.clone()); - continue; - } - }; - - if notify.is_some() { - return Err(Error::bad_database( - r#"Malformed pushrule contains more than one of these actions: ["dont_notify", "notify", "coalesce"]"#, - )); - } - - notify = Some(n); - } - - if notify == Some(true) { - send_notice(unread, pusher, tweaks, pdu, db).await?; - } - // Else the event triggered no actions - - Ok(()) -} - -#[tracing::instrument(skip(user, ruleset, pdu, db))] -pub fn get_actions<'a>( - user: &UserId, - ruleset: &'a Ruleset, - power_levels: &RoomPowerLevelsEventContent, - pdu: &Raw, - room_id: &RoomId, - db: &Database, -) -> Result<&'a [Action]> { - let ctx = PushConditionRoomCtx { - room_id: room_id.to_owned(), - member_count: 10_u32.into(), // TODO: get member count efficiently - user_display_name: db - .users - .displayname(user)? - .unwrap_or_else(|| user.localpart().to_owned()), - users_power_levels: power_levels.users.clone(), - default_power_level: power_levels.users_default, - notification_power_levels: power_levels.notifications.clone(), - }; - - Ok(ruleset.get_actions(pdu, &ctx)) -} - -#[tracing::instrument(skip(unread, pusher, tweaks, event, db))] -async fn send_notice( - unread: UInt, - pusher: &get_pushers::v3::Pusher, - tweaks: Vec, - event: &PduEvent, - db: &Database, -) -> Result<()> { - // TODO: email - if pusher.kind == PusherKind::Email { - return Ok(()); - } - - // TODO: - // Two problems with this - // 1. if "event_id_only" is the only format kind it seems we should never add more info - // 2. can pusher/devices have conflicting formats - let event_id_only = pusher.data.format == Some(PushFormat::EventIdOnly); - let url = if let Some(url) = &pusher.data.url { - url - } else { - error!("Http Pusher must have URL specified."); - return Ok(()); - }; - - let mut device = Device::new(pusher.app_id.clone(), pusher.pushkey.clone()); - let mut data_minus_url = pusher.data.clone(); - // The url must be stripped off according to spec - data_minus_url.url = None; - device.data = data_minus_url; - - // Tweaks are only added if the format is NOT event_id_only - if !event_id_only { - device.tweaks = tweaks.clone(); - } - - let d = &[device]; - let mut notifi = Notification::new(d); - - notifi.prio = NotificationPriority::Low; - notifi.event_id = Some(&event.event_id); - notifi.room_id = Some(&event.room_id); - // TODO: missed calls - notifi.counts = NotificationCounts::new(unread, uint!(0)); - - if event.kind == RoomEventType::RoomEncrypted - || tweaks - .iter() - .any(|t| matches!(t, Tweak::Highlight(true) | Tweak::Sound(_))) - { - notifi.prio = NotificationPriority::High - } - - if event_id_only { - send_request( - &db.globals, - url, - send_event_notification::v1::Request::new(notifi), - ) - .await?; - } else { - notifi.sender = Some(&event.sender); - notifi.event_type = Some(&event.kind); - let content = serde_json::value::to_raw_value(&event.content).ok(); - notifi.content = content.as_deref(); - - if event.kind == RoomEventType::RoomMember { - notifi.user_is_target = event.state_key.as_deref() == Some(event.sender.as_str()); - } - - let user_name = db.users.displayname(&event.sender)?; - notifi.sender_display_name = user_name.as_deref(); - - let room_name = if let Some(room_name_pdu) = - db.rooms - .room_state_get(&event.room_id, &StateEventType::RoomName, "")? - { - serde_json::from_str::(room_name_pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid room name event in database."))? - .name - } else { - None - }; - - notifi.room_name = room_name.as_deref(); - - send_request( - &db.globals, - url, - send_event_notification::v1::Request::new(notifi), - ) - .await?; - } - - // TODO: email - - Ok(()) -} diff --git a/src/database/key_value/rooms/timeline.rs b/src/database/key_value/rooms/timeline.rs index 5b423d2d..58884ec3 100644 --- a/src/database/key_value/rooms/timeline.rs +++ b/src/database/key_value/rooms/timeline.rs @@ -1,28 +1,5 @@ - - /// Checks if a room exists. - #[tracing::instrument(skip(self))] - pub fn first_pdu_in_room(&self, room_id: &RoomId) -> Result>> { - let prefix = self - .get_shortroomid(room_id)? - .expect("room exists") - .to_be_bytes() - .to_vec(); - - // Look for PDUs in that room. - self.pduid_pdu - .iter_from(&prefix, false) - .filter(|(k, _)| k.starts_with(&prefix)) - .map(|(_, pdu)| { - serde_json::from_slice(&pdu) - .map_err(|_| Error::bad_database("Invalid first PDU in db.")) - .map(Arc::new) - }) - .next() - .transpose() - } - - #[tracing::instrument(skip(self))] - pub fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { +impl service::room::timeline::Data for KeyValueDatabase { + fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { match self .lasttimelinecount_cache .lock() @@ -51,31 +28,8 @@ } } - // TODO Is this the same as the function above? - #[tracing::instrument(skip(self))] - pub fn latest_pdu_count(&self, room_id: &RoomId) -> Result { - let prefix = self - .get_shortroomid(room_id)? - .expect("room exists") - .to_be_bytes() - .to_vec(); - - let mut last_possible_key = prefix.clone(); - last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); - - self.pduid_pdu - .iter_from(&last_possible_key, true) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .next() - .map(|b| self.pdu_count(&b.0)) - .transpose() - .map(|op| op.unwrap_or_default()) - } - - - /// Returns the `count` of this pdu's id. - pub fn get_pdu_count(&self, event_id: &EventId) -> Result> { + fn get_pdu_count(&self, event_id: &EventId) -> Result> { self.eventid_pduid .get(event_id.as_bytes())? .map(|pdu_id| self.pdu_count(&pdu_id)) @@ -207,7 +161,6 @@ } /// Removes a pdu and creates a new one with the same id. - #[tracing::instrument(skip(self))] fn replace_pdu(&self, pdu_id: &[u8], pdu: &PduEvent) -> Result<()> { if self.pduid_pdu.get(pdu_id)?.is_some() { self.pduid_pdu.insert( @@ -223,598 +176,8 @@ } } - /// Creates a new persisted data unit and adds it to a room. - /// - /// By this point the incoming event should be fully authenticated, no auth happens - /// in `append_pdu`. - /// - /// Returns pdu id - #[tracing::instrument(skip(self, pdu, pdu_json, leaves, db))] - pub fn append_pdu<'a>( - &self, - pdu: &PduEvent, - mut pdu_json: CanonicalJsonObject, - leaves: impl IntoIterator + Debug, - db: &Database, - ) -> Result> { - let shortroomid = self.get_shortroomid(&pdu.room_id)?.expect("room exists"); - - // Make unsigned fields correct. This is not properly documented in the spec, but state - // events need to have previous content in the unsigned field, so clients can easily - // interpret things like membership changes - if let Some(state_key) = &pdu.state_key { - if let CanonicalJsonValue::Object(unsigned) = pdu_json - .entry("unsigned".to_owned()) - .or_insert_with(|| CanonicalJsonValue::Object(Default::default())) - { - if let Some(shortstatehash) = self.pdu_shortstatehash(&pdu.event_id).unwrap() { - if let Some(prev_state) = self - .state_get(shortstatehash, &pdu.kind.to_string().into(), state_key) - .unwrap() - { - unsigned.insert( - "prev_content".to_owned(), - CanonicalJsonValue::Object( - utils::to_canonical_object(prev_state.content.clone()) - .expect("event is valid, we just created it"), - ), - ); - } - } - } else { - error!("Invalid unsigned type in pdu."); - } - } - - // We must keep track of all events that have been referenced. - self.mark_as_referenced(&pdu.room_id, &pdu.prev_events)?; - self.replace_pdu_leaves(&pdu.room_id, leaves)?; - - let mutex_insert = Arc::clone( - db.globals - .roomid_mutex_insert - .write() - .unwrap() - .entry(pdu.room_id.clone()) - .or_default(), - ); - let insert_lock = mutex_insert.lock().unwrap(); - - let count1 = db.globals.next_count()?; - // Mark as read first so the sending client doesn't get a notification even if appending - // fails - self.edus - .private_read_set(&pdu.room_id, &pdu.sender, count1, &db.globals)?; - self.reset_notification_counts(&pdu.sender, &pdu.room_id)?; - - let count2 = db.globals.next_count()?; - let mut pdu_id = shortroomid.to_be_bytes().to_vec(); - pdu_id.extend_from_slice(&count2.to_be_bytes()); - - // There's a brief moment of time here where the count is updated but the pdu does not - // exist. This could theoretically lead to dropped pdus, but it's extremely rare - // - // Update: We fixed this using insert_lock - - self.pduid_pdu.insert( - &pdu_id, - &serde_json::to_vec(&pdu_json).expect("CanonicalJsonObject is always a valid"), - )?; - self.lasttimelinecount_cache - .lock() - .unwrap() - .insert(pdu.room_id.clone(), count2); - - self.eventid_pduid - .insert(pdu.event_id.as_bytes(), &pdu_id)?; - self.eventid_outlierpdu.remove(pdu.event_id.as_bytes())?; - - drop(insert_lock); - - // See if the event matches any known pushers - let power_levels: RoomPowerLevelsEventContent = db - .rooms - .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")? - .map(|ev| { - serde_json::from_str(ev.content.get()) - .map_err(|_| Error::bad_database("invalid m.room.power_levels event")) - }) - .transpose()? - .unwrap_or_default(); - - let sync_pdu = pdu.to_sync_room_event(); - - let mut notifies = Vec::new(); - let mut highlights = Vec::new(); - - for user in self.get_our_real_users(&pdu.room_id, db)?.iter() { - // Don't notify the user of their own events - if user == &pdu.sender { - continue; - } - - let rules_for_user = db - .account_data - .get( - None, - user, - GlobalAccountDataEventType::PushRules.to_string().into(), - )? - .map(|ev: PushRulesEvent| ev.content.global) - .unwrap_or_else(|| Ruleset::server_default(user)); - - let mut highlight = false; - let mut notify = false; - - for action in pusher::get_actions( - user, - &rules_for_user, - &power_levels, - &sync_pdu, - &pdu.room_id, - db, - )? { - match action { - Action::DontNotify => notify = false, - // TODO: Implement proper support for coalesce - Action::Notify | Action::Coalesce => notify = true, - Action::SetTweak(Tweak::Highlight(true)) => { - highlight = true; - } - _ => {} - }; - } - - let mut userroom_id = user.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(pdu.room_id.as_bytes()); - - if notify { - notifies.push(userroom_id.clone()); - } - - if highlight { - highlights.push(userroom_id); - } - - for senderkey in db.pusher.get_pusher_senderkeys(user) { - db.sending.send_push_pdu(&*pdu_id, senderkey)?; - } - } - - self.userroomid_notificationcount - .increment_batch(&mut notifies.into_iter())?; - self.userroomid_highlightcount - .increment_batch(&mut highlights.into_iter())?; - - match pdu.kind { - RoomEventType::RoomRedaction => { - if let Some(redact_id) = &pdu.redacts { - self.redact_pdu(redact_id, pdu)?; - } - } - RoomEventType::RoomMember => { - if let Some(state_key) = &pdu.state_key { - #[derive(Deserialize)] - struct ExtractMembership { - membership: MembershipState, - } - - // if the state_key fails - let target_user_id = UserId::parse(state_key.clone()) - .expect("This state_key was previously validated"); - - let content = serde_json::from_str::(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid content in pdu."))?; - - let invite_state = match content.membership { - MembershipState::Invite => { - let state = self.calculate_invite_state(pdu)?; - Some(state) - } - _ => None, - }; - - // Update our membership info, we do this here incase a user is invited - // and immediately leaves we need the DB to record the invite event for auth - self.update_membership( - &pdu.room_id, - &target_user_id, - content.membership, - &pdu.sender, - invite_state, - db, - true, - )?; - } - } - RoomEventType::RoomMessage => { - #[derive(Deserialize)] - struct ExtractBody<'a> { - #[serde(borrow)] - body: Option>, - } - - let content = serde_json::from_str::>(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid content in pdu."))?; - - if let Some(body) = content.body { - DB.rooms.search.index_pdu(room_id, pdu_id, body)?; - - let admin_room = self.id_from_alias( - <&RoomAliasId>::try_from( - format!("#admins:{}", db.globals.server_name()).as_str(), - ) - .expect("#admins:server_name is a valid room alias"), - )?; - let server_user = format!("@conduit:{}", db.globals.server_name()); - - let to_conduit = body.starts_with(&format!("{}: ", server_user)); - - // This will evaluate to false if the emergency password is set up so that - // the administrator can execute commands as conduit - let from_conduit = - pdu.sender == server_user && db.globals.emergency_password().is_none(); - - if to_conduit && !from_conduit && admin_room.as_ref() == Some(&pdu.room_id) { - db.admin.process_message(body.to_string()); - } - } - } - _ => {} - } - - for appservice in db.appservice.all()? { - if self.appservice_in_room(room_id, &appservice, db)? { - db.sending.send_pdu_appservice(&appservice.0, &pdu_id)?; - continue; - } - - // If the RoomMember event has a non-empty state_key, it is targeted at someone. - // If it is our appservice user, we send this PDU to it. - if pdu.kind == RoomEventType::RoomMember { - if let Some(state_key_uid) = &pdu - .state_key - .as_ref() - .and_then(|state_key| UserId::parse(state_key.as_str()).ok()) - { - if let Some(appservice_uid) = appservice - .1 - .get("sender_localpart") - .and_then(|string| string.as_str()) - .and_then(|string| { - UserId::parse_with_server_name(string, db.globals.server_name()).ok() - }) - { - if state_key_uid == &appservice_uid { - db.sending.send_pdu_appservice(&appservice.0, &pdu_id)?; - continue; - } - } - } - } - - if let Some(namespaces) = appservice.1.get("namespaces") { - let users = namespaces - .get("users") - .and_then(|users| users.as_sequence()) - .map_or_else(Vec::new, |users| { - users - .iter() - .filter_map(|users| Regex::new(users.get("regex")?.as_str()?).ok()) - .collect::>() - }); - let aliases = namespaces - .get("aliases") - .and_then(|aliases| aliases.as_sequence()) - .map_or_else(Vec::new, |aliases| { - aliases - .iter() - .filter_map(|aliases| Regex::new(aliases.get("regex")?.as_str()?).ok()) - .collect::>() - }); - let rooms = namespaces - .get("rooms") - .and_then(|rooms| rooms.as_sequence()); - - let matching_users = |users: &Regex| { - users.is_match(pdu.sender.as_str()) - || pdu.kind == RoomEventType::RoomMember - && pdu - .state_key - .as_ref() - .map_or(false, |state_key| users.is_match(state_key)) - }; - let matching_aliases = |aliases: &Regex| { - self.room_aliases(room_id) - .filter_map(|r| r.ok()) - .any(|room_alias| aliases.is_match(room_alias.as_str())) - }; - - if aliases.iter().any(matching_aliases) - || rooms.map_or(false, |rooms| rooms.contains(&room_id.as_str().into())) - || users.iter().any(matching_users) - { - db.sending.send_pdu_appservice(&appservice.0, &pdu_id)?; - } - } - } - - - Ok(pdu_id) - } - - pub fn create_hash_and_sign_event( - &self, - pdu_builder: PduBuilder, - sender: &UserId, - room_id: &RoomId, - db: &Database, - _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex - ) -> (PduEvent, CanonicalJsonObj) { - let PduBuilder { - event_type, - content, - unsigned, - state_key, - redacts, - } = pdu_builder; - - let prev_events: Vec<_> = db - .rooms - .get_pdu_leaves(room_id)? - .into_iter() - .take(20) - .collect(); - - let create_event = db - .rooms - .room_state_get(room_id, &StateEventType::RoomCreate, "")?; - - let create_event_content: Option = create_event - .as_ref() - .map(|create_event| { - serde_json::from_str(create_event.content.get()).map_err(|e| { - warn!("Invalid create event: {}", e); - Error::bad_database("Invalid create event in db.") - }) - }) - .transpose()?; - - // If there was no create event yet, assume we are creating a room with the default - // version right now - let room_version_id = create_event_content - .map_or(db.globals.default_room_version(), |create_event| { - create_event.room_version - }); - let room_version = - RoomVersion::new(&room_version_id).expect("room version is supported"); - - let auth_events = - self.get_auth_events(room_id, &event_type, sender, state_key.as_deref(), &content)?; - - // Our depth is the maximum depth of prev_events + 1 - let depth = prev_events - .iter() - .filter_map(|event_id| Some(db.rooms.get_pdu(event_id).ok()??.depth)) - .max() - .unwrap_or_else(|| uint!(0)) - + uint!(1); - - let mut unsigned = unsigned.unwrap_or_default(); - - if let Some(state_key) = &state_key { - if let Some(prev_pdu) = - self.room_state_get(room_id, &event_type.to_string().into(), state_key)? - { - unsigned.insert( - "prev_content".to_owned(), - serde_json::from_str(prev_pdu.content.get()).expect("string is valid json"), - ); - unsigned.insert( - "prev_sender".to_owned(), - serde_json::to_value(&prev_pdu.sender).expect("UserId::to_value always works"), - ); - } - } - - let pdu = PduEvent { - event_id: ruma::event_id!("$thiswillbefilledinlater").into(), - room_id: room_id.to_owned(), - sender: sender_user.to_owned(), - origin_server_ts: utils::millis_since_unix_epoch() - .try_into() - .expect("time is valid"), - kind: event_type, - content, - state_key, - prev_events, - depth, - auth_events: auth_events - .iter() - .map(|(_, pdu)| pdu.event_id.clone()) - .collect(), - redacts, - unsigned: if unsigned.is_empty() { - None - } else { - Some(to_raw_value(&unsigned).expect("to_raw_value always works")) - }, - hashes: EventHash { - sha256: "aaa".to_owned(), - }, - signatures: None, - }; - - let auth_check = state_res::auth_check( - &room_version, - &pdu, - None::, // TODO: third_party_invite - |k, s| auth_events.get(&(k.clone(), s.to_owned())), - ) - .map_err(|e| { - error!("{:?}", e); - Error::bad_database("Auth check failed.") - })?; - - if !auth_check { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Event is not authorized.", - )); - } - - // Hash and sign - let mut pdu_json = - utils::to_canonical_object(&pdu).expect("event is valid, we just created it"); - - pdu_json.remove("event_id"); - - // Add origin because synapse likes that (and it's required in the spec) - pdu_json.insert( - "origin".to_owned(), - to_canonical_value(db.globals.server_name()) - .expect("server name is a valid CanonicalJsonValue"), - ); - - match ruma::signatures::hash_and_sign_event( - db.globals.server_name().as_str(), - db.globals.keypair(), - &mut pdu_json, - &room_version_id, - ) { - Ok(_) => {} - Err(e) => { - return match e { - ruma::signatures::Error::PduSize => Err(Error::BadRequest( - ErrorKind::TooLarge, - "Message is too long", - )), - _ => Err(Error::BadRequest( - ErrorKind::Unknown, - "Signing event failed", - )), - } - } - } - - // Generate event id - pdu.event_id = EventId::parse_arc(format!( - "${}", - ruma::signatures::reference_hash(&pdu_json, &room_version_id) - .expect("ruma can calculate reference hashes") - )) - .expect("ruma's reference hashes are valid event ids"); - - pdu_json.insert( - "event_id".to_owned(), - CanonicalJsonValue::String(pdu.event_id.as_str().to_owned()), - ); - - // Generate short event id - let _shorteventid = self.get_or_create_shorteventid(&pdu.event_id, &db.globals)?; - } - - /// Creates a new persisted data unit and adds it to a room. This function takes a - /// roomid_mutex_state, meaning that only this function is able to mutate the room state. - #[tracing::instrument(skip(self, db, _mutex_lock))] - pub fn build_and_append_pdu( - &self, - pdu_builder: PduBuilder, - sender: &UserId, - room_id: &RoomId, - db: &Database, - _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex - ) -> Result> { - - let (pdu, pdu_json) = create_hash_and_sign_event()?; - - - // We append to state before appending the pdu, so we don't have a moment in time with the - // pdu without it's state. This is okay because append_pdu can't fail. - let statehashid = self.append_to_state(&pdu, &db.globals)?; - - let pdu_id = self.append_pdu( - &pdu, - pdu_json, - // Since this PDU references all pdu_leaves we can update the leaves - // of the room - iter::once(&*pdu.event_id), - db, - )?; - - // We set the room state after inserting the pdu, so that we never have a moment in time - // where events in the current room state do not exist - self.set_room_state(room_id, statehashid)?; - - let mut servers: HashSet> = - self.room_servers(room_id).filter_map(|r| r.ok()).collect(); - - // In case we are kicking or banning a user, we need to inform their server of the change - if pdu.kind == RoomEventType::RoomMember { - if let Some(state_key_uid) = &pdu - .state_key - .as_ref() - .and_then(|state_key| UserId::parse(state_key.as_str()).ok()) - { - servers.insert(Box::from(state_key_uid.server_name())); - } - } - - // Remove our server from the server list since it will be added to it by room_servers() and/or the if statement above - servers.remove(db.globals.server_name()); - - db.sending.send_pdu(servers.into_iter(), &pdu_id)?; - - Ok(pdu.event_id) - } - - /// Append the incoming event setting the state snapshot to the state from the - /// server that sent the event. - #[tracing::instrument(skip_all)] - fn append_incoming_pdu<'a>( - db: &Database, - pdu: &PduEvent, - pdu_json: CanonicalJsonObject, - new_room_leaves: impl IntoIterator + Clone + Debug, - state_ids_compressed: HashSet, - soft_fail: bool, - _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex - ) -> Result>> { - // We append to state before appending the pdu, so we don't have a moment in time with the - // pdu without it's state. This is okay because append_pdu can't fail. - db.rooms.set_event_state( - &pdu.event_id, - &pdu.room_id, - state_ids_compressed, - &db.globals, - )?; - - if soft_fail { - db.rooms - .mark_as_referenced(&pdu.room_id, &pdu.prev_events)?; - db.rooms.replace_pdu_leaves(&pdu.room_id, new_room_leaves)?; - return Ok(None); - } - - let pdu_id = db.rooms.append_pdu(pdu, pdu_json, new_room_leaves, db)?; - - Ok(Some(pdu_id)) - } - - /// Returns an iterator over all PDUs in a room. - #[tracing::instrument(skip(self))] - pub fn all_pdus<'a>( - &'a self, - user_id: &UserId, - room_id: &RoomId, - ) -> Result, PduEvent)>> + 'a> { - self.pdus_since(user_id, room_id, 0) - } - /// Returns an iterator over all events in a room that happened after the event with id `since` /// in chronological order. - #[tracing::instrument(skip(self))] pub fn pdus_since<'a>( &'a self, user_id: &UserId, @@ -849,7 +212,6 @@ /// Returns an iterator over all events and their tokens in a room that happened before the /// event with id `until` in reverse-chronological order. - #[tracing::instrument(skip(self))] pub fn pdus_until<'a>( &'a self, user_id: &UserId, @@ -884,9 +246,6 @@ })) } - /// Returns an iterator over all events and their token in a room that happened after the event - /// with id `from` in chronological order. - #[tracing::instrument(skip(self))] pub fn pdus_after<'a>( &'a self, user_id: &UserId, @@ -920,18 +279,4 @@ Ok((pdu_id, pdu)) })) } - - /// Replace a PDU with the redacted form. - #[tracing::instrument(skip(self, reason))] - pub fn redact_pdu(&self, event_id: &EventId, reason: &PduEvent) -> Result<()> { - if let Some(pdu_id) = self.get_pdu_id(event_id)? { - let mut pdu = self - .get_pdu_from_id(&pdu_id)? - .ok_or_else(|| Error::bad_database("PDU ID points to invalid PDU."))?; - pdu.redact(reason)?; - self.replace_pdu(&pdu_id, &pdu)?; - } - // If event does not exist, just noop - Ok(()) - } - +} diff --git a/src/service/transaction_ids.rs b/src/database/key_value/transaction_ids.rs similarity index 77% rename from src/service/transaction_ids.rs rename to src/database/key_value/transaction_ids.rs index ed0970d1..81c1197d 100644 --- a/src/service/transaction_ids.rs +++ b/src/database/key_value/transaction_ids.rs @@ -1,15 +1,4 @@ -use std::sync::Arc; - -use crate::Result; -use ruma::{DeviceId, TransactionId, UserId}; - -use super::abstraction::Tree; - -pub struct TransactionIds { - pub(super) userdevicetxnid_response: Arc, // Response can be empty (/sendToDevice) or the event id (/send) -} - -impl TransactionIds { +impl service::pusher::Data for KeyValueDatabase { pub fn add_txnid( &self, user_id: &UserId, diff --git a/src/database/key_value/users.rs b/src/database/key_value/users.rs index 7c15f1d8..5ef058f3 100644 --- a/src/database/key_value/users.rs +++ b/src/database/key_value/users.rs @@ -1,49 +1,10 @@ -use crate::{utils, Error, Result}; -use ruma::{ - api::client::{device::Device, error::ErrorKind, filter::IncomingFilterDefinition}, - encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, - events::{AnyToDeviceEvent, StateEventType}, - serde::Raw, - DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, MxcUri, RoomAliasId, - UInt, UserId, -}; -use std::{collections::BTreeMap, mem, sync::Arc}; -use tracing::warn; - -use super::abstraction::Tree; - -pub struct Users { - pub(super) userid_password: Arc, - pub(super) userid_displayname: Arc, - pub(super) userid_avatarurl: Arc, - pub(super) userid_blurhash: Arc, - pub(super) userdeviceid_token: Arc, - pub(super) userdeviceid_metadata: Arc, // This is also used to check if a device exists - pub(super) userid_devicelistversion: Arc, // DevicelistVersion = u64 - pub(super) token_userdeviceid: Arc, - - pub(super) onetimekeyid_onetimekeys: Arc, // OneTimeKeyId = UserId + DeviceKeyId - pub(super) userid_lastonetimekeyupdate: Arc, // LastOneTimeKeyUpdate = Count - pub(super) keychangeid_userid: Arc, // KeyChangeId = UserId/RoomId + Count - pub(super) keyid_key: Arc, // KeyId = UserId + KeyId (depends on key type) - pub(super) userid_masterkeyid: Arc, - pub(super) userid_selfsigningkeyid: Arc, - pub(super) userid_usersigningkeyid: Arc, - - pub(super) userfilterid_filter: Arc, // UserFilterId = UserId + FilterId - - pub(super) todeviceid_events: Arc, // ToDeviceId = UserId + DeviceId + Count -} - -impl Users { +impl service::users::Data for KeyValueDatabase { /// Check if a user has an account on this homeserver. - #[tracing::instrument(skip(self, user_id))] pub fn exists(&self, user_id: &UserId) -> Result { Ok(self.userid_password.get(user_id.as_bytes())?.is_some()) } /// Check if account is deactivated - #[tracing::instrument(skip(self, user_id))] pub fn is_deactivated(&self, user_id: &UserId) -> Result { Ok(self .userid_password @@ -56,7 +17,6 @@ impl Users { } /// Check if a user is an admin - #[tracing::instrument(skip(self, user_id, rooms, globals))] pub fn is_admin( &self, user_id: &UserId, @@ -71,20 +31,17 @@ impl Users { } /// Create a new user account on this homeserver. - #[tracing::instrument(skip(self, user_id, password))] pub fn create(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { self.set_password(user_id, password)?; Ok(()) } /// Returns the number of users registered on this server. - #[tracing::instrument(skip(self))] pub fn count(&self) -> Result { Ok(self.userid_password.iter().count()) } /// Find out which user an access token belongs to. - #[tracing::instrument(skip(self, token))] pub fn find_from_token(&self, token: &str) -> Result, String)>> { self.token_userdeviceid .get(token.as_bytes())? @@ -112,7 +69,6 @@ impl Users { } /// Returns an iterator over all users on this homeserver. - #[tracing::instrument(skip(self))] pub fn iter(&self) -> impl Iterator>> + '_ { self.userid_password.iter().map(|(bytes, _)| { UserId::parse(utils::string_from_bytes(&bytes).map_err(|_| { @@ -125,7 +81,6 @@ impl Users { /// Returns a list of local users as list of usernames. /// /// A user account is considered `local` if the length of it's password is greater then zero. - #[tracing::instrument(skip(self))] pub fn list_local_users(&self) -> Result> { let users: Vec = self .userid_password @@ -139,7 +94,6 @@ impl Users { /// username could be successfully parsed. /// If utils::string_from_bytes(...) returns an error that username will be skipped /// and the error will be logged. - #[tracing::instrument(skip(self))] fn get_username_with_valid_password(&self, username: &[u8], password: &[u8]) -> Option { // A valid password is not empty if password.is_empty() { @@ -159,7 +113,6 @@ impl Users { } /// Returns the password hash for the given user. - #[tracing::instrument(skip(self, user_id))] pub fn password_hash(&self, user_id: &UserId) -> Result> { self.userid_password .get(user_id.as_bytes())? @@ -171,7 +124,6 @@ impl Users { } /// Hash and set the user's password to the Argon2 hash - #[tracing::instrument(skip(self, user_id, password))] pub fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { if let Some(password) = password { if let Ok(hash) = utils::calculate_hash(password) { @@ -191,7 +143,6 @@ impl Users { } /// Returns the displayname of a user on this homeserver. - #[tracing::instrument(skip(self, user_id))] pub fn displayname(&self, user_id: &UserId) -> Result> { self.userid_displayname .get(user_id.as_bytes())? @@ -203,7 +154,6 @@ impl Users { } /// Sets a new displayname or removes it if displayname is None. You still need to nofify all rooms of this change. - #[tracing::instrument(skip(self, user_id, displayname))] pub fn set_displayname(&self, user_id: &UserId, displayname: Option) -> Result<()> { if let Some(displayname) = displayname { self.userid_displayname @@ -216,7 +166,6 @@ impl Users { } /// Get the avatar_url of a user. - #[tracing::instrument(skip(self, user_id))] pub fn avatar_url(&self, user_id: &UserId) -> Result>> { self.userid_avatarurl .get(user_id.as_bytes())? @@ -230,7 +179,6 @@ impl Users { } /// Sets a new avatar_url or removes it if avatar_url is None. - #[tracing::instrument(skip(self, user_id, avatar_url))] pub fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option>) -> Result<()> { if let Some(avatar_url) = avatar_url { self.userid_avatarurl @@ -243,7 +191,6 @@ impl Users { } /// Get the blurhash of a user. - #[tracing::instrument(skip(self, user_id))] pub fn blurhash(&self, user_id: &UserId) -> Result> { self.userid_blurhash .get(user_id.as_bytes())? @@ -257,7 +204,6 @@ impl Users { } /// Sets a new avatar_url or removes it if avatar_url is None. - #[tracing::instrument(skip(self, user_id, blurhash))] pub fn set_blurhash(&self, user_id: &UserId, blurhash: Option) -> Result<()> { if let Some(blurhash) = blurhash { self.userid_blurhash @@ -270,7 +216,6 @@ impl Users { } /// Adds a new device to a user. - #[tracing::instrument(skip(self, user_id, device_id, token, initial_device_display_name))] pub fn create_device( &self, user_id: &UserId, @@ -305,7 +250,6 @@ impl Users { } /// Removes a device from a user. - #[tracing::instrument(skip(self, user_id, device_id))] pub fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { let mut userdeviceid = user_id.as_bytes().to_vec(); userdeviceid.push(0xff); @@ -336,7 +280,6 @@ impl Users { } /// Returns an iterator over all device ids of this user. - #[tracing::instrument(skip(self, user_id))] pub fn all_device_ids<'a>( &'a self, user_id: &UserId, @@ -359,7 +302,6 @@ impl Users { } /// Replaces the access token of one device. - #[tracing::instrument(skip(self, user_id, device_id, token))] pub fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> { let mut userdeviceid = user_id.as_bytes().to_vec(); userdeviceid.push(0xff); @@ -383,14 +325,6 @@ impl Users { Ok(()) } - #[tracing::instrument(skip( - self, - user_id, - device_id, - one_time_key_key, - one_time_key_value, - globals - ))] pub fn add_one_time_key( &self, user_id: &UserId, @@ -427,7 +361,6 @@ impl Users { Ok(()) } - #[tracing::instrument(skip(self, user_id))] pub fn last_one_time_keys_update(&self, user_id: &UserId) -> Result { self.userid_lastonetimekeyupdate .get(user_id.as_bytes())? @@ -439,7 +372,6 @@ impl Users { .unwrap_or(Ok(0)) } - #[tracing::instrument(skip(self, user_id, device_id, key_algorithm, globals))] pub fn take_one_time_key( &self, user_id: &UserId, @@ -479,7 +411,6 @@ impl Users { .transpose() } - #[tracing::instrument(skip(self, user_id, device_id))] pub fn count_one_time_keys( &self, user_id: &UserId, @@ -512,7 +443,6 @@ impl Users { Ok(counts) } - #[tracing::instrument(skip(self, user_id, device_id, device_keys, rooms, globals))] pub fn add_device_keys( &self, user_id: &UserId, @@ -535,14 +465,6 @@ impl Users { Ok(()) } - #[tracing::instrument(skip( - self, - master_key, - self_signing_key, - user_signing_key, - rooms, - globals - ))] pub fn add_cross_signing_keys( &self, user_id: &UserId, @@ -658,7 +580,6 @@ impl Users { Ok(()) } - #[tracing::instrument(skip(self, target_id, key_id, signature, sender_id, rooms, globals))] pub fn sign_key( &self, target_id: &UserId, @@ -703,7 +624,6 @@ impl Users { Ok(()) } - #[tracing::instrument(skip(self, user_or_room_id, from, to))] pub fn keys_changed<'a>( &'a self, user_or_room_id: &str, @@ -742,7 +662,6 @@ impl Users { }) } - #[tracing::instrument(skip(self, user_id, rooms, globals))] pub fn mark_device_key_update( &self, user_id: &UserId, @@ -774,7 +693,6 @@ impl Users { Ok(()) } - #[tracing::instrument(skip(self, user_id, device_id))] pub fn get_device_keys( &self, user_id: &UserId, @@ -791,7 +709,6 @@ impl Users { }) } - #[tracing::instrument(skip(self, user_id, allowed_signatures))] pub fn get_master_key bool>( &self, user_id: &UserId, @@ -813,7 +730,6 @@ impl Users { }) } - #[tracing::instrument(skip(self, user_id, allowed_signatures))] pub fn get_self_signing_key bool>( &self, user_id: &UserId, @@ -835,7 +751,6 @@ impl Users { }) } - #[tracing::instrument(skip(self, user_id))] pub fn get_user_signing_key(&self, user_id: &UserId) -> Result>> { self.userid_usersigningkeyid .get(user_id.as_bytes())? @@ -848,15 +763,6 @@ impl Users { }) } - #[tracing::instrument(skip( - self, - sender, - target_user_id, - target_device_id, - event_type, - content, - globals - ))] pub fn add_to_device_event( &self, sender: &UserId, @@ -884,7 +790,6 @@ impl Users { Ok(()) } - #[tracing::instrument(skip(self, user_id, device_id))] pub fn get_to_device_events( &self, user_id: &UserId, @@ -907,7 +812,6 @@ impl Users { Ok(events) } - #[tracing::instrument(skip(self, user_id, device_id, until))] pub fn remove_to_device_events( &self, user_id: &UserId, @@ -942,7 +846,6 @@ impl Users { Ok(()) } - #[tracing::instrument(skip(self, user_id, device_id, device))] pub fn update_device_metadata( &self, user_id: &UserId, @@ -968,7 +871,6 @@ impl Users { } /// Get device metadata. - #[tracing::instrument(skip(self, user_id, device_id))] pub fn get_device_metadata( &self, user_id: &UserId, @@ -987,7 +889,6 @@ impl Users { }) } - #[tracing::instrument(skip(self, user_id))] pub fn get_devicelist_version(&self, user_id: &UserId) -> Result> { self.userid_devicelistversion .get(user_id.as_bytes())? @@ -998,7 +899,6 @@ impl Users { }) } - #[tracing::instrument(skip(self, user_id))] pub fn all_devices_metadata<'a>( &'a self, user_id: &UserId, @@ -1014,25 +914,7 @@ impl Users { }) } - /// Deactivate account - #[tracing::instrument(skip(self, user_id))] - pub fn deactivate_account(&self, user_id: &UserId) -> Result<()> { - // Remove all associated devices - for device_id in self.all_device_ids(user_id) { - self.remove_device(user_id, &device_id?)?; - } - - // Set the password to "" to indicate a deactivated account. Hashes will never result in an - // empty string, so the user will not be able to log in again. Systems like changing the - // password without logging in should check if the account is deactivated. - self.userid_password.insert(user_id.as_bytes(), &[])?; - - // TODO: Unhook 3PID - Ok(()) - } - /// Creates a new sync filter. Returns the filter id. - #[tracing::instrument(skip(self))] pub fn create_filter( &self, user_id: &UserId, @@ -1052,7 +934,6 @@ impl Users { Ok(filter_id) } - #[tracing::instrument(skip(self))] pub fn get_filter( &self, user_id: &UserId, @@ -1072,30 +953,3 @@ impl Users { } } } - -/// Ensure that a user only sees signatures from themselves and the target user -fn clean_signatures bool>( - cross_signing_key: &mut serde_json::Value, - user_id: &UserId, - allowed_signatures: F, -) -> Result<(), Error> { - if let Some(signatures) = cross_signing_key - .get_mut("signatures") - .and_then(|v| v.as_object_mut()) - { - // Don't allocate for the full size of the current signatures, but require - // at most one resize if nothing is dropped - let new_capacity = signatures.len() / 2; - for (user, signature) in - mem::replace(signatures, serde_json::Map::with_capacity(new_capacity)) - { - let id = <&UserId>::try_from(user.as_str()) - .map_err(|_| Error::bad_database("Invalid user ID in database."))?; - if id == user_id || allowed_signatures(id) { - signatures.insert(user, signature); - } - } - } - - Ok(()) -} diff --git a/src/service/appservice/data.rs b/src/service/appservice/data.rs new file mode 100644 index 00000000..fe57451f --- /dev/null +++ b/src/service/appservice/data.rs @@ -0,0 +1,17 @@ +pub trait Data { + /// Registers an appservice and returns the ID to the caller + pub fn register_appservice(&self, yaml: serde_yaml::Value) -> Result; + + /// Remove an appservice registration + /// + /// # Arguments + /// + /// * `service_name` - the name you send to register the service previously + pub fn unregister_appservice(&self, service_name: &str) -> Result<()>; + + pub fn get_registration(&self, id: &str) -> Result>; + + pub fn iter_ids(&self) -> Result> + '_>; + + pub fn all(&self) -> Result>; +} diff --git a/src/service/appservice/mod.rs b/src/service/appservice/mod.rs new file mode 100644 index 00000000..ec4ffc56 --- /dev/null +++ b/src/service/appservice/mod.rs @@ -0,0 +1,36 @@ +mod data; +pub use data::Data; + +use crate::service::*; + +pub struct Service { + db: D, +} + +impl Service<_> { + /// Registers an appservice and returns the ID to the caller + pub fn register_appservice(&self, yaml: serde_yaml::Value) -> Result { + self.db.register_appservice(yaml) + } + + /// Remove an appservice registration + /// + /// # Arguments + /// + /// * `service_name` - the name you send to register the service previously + pub fn unregister_appservice(&self, service_name: &str) -> Result<()> { + self.db.unregister_appservice(service_name) + } + + pub fn get_registration(&self, id: &str) -> Result> { + self.db.get_registration(id) + } + + pub fn iter_ids(&self) -> Result> + '_> { + self.db.iter_ids() + } + + pub fn all(&self) -> Result> { + self.db.all() + } +} diff --git a/src/service/globals.rs b/src/service/globals.rs index 7e09128e..2b47e5b1 100644 --- a/src/service/globals.rs +++ b/src/service/globals.rs @@ -1,3 +1,8 @@ +mod data; +pub use data::Data; + +use crate::service::*; + use crate::{database::Config, server_server::FedDest, utils, Error, Result}; use ruma::{ api::{ @@ -32,10 +37,11 @@ type SyncHandle = ( Receiver>>, // rx ); -pub struct Globals { +pub struct Service { + db: D, + pub actual_destination_cache: Arc>, // actual_destination, host pub tls_name_override: Arc>, - pub(super) globals: Arc, pub config: Config, keypair: Arc, dns_resolver: TokioAsyncResolver, @@ -44,7 +50,6 @@ pub struct Globals { default_client: reqwest::Client, pub stable_room_versions: Vec, pub unstable_room_versions: Vec, - pub(super) server_signingkeys: Arc, pub bad_event_ratelimiter: Arc, RateLimitState>>>, pub bad_signature_ratelimiter: Arc, RateLimitState>>>, pub servername_ratelimiter: Arc, Arc>>>, @@ -87,7 +92,8 @@ impl Default for RotationHandler { } } -impl Globals { + +impl Service<_> { pub fn load( globals: Arc, server_signingkeys: Arc, diff --git a/src/service/pusher.rs b/src/service/pusher.rs deleted file mode 100644 index 6b906c24..00000000 --- a/src/service/pusher.rs +++ /dev/null @@ -1,348 +0,0 @@ -use crate::{Database, Error, PduEvent, Result}; -use bytes::BytesMut; -use ruma::{ - api::{ - client::push::{get_pushers, set_pusher, PusherKind}, - push_gateway::send_event_notification::{ - self, - v1::{Device, Notification, NotificationCounts, NotificationPriority}, - }, - IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken, - }, - events::{ - room::{name::RoomNameEventContent, power_levels::RoomPowerLevelsEventContent}, - AnySyncRoomEvent, RoomEventType, StateEventType, - }, - push::{Action, PushConditionRoomCtx, PushFormat, Ruleset, Tweak}, - serde::Raw, - uint, RoomId, UInt, UserId, -}; -use tracing::{error, info, warn}; - -use std::{fmt::Debug, mem, sync::Arc}; - -use super::abstraction::Tree; - -pub struct PushData { - /// UserId + pushkey -> Pusher - pub(super) senderkey_pusher: Arc, -} - -impl PushData { - #[tracing::instrument(skip(self, sender, pusher))] - pub fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::Pusher) -> Result<()> { - let mut key = sender.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(pusher.pushkey.as_bytes()); - - // There are 2 kinds of pushers but the spec says: null deletes the pusher. - if pusher.kind.is_none() { - return self - .senderkey_pusher - .remove(&key) - .map(|_| ()) - .map_err(Into::into); - } - - self.senderkey_pusher.insert( - &key, - &serde_json::to_vec(&pusher).expect("Pusher is valid JSON value"), - )?; - - Ok(()) - } - - #[tracing::instrument(skip(self, senderkey))] - pub fn get_pusher(&self, senderkey: &[u8]) -> Result> { - self.senderkey_pusher - .get(senderkey)? - .map(|push| { - serde_json::from_slice(&*push) - .map_err(|_| Error::bad_database("Invalid Pusher in db.")) - }) - .transpose() - } - - #[tracing::instrument(skip(self, sender))] - pub fn get_pushers(&self, sender: &UserId) -> Result> { - let mut prefix = sender.as_bytes().to_vec(); - prefix.push(0xff); - - self.senderkey_pusher - .scan_prefix(prefix) - .map(|(_, push)| { - serde_json::from_slice(&*push) - .map_err(|_| Error::bad_database("Invalid Pusher in db.")) - }) - .collect() - } - - #[tracing::instrument(skip(self, sender))] - pub fn get_pusher_senderkeys<'a>( - &'a self, - sender: &UserId, - ) -> impl Iterator> + 'a { - let mut prefix = sender.as_bytes().to_vec(); - prefix.push(0xff); - - self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| k) - } -} - -#[tracing::instrument(skip(globals, destination, request))] -pub async fn send_request( - globals: &crate::database::globals::Globals, - destination: &str, - request: T, -) -> Result -where - T: Debug, -{ - let destination = destination.replace("/_matrix/push/v1/notify", ""); - - let http_request = request - .try_into_http_request::( - &destination, - SendAccessToken::IfRequired(""), - &[MatrixVersion::V1_0], - ) - .map_err(|e| { - warn!("Failed to find destination {}: {}", destination, e); - Error::BadServerResponse("Invalid destination") - })? - .map(|body| body.freeze()); - - let reqwest_request = reqwest::Request::try_from(http_request) - .expect("all http requests are valid reqwest requests"); - - // TODO: we could keep this very short and let expo backoff do it's thing... - //*reqwest_request.timeout_mut() = Some(Duration::from_secs(5)); - - let url = reqwest_request.url().clone(); - let response = globals.default_client().execute(reqwest_request).await; - - match response { - Ok(mut response) => { - // reqwest::Response -> http::Response conversion - let status = response.status(); - let mut http_response_builder = http::Response::builder() - .status(status) - .version(response.version()); - mem::swap( - response.headers_mut(), - http_response_builder - .headers_mut() - .expect("http::response::Builder is usable"), - ); - - let body = response.bytes().await.unwrap_or_else(|e| { - warn!("server error {}", e); - Vec::new().into() - }); // TODO: handle timeout - - if status != 200 { - info!( - "Push gateway returned bad response {} {}\n{}\n{:?}", - destination, - status, - url, - crate::utils::string_from_bytes(&body) - ); - } - - let response = T::IncomingResponse::try_from_http_response( - http_response_builder - .body(body) - .expect("reqwest body is valid http body"), - ); - response.map_err(|_| { - info!( - "Push gateway returned invalid response bytes {}\n{}", - destination, url - ); - Error::BadServerResponse("Push gateway returned bad response.") - }) - } - Err(e) => Err(e.into()), - } -} - -#[tracing::instrument(skip(user, unread, pusher, ruleset, pdu, db))] -pub async fn send_push_notice( - user: &UserId, - unread: UInt, - pusher: &get_pushers::v3::Pusher, - ruleset: Ruleset, - pdu: &PduEvent, - db: &Database, -) -> Result<()> { - let mut notify = None; - let mut tweaks = Vec::new(); - - let power_levels: RoomPowerLevelsEventContent = db - .rooms - .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")? - .map(|ev| { - serde_json::from_str(ev.content.get()) - .map_err(|_| Error::bad_database("invalid m.room.power_levels event")) - }) - .transpose()? - .unwrap_or_default(); - - for action in get_actions( - user, - &ruleset, - &power_levels, - &pdu.to_sync_room_event(), - &pdu.room_id, - db, - )? { - let n = match action { - Action::DontNotify => false, - // TODO: Implement proper support for coalesce - Action::Notify | Action::Coalesce => true, - Action::SetTweak(tweak) => { - tweaks.push(tweak.clone()); - continue; - } - }; - - if notify.is_some() { - return Err(Error::bad_database( - r#"Malformed pushrule contains more than one of these actions: ["dont_notify", "notify", "coalesce"]"#, - )); - } - - notify = Some(n); - } - - if notify == Some(true) { - send_notice(unread, pusher, tweaks, pdu, db).await?; - } - // Else the event triggered no actions - - Ok(()) -} - -#[tracing::instrument(skip(user, ruleset, pdu, db))] -pub fn get_actions<'a>( - user: &UserId, - ruleset: &'a Ruleset, - power_levels: &RoomPowerLevelsEventContent, - pdu: &Raw, - room_id: &RoomId, - db: &Database, -) -> Result<&'a [Action]> { - let ctx = PushConditionRoomCtx { - room_id: room_id.to_owned(), - member_count: 10_u32.into(), // TODO: get member count efficiently - user_display_name: db - .users - .displayname(user)? - .unwrap_or_else(|| user.localpart().to_owned()), - users_power_levels: power_levels.users.clone(), - default_power_level: power_levels.users_default, - notification_power_levels: power_levels.notifications.clone(), - }; - - Ok(ruleset.get_actions(pdu, &ctx)) -} - -#[tracing::instrument(skip(unread, pusher, tweaks, event, db))] -async fn send_notice( - unread: UInt, - pusher: &get_pushers::v3::Pusher, - tweaks: Vec, - event: &PduEvent, - db: &Database, -) -> Result<()> { - // TODO: email - if pusher.kind == PusherKind::Email { - return Ok(()); - } - - // TODO: - // Two problems with this - // 1. if "event_id_only" is the only format kind it seems we should never add more info - // 2. can pusher/devices have conflicting formats - let event_id_only = pusher.data.format == Some(PushFormat::EventIdOnly); - let url = if let Some(url) = &pusher.data.url { - url - } else { - error!("Http Pusher must have URL specified."); - return Ok(()); - }; - - let mut device = Device::new(pusher.app_id.clone(), pusher.pushkey.clone()); - let mut data_minus_url = pusher.data.clone(); - // The url must be stripped off according to spec - data_minus_url.url = None; - device.data = data_minus_url; - - // Tweaks are only added if the format is NOT event_id_only - if !event_id_only { - device.tweaks = tweaks.clone(); - } - - let d = &[device]; - let mut notifi = Notification::new(d); - - notifi.prio = NotificationPriority::Low; - notifi.event_id = Some(&event.event_id); - notifi.room_id = Some(&event.room_id); - // TODO: missed calls - notifi.counts = NotificationCounts::new(unread, uint!(0)); - - if event.kind == RoomEventType::RoomEncrypted - || tweaks - .iter() - .any(|t| matches!(t, Tweak::Highlight(true) | Tweak::Sound(_))) - { - notifi.prio = NotificationPriority::High - } - - if event_id_only { - send_request( - &db.globals, - url, - send_event_notification::v1::Request::new(notifi), - ) - .await?; - } else { - notifi.sender = Some(&event.sender); - notifi.event_type = Some(&event.kind); - let content = serde_json::value::to_raw_value(&event.content).ok(); - notifi.content = content.as_deref(); - - if event.kind == RoomEventType::RoomMember { - notifi.user_is_target = event.state_key.as_deref() == Some(event.sender.as_str()); - } - - let user_name = db.users.displayname(&event.sender)?; - notifi.sender_display_name = user_name.as_deref(); - - let room_name = if let Some(room_name_pdu) = - db.rooms - .room_state_get(&event.room_id, &StateEventType::RoomName, "")? - { - serde_json::from_str::(room_name_pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid room name event in database."))? - .name - } else { - None - }; - - notifi.room_name = room_name.as_deref(); - - send_request( - &db.globals, - url, - send_event_notification::v1::Request::new(notifi), - ) - .await?; - } - - // TODO: email - - Ok(()) -} diff --git a/src/service/pusher/data.rs b/src/service/pusher/data.rs index 6b906c24..468ad8b4 100644 --- a/src/service/pusher/data.rs +++ b/src/service/pusher/data.rs @@ -1,348 +1,12 @@ -use crate::{Database, Error, PduEvent, Result}; -use bytes::BytesMut; -use ruma::{ - api::{ - client::push::{get_pushers, set_pusher, PusherKind}, - push_gateway::send_event_notification::{ - self, - v1::{Device, Notification, NotificationCounts, NotificationPriority}, - }, - IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken, - }, - events::{ - room::{name::RoomNameEventContent, power_levels::RoomPowerLevelsEventContent}, - AnySyncRoomEvent, RoomEventType, StateEventType, - }, - push::{Action, PushConditionRoomCtx, PushFormat, Ruleset, Tweak}, - serde::Raw, - uint, RoomId, UInt, UserId, -}; -use tracing::{error, info, warn}; +pub trait Data { + fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::Pusher) -> Result<()>; -use std::{fmt::Debug, mem, sync::Arc}; + pub fn get_pusher(&self, senderkey: &[u8]) -> Result>; -use super::abstraction::Tree; + pub fn get_pushers(&self, sender: &UserId) -> Result>; -pub struct PushData { - /// UserId + pushkey -> Pusher - pub(super) senderkey_pusher: Arc, -} - -impl PushData { - #[tracing::instrument(skip(self, sender, pusher))] - pub fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::Pusher) -> Result<()> { - let mut key = sender.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(pusher.pushkey.as_bytes()); - - // There are 2 kinds of pushers but the spec says: null deletes the pusher. - if pusher.kind.is_none() { - return self - .senderkey_pusher - .remove(&key) - .map(|_| ()) - .map_err(Into::into); - } - - self.senderkey_pusher.insert( - &key, - &serde_json::to_vec(&pusher).expect("Pusher is valid JSON value"), - )?; - - Ok(()) - } - - #[tracing::instrument(skip(self, senderkey))] - pub fn get_pusher(&self, senderkey: &[u8]) -> Result> { - self.senderkey_pusher - .get(senderkey)? - .map(|push| { - serde_json::from_slice(&*push) - .map_err(|_| Error::bad_database("Invalid Pusher in db.")) - }) - .transpose() - } - - #[tracing::instrument(skip(self, sender))] - pub fn get_pushers(&self, sender: &UserId) -> Result> { - let mut prefix = sender.as_bytes().to_vec(); - prefix.push(0xff); - - self.senderkey_pusher - .scan_prefix(prefix) - .map(|(_, push)| { - serde_json::from_slice(&*push) - .map_err(|_| Error::bad_database("Invalid Pusher in db.")) - }) - .collect() - } - - #[tracing::instrument(skip(self, sender))] pub fn get_pusher_senderkeys<'a>( &'a self, sender: &UserId, - ) -> impl Iterator> + 'a { - let mut prefix = sender.as_bytes().to_vec(); - prefix.push(0xff); - - self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| k) - } -} - -#[tracing::instrument(skip(globals, destination, request))] -pub async fn send_request( - globals: &crate::database::globals::Globals, - destination: &str, - request: T, -) -> Result -where - T: Debug, -{ - let destination = destination.replace("/_matrix/push/v1/notify", ""); - - let http_request = request - .try_into_http_request::( - &destination, - SendAccessToken::IfRequired(""), - &[MatrixVersion::V1_0], - ) - .map_err(|e| { - warn!("Failed to find destination {}: {}", destination, e); - Error::BadServerResponse("Invalid destination") - })? - .map(|body| body.freeze()); - - let reqwest_request = reqwest::Request::try_from(http_request) - .expect("all http requests are valid reqwest requests"); - - // TODO: we could keep this very short and let expo backoff do it's thing... - //*reqwest_request.timeout_mut() = Some(Duration::from_secs(5)); - - let url = reqwest_request.url().clone(); - let response = globals.default_client().execute(reqwest_request).await; - - match response { - Ok(mut response) => { - // reqwest::Response -> http::Response conversion - let status = response.status(); - let mut http_response_builder = http::Response::builder() - .status(status) - .version(response.version()); - mem::swap( - response.headers_mut(), - http_response_builder - .headers_mut() - .expect("http::response::Builder is usable"), - ); - - let body = response.bytes().await.unwrap_or_else(|e| { - warn!("server error {}", e); - Vec::new().into() - }); // TODO: handle timeout - - if status != 200 { - info!( - "Push gateway returned bad response {} {}\n{}\n{:?}", - destination, - status, - url, - crate::utils::string_from_bytes(&body) - ); - } - - let response = T::IncomingResponse::try_from_http_response( - http_response_builder - .body(body) - .expect("reqwest body is valid http body"), - ); - response.map_err(|_| { - info!( - "Push gateway returned invalid response bytes {}\n{}", - destination, url - ); - Error::BadServerResponse("Push gateway returned bad response.") - }) - } - Err(e) => Err(e.into()), - } -} - -#[tracing::instrument(skip(user, unread, pusher, ruleset, pdu, db))] -pub async fn send_push_notice( - user: &UserId, - unread: UInt, - pusher: &get_pushers::v3::Pusher, - ruleset: Ruleset, - pdu: &PduEvent, - db: &Database, -) -> Result<()> { - let mut notify = None; - let mut tweaks = Vec::new(); - - let power_levels: RoomPowerLevelsEventContent = db - .rooms - .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")? - .map(|ev| { - serde_json::from_str(ev.content.get()) - .map_err(|_| Error::bad_database("invalid m.room.power_levels event")) - }) - .transpose()? - .unwrap_or_default(); - - for action in get_actions( - user, - &ruleset, - &power_levels, - &pdu.to_sync_room_event(), - &pdu.room_id, - db, - )? { - let n = match action { - Action::DontNotify => false, - // TODO: Implement proper support for coalesce - Action::Notify | Action::Coalesce => true, - Action::SetTweak(tweak) => { - tweaks.push(tweak.clone()); - continue; - } - }; - - if notify.is_some() { - return Err(Error::bad_database( - r#"Malformed pushrule contains more than one of these actions: ["dont_notify", "notify", "coalesce"]"#, - )); - } - - notify = Some(n); - } - - if notify == Some(true) { - send_notice(unread, pusher, tweaks, pdu, db).await?; - } - // Else the event triggered no actions - - Ok(()) -} - -#[tracing::instrument(skip(user, ruleset, pdu, db))] -pub fn get_actions<'a>( - user: &UserId, - ruleset: &'a Ruleset, - power_levels: &RoomPowerLevelsEventContent, - pdu: &Raw, - room_id: &RoomId, - db: &Database, -) -> Result<&'a [Action]> { - let ctx = PushConditionRoomCtx { - room_id: room_id.to_owned(), - member_count: 10_u32.into(), // TODO: get member count efficiently - user_display_name: db - .users - .displayname(user)? - .unwrap_or_else(|| user.localpart().to_owned()), - users_power_levels: power_levels.users.clone(), - default_power_level: power_levels.users_default, - notification_power_levels: power_levels.notifications.clone(), - }; - - Ok(ruleset.get_actions(pdu, &ctx)) -} - -#[tracing::instrument(skip(unread, pusher, tweaks, event, db))] -async fn send_notice( - unread: UInt, - pusher: &get_pushers::v3::Pusher, - tweaks: Vec, - event: &PduEvent, - db: &Database, -) -> Result<()> { - // TODO: email - if pusher.kind == PusherKind::Email { - return Ok(()); - } - - // TODO: - // Two problems with this - // 1. if "event_id_only" is the only format kind it seems we should never add more info - // 2. can pusher/devices have conflicting formats - let event_id_only = pusher.data.format == Some(PushFormat::EventIdOnly); - let url = if let Some(url) = &pusher.data.url { - url - } else { - error!("Http Pusher must have URL specified."); - return Ok(()); - }; - - let mut device = Device::new(pusher.app_id.clone(), pusher.pushkey.clone()); - let mut data_minus_url = pusher.data.clone(); - // The url must be stripped off according to spec - data_minus_url.url = None; - device.data = data_minus_url; - - // Tweaks are only added if the format is NOT event_id_only - if !event_id_only { - device.tweaks = tweaks.clone(); - } - - let d = &[device]; - let mut notifi = Notification::new(d); - - notifi.prio = NotificationPriority::Low; - notifi.event_id = Some(&event.event_id); - notifi.room_id = Some(&event.room_id); - // TODO: missed calls - notifi.counts = NotificationCounts::new(unread, uint!(0)); - - if event.kind == RoomEventType::RoomEncrypted - || tweaks - .iter() - .any(|t| matches!(t, Tweak::Highlight(true) | Tweak::Sound(_))) - { - notifi.prio = NotificationPriority::High - } - - if event_id_only { - send_request( - &db.globals, - url, - send_event_notification::v1::Request::new(notifi), - ) - .await?; - } else { - notifi.sender = Some(&event.sender); - notifi.event_type = Some(&event.kind); - let content = serde_json::value::to_raw_value(&event.content).ok(); - notifi.content = content.as_deref(); - - if event.kind == RoomEventType::RoomMember { - notifi.user_is_target = event.state_key.as_deref() == Some(event.sender.as_str()); - } - - let user_name = db.users.displayname(&event.sender)?; - notifi.sender_display_name = user_name.as_deref(); - - let room_name = if let Some(room_name_pdu) = - db.rooms - .room_state_get(&event.room_id, &StateEventType::RoomName, "")? - { - serde_json::from_str::(room_name_pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid room name event in database."))? - .name - } else { - None - }; - - notifi.room_name = room_name.as_deref(); - - send_request( - &db.globals, - url, - send_event_notification::v1::Request::new(notifi), - ) - .await?; - } - - // TODO: email - - Ok(()) + ) -> impl Iterator> + 'a; } diff --git a/src/service/pusher/mod.rs b/src/service/pusher/mod.rs index 6b906c24..342763e8 100644 --- a/src/service/pusher/mod.rs +++ b/src/service/pusher/mod.rs @@ -1,348 +1,287 @@ -use crate::{Database, Error, PduEvent, Result}; -use bytes::BytesMut; -use ruma::{ - api::{ - client::push::{get_pushers, set_pusher, PusherKind}, - push_gateway::send_event_notification::{ - self, - v1::{Device, Notification, NotificationCounts, NotificationPriority}, - }, - IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken, - }, - events::{ - room::{name::RoomNameEventContent, power_levels::RoomPowerLevelsEventContent}, - AnySyncRoomEvent, RoomEventType, StateEventType, - }, - push::{Action, PushConditionRoomCtx, PushFormat, Ruleset, Tweak}, - serde::Raw, - uint, RoomId, UInt, UserId, -}; -use tracing::{error, info, warn}; +mod data; +pub use data::Data; -use std::{fmt::Debug, mem, sync::Arc}; +use crate::service::*; -use super::abstraction::Tree; - -pub struct PushData { - /// UserId + pushkey -> Pusher - pub(super) senderkey_pusher: Arc, +pub struct Service { + db: D, } -impl PushData { - #[tracing::instrument(skip(self, sender, pusher))] +impl Service<_> { pub fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::Pusher) -> Result<()> { - let mut key = sender.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(pusher.pushkey.as_bytes()); - - // There are 2 kinds of pushers but the spec says: null deletes the pusher. - if pusher.kind.is_none() { - return self - .senderkey_pusher - .remove(&key) - .map(|_| ()) - .map_err(Into::into); - } - - self.senderkey_pusher.insert( - &key, - &serde_json::to_vec(&pusher).expect("Pusher is valid JSON value"), - )?; - - Ok(()) + self.db.set_pusher(sender, pusher) } - #[tracing::instrument(skip(self, senderkey))] pub fn get_pusher(&self, senderkey: &[u8]) -> Result> { - self.senderkey_pusher - .get(senderkey)? - .map(|push| { - serde_json::from_slice(&*push) - .map_err(|_| Error::bad_database("Invalid Pusher in db.")) - }) - .transpose() + self.db.get_pusher(senderkey) } - #[tracing::instrument(skip(self, sender))] pub fn get_pushers(&self, sender: &UserId) -> Result> { - let mut prefix = sender.as_bytes().to_vec(); - prefix.push(0xff); - - self.senderkey_pusher - .scan_prefix(prefix) - .map(|(_, push)| { - serde_json::from_slice(&*push) - .map_err(|_| Error::bad_database("Invalid Pusher in db.")) - }) - .collect() + self.db.get_pushers(sender) } - #[tracing::instrument(skip(self, sender))] pub fn get_pusher_senderkeys<'a>( &'a self, sender: &UserId, ) -> impl Iterator> + 'a { - let mut prefix = sender.as_bytes().to_vec(); - prefix.push(0xff); - - self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| k) - } -} - -#[tracing::instrument(skip(globals, destination, request))] -pub async fn send_request( - globals: &crate::database::globals::Globals, - destination: &str, - request: T, -) -> Result -where - T: Debug, -{ - let destination = destination.replace("/_matrix/push/v1/notify", ""); - - let http_request = request - .try_into_http_request::( - &destination, - SendAccessToken::IfRequired(""), - &[MatrixVersion::V1_0], - ) - .map_err(|e| { - warn!("Failed to find destination {}: {}", destination, e); - Error::BadServerResponse("Invalid destination") - })? - .map(|body| body.freeze()); - - let reqwest_request = reqwest::Request::try_from(http_request) - .expect("all http requests are valid reqwest requests"); - - // TODO: we could keep this very short and let expo backoff do it's thing... - //*reqwest_request.timeout_mut() = Some(Duration::from_secs(5)); - - let url = reqwest_request.url().clone(); - let response = globals.default_client().execute(reqwest_request).await; - - match response { - Ok(mut response) => { - // reqwest::Response -> http::Response conversion - let status = response.status(); - let mut http_response_builder = http::Response::builder() - .status(status) - .version(response.version()); - mem::swap( - response.headers_mut(), - http_response_builder - .headers_mut() - .expect("http::response::Builder is usable"), - ); - - let body = response.bytes().await.unwrap_or_else(|e| { - warn!("server error {}", e); - Vec::new().into() - }); // TODO: handle timeout - - if status != 200 { - info!( - "Push gateway returned bad response {} {}\n{}\n{:?}", - destination, - status, - url, - crate::utils::string_from_bytes(&body) - ); - } - - let response = T::IncomingResponse::try_from_http_response( - http_response_builder - .body(body) - .expect("reqwest body is valid http body"), - ); - response.map_err(|_| { - info!( - "Push gateway returned invalid response bytes {}\n{}", - destination, url - ); - Error::BadServerResponse("Push gateway returned bad response.") - }) - } - Err(e) => Err(e.into()), - } -} - -#[tracing::instrument(skip(user, unread, pusher, ruleset, pdu, db))] -pub async fn send_push_notice( - user: &UserId, - unread: UInt, - pusher: &get_pushers::v3::Pusher, - ruleset: Ruleset, - pdu: &PduEvent, - db: &Database, -) -> Result<()> { - let mut notify = None; - let mut tweaks = Vec::new(); - - let power_levels: RoomPowerLevelsEventContent = db - .rooms - .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")? - .map(|ev| { - serde_json::from_str(ev.content.get()) - .map_err(|_| Error::bad_database("invalid m.room.power_levels event")) - }) - .transpose()? - .unwrap_or_default(); - - for action in get_actions( - user, - &ruleset, - &power_levels, - &pdu.to_sync_room_event(), - &pdu.room_id, - db, - )? { - let n = match action { - Action::DontNotify => false, - // TODO: Implement proper support for coalesce - Action::Notify | Action::Coalesce => true, - Action::SetTweak(tweak) => { - tweaks.push(tweak.clone()); - continue; - } - }; - - if notify.is_some() { - return Err(Error::bad_database( - r#"Malformed pushrule contains more than one of these actions: ["dont_notify", "notify", "coalesce"]"#, - )); - } - - notify = Some(n); + self.db.get_pusher_senderkeys(sender) } - if notify == Some(true) { - send_notice(unread, pusher, tweaks, pdu, db).await?; - } - // Else the event triggered no actions - - Ok(()) -} - -#[tracing::instrument(skip(user, ruleset, pdu, db))] -pub fn get_actions<'a>( - user: &UserId, - ruleset: &'a Ruleset, - power_levels: &RoomPowerLevelsEventContent, - pdu: &Raw, - room_id: &RoomId, - db: &Database, -) -> Result<&'a [Action]> { - let ctx = PushConditionRoomCtx { - room_id: room_id.to_owned(), - member_count: 10_u32.into(), // TODO: get member count efficiently - user_display_name: db - .users - .displayname(user)? - .unwrap_or_else(|| user.localpart().to_owned()), - users_power_levels: power_levels.users.clone(), - default_power_level: power_levels.users_default, - notification_power_levels: power_levels.notifications.clone(), - }; - - Ok(ruleset.get_actions(pdu, &ctx)) -} - -#[tracing::instrument(skip(unread, pusher, tweaks, event, db))] -async fn send_notice( - unread: UInt, - pusher: &get_pushers::v3::Pusher, - tweaks: Vec, - event: &PduEvent, - db: &Database, -) -> Result<()> { - // TODO: email - if pusher.kind == PusherKind::Email { - return Ok(()); - } - - // TODO: - // Two problems with this - // 1. if "event_id_only" is the only format kind it seems we should never add more info - // 2. can pusher/devices have conflicting formats - let event_id_only = pusher.data.format == Some(PushFormat::EventIdOnly); - let url = if let Some(url) = &pusher.data.url { - url - } else { - error!("Http Pusher must have URL specified."); - return Ok(()); - }; - - let mut device = Device::new(pusher.app_id.clone(), pusher.pushkey.clone()); - let mut data_minus_url = pusher.data.clone(); - // The url must be stripped off according to spec - data_minus_url.url = None; - device.data = data_minus_url; - - // Tweaks are only added if the format is NOT event_id_only - if !event_id_only { - device.tweaks = tweaks.clone(); - } - - let d = &[device]; - let mut notifi = Notification::new(d); - - notifi.prio = NotificationPriority::Low; - notifi.event_id = Some(&event.event_id); - notifi.room_id = Some(&event.room_id); - // TODO: missed calls - notifi.counts = NotificationCounts::new(unread, uint!(0)); - - if event.kind == RoomEventType::RoomEncrypted - || tweaks - .iter() - .any(|t| matches!(t, Tweak::Highlight(true) | Tweak::Sound(_))) + #[tracing::instrument(skip(globals, destination, request))] + pub async fn send_request( + globals: &crate::database::globals::Globals, + destination: &str, + request: T, + ) -> Result + where + T: Debug, { - notifi.prio = NotificationPriority::High + let destination = destination.replace("/_matrix/push/v1/notify", ""); + + let http_request = request + .try_into_http_request::( + &destination, + SendAccessToken::IfRequired(""), + &[MatrixVersion::V1_0], + ) + .map_err(|e| { + warn!("Failed to find destination {}: {}", destination, e); + Error::BadServerResponse("Invalid destination") + })? + .map(|body| body.freeze()); + + let reqwest_request = reqwest::Request::try_from(http_request) + .expect("all http requests are valid reqwest requests"); + + // TODO: we could keep this very short and let expo backoff do it's thing... + //*reqwest_request.timeout_mut() = Some(Duration::from_secs(5)); + + let url = reqwest_request.url().clone(); + let response = globals.default_client().execute(reqwest_request).await; + + match response { + Ok(mut response) => { + // reqwest::Response -> http::Response conversion + let status = response.status(); + let mut http_response_builder = http::Response::builder() + .status(status) + .version(response.version()); + mem::swap( + response.headers_mut(), + http_response_builder + .headers_mut() + .expect("http::response::Builder is usable"), + ); + + let body = response.bytes().await.unwrap_or_else(|e| { + warn!("server error {}", e); + Vec::new().into() + }); // TODO: handle timeout + + if status != 200 { + info!( + "Push gateway returned bad response {} {}\n{}\n{:?}", + destination, + status, + url, + crate::utils::string_from_bytes(&body) + ); + } + + let response = T::IncomingResponse::try_from_http_response( + http_response_builder + .body(body) + .expect("reqwest body is valid http body"), + ); + response.map_err(|_| { + info!( + "Push gateway returned invalid response bytes {}\n{}", + destination, url + ); + Error::BadServerResponse("Push gateway returned bad response.") + }) + } + Err(e) => Err(e.into()), + } } - if event_id_only { - send_request( - &db.globals, - url, - send_event_notification::v1::Request::new(notifi), - ) - .await?; - } else { - notifi.sender = Some(&event.sender); - notifi.event_type = Some(&event.kind); - let content = serde_json::value::to_raw_value(&event.content).ok(); - notifi.content = content.as_deref(); + #[tracing::instrument(skip(user, unread, pusher, ruleset, pdu, db))] + pub async fn send_push_notice( + user: &UserId, + unread: UInt, + pusher: &get_pushers::v3::Pusher, + ruleset: Ruleset, + pdu: &PduEvent, + db: &Database, + ) -> Result<()> { + let mut notify = None; + let mut tweaks = Vec::new(); - if event.kind == RoomEventType::RoomMember { - notifi.user_is_target = event.state_key.as_deref() == Some(event.sender.as_str()); + let power_levels: RoomPowerLevelsEventContent = db + .rooms + .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")? + .map(|ev| { + serde_json::from_str(ev.content.get()) + .map_err(|_| Error::bad_database("invalid m.room.power_levels event")) + }) + .transpose()? + .unwrap_or_default(); + + for action in get_actions( + user, + &ruleset, + &power_levels, + &pdu.to_sync_room_event(), + &pdu.room_id, + db, + )? { + let n = match action { + Action::DontNotify => false, + // TODO: Implement proper support for coalesce + Action::Notify | Action::Coalesce => true, + Action::SetTweak(tweak) => { + tweaks.push(tweak.clone()); + continue; + } + }; + + if notify.is_some() { + return Err(Error::bad_database( + r#"Malformed pushrule contains more than one of these actions: ["dont_notify", "notify", "coalesce"]"#, + )); + } + + notify = Some(n); } - let user_name = db.users.displayname(&event.sender)?; - notifi.sender_display_name = user_name.as_deref(); + if notify == Some(true) { + send_notice(unread, pusher, tweaks, pdu, db).await?; + } + // Else the event triggered no actions - let room_name = if let Some(room_name_pdu) = - db.rooms - .room_state_get(&event.room_id, &StateEventType::RoomName, "")? - { - serde_json::from_str::(room_name_pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid room name event in database."))? - .name - } else { - None - }; - - notifi.room_name = room_name.as_deref(); - - send_request( - &db.globals, - url, - send_event_notification::v1::Request::new(notifi), - ) - .await?; + Ok(()) } - // TODO: email + #[tracing::instrument(skip(user, ruleset, pdu, db))] + pub fn get_actions<'a>( + user: &UserId, + ruleset: &'a Ruleset, + power_levels: &RoomPowerLevelsEventContent, + pdu: &Raw, + room_id: &RoomId, + db: &Database, + ) -> Result<&'a [Action]> { + let ctx = PushConditionRoomCtx { + room_id: room_id.to_owned(), + member_count: 10_u32.into(), // TODO: get member count efficiently + user_display_name: db + .users + .displayname(user)? + .unwrap_or_else(|| user.localpart().to_owned()), + users_power_levels: power_levels.users.clone(), + default_power_level: power_levels.users_default, + notification_power_levels: power_levels.notifications.clone(), + }; - Ok(()) + Ok(ruleset.get_actions(pdu, &ctx)) + } + + #[tracing::instrument(skip(unread, pusher, tweaks, event, db))] + async fn send_notice( + unread: UInt, + pusher: &get_pushers::v3::Pusher, + tweaks: Vec, + event: &PduEvent, + db: &Database, + ) -> Result<()> { + // TODO: email + if pusher.kind == PusherKind::Email { + return Ok(()); + } + + // TODO: + // Two problems with this + // 1. if "event_id_only" is the only format kind it seems we should never add more info + // 2. can pusher/devices have conflicting formats + let event_id_only = pusher.data.format == Some(PushFormat::EventIdOnly); + let url = if let Some(url) = &pusher.data.url { + url + } else { + error!("Http Pusher must have URL specified."); + return Ok(()); + }; + + let mut device = Device::new(pusher.app_id.clone(), pusher.pushkey.clone()); + let mut data_minus_url = pusher.data.clone(); + // The url must be stripped off according to spec + data_minus_url.url = None; + device.data = data_minus_url; + + // Tweaks are only added if the format is NOT event_id_only + if !event_id_only { + device.tweaks = tweaks.clone(); + } + + let d = &[device]; + let mut notifi = Notification::new(d); + + notifi.prio = NotificationPriority::Low; + notifi.event_id = Some(&event.event_id); + notifi.room_id = Some(&event.room_id); + // TODO: missed calls + notifi.counts = NotificationCounts::new(unread, uint!(0)); + + if event.kind == RoomEventType::RoomEncrypted + || tweaks + .iter() + .any(|t| matches!(t, Tweak::Highlight(true) | Tweak::Sound(_))) + { + notifi.prio = NotificationPriority::High + } + + if event_id_only { + send_request( + &db.globals, + url, + send_event_notification::v1::Request::new(notifi), + ) + .await?; + } else { + notifi.sender = Some(&event.sender); + notifi.event_type = Some(&event.kind); + let content = serde_json::value::to_raw_value(&event.content).ok(); + notifi.content = content.as_deref(); + + if event.kind == RoomEventType::RoomMember { + notifi.user_is_target = event.state_key.as_deref() == Some(event.sender.as_str()); + } + + let user_name = db.users.displayname(&event.sender)?; + notifi.sender_display_name = user_name.as_deref(); + + let room_name = if let Some(room_name_pdu) = + db.rooms + .room_state_get(&event.room_id, &StateEventType::RoomName, "")? + { + serde_json::from_str::(room_name_pdu.content.get()) + .map_err(|_| Error::bad_database("Invalid room name event in database."))? + .name + } else { + None + }; + + notifi.room_name = room_name.as_deref(); + + send_request( + &db.globals, + url, + send_event_notification::v1::Request::new(notifi), + ) + .await?; + } + + // TODO: email + + Ok(()) + } } diff --git a/src/service/rooms/short/mod.rs b/src/service/rooms/short/mod.rs index c44d357c..a8e87b91 100644 --- a/src/service/rooms/short/mod.rs +++ b/src/service/rooms/short/mod.rs @@ -1,4 +1,13 @@ +mod data; +pub use data::Data; +use crate::service::*; + +pub struct Service { + db: D, +} + +impl Service<_> { pub fn get_or_create_shorteventid( &self, event_id: &EventId, @@ -222,4 +231,4 @@ } }) } - +} diff --git a/src/service/rooms/timeline/data.rs b/src/service/rooms/timeline/data.rs index 5b423d2d..4e5c3796 100644 --- a/src/service/rooms/timeline/data.rs +++ b/src/service/rooms/timeline/data.rs @@ -1,816 +1,41 @@ - - /// Checks if a room exists. - #[tracing::instrument(skip(self))] - pub fn first_pdu_in_room(&self, room_id: &RoomId) -> Result>> { - let prefix = self - .get_shortroomid(room_id)? - .expect("room exists") - .to_be_bytes() - .to_vec(); - - // Look for PDUs in that room. - self.pduid_pdu - .iter_from(&prefix, false) - .filter(|(k, _)| k.starts_with(&prefix)) - .map(|(_, pdu)| { - serde_json::from_slice(&pdu) - .map_err(|_| Error::bad_database("Invalid first PDU in db.")) - .map(Arc::new) - }) - .next() - .transpose() - } - - #[tracing::instrument(skip(self))] - pub fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { - match self - .lasttimelinecount_cache - .lock() - .unwrap() - .entry(room_id.to_owned()) - { - hash_map::Entry::Vacant(v) => { - if let Some(last_count) = self - .pdus_until(&sender_user, &room_id, u64::MAX)? - .filter_map(|r| { - // Filter out buggy events - if r.is_err() { - error!("Bad pdu in pdus_since: {:?}", r); - } - r.ok() - }) - .map(|(pduid, _)| self.pdu_count(&pduid)) - .next() - { - Ok(*v.insert(last_count?)) - } else { - Ok(0) - } - } - hash_map::Entry::Occupied(o) => Ok(*o.get()), - } - } - - // TODO Is this the same as the function above? - #[tracing::instrument(skip(self))] - pub fn latest_pdu_count(&self, room_id: &RoomId) -> Result { - let prefix = self - .get_shortroomid(room_id)? - .expect("room exists") - .to_be_bytes() - .to_vec(); - - let mut last_possible_key = prefix.clone(); - last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); - - self.pduid_pdu - .iter_from(&last_possible_key, true) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .next() - .map(|b| self.pdu_count(&b.0)) - .transpose() - .map(|op| op.unwrap_or_default()) - } - - +pub trait Data { + fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result; /// Returns the `count` of this pdu's id. - pub fn get_pdu_count(&self, event_id: &EventId) -> Result> { - self.eventid_pduid - .get(event_id.as_bytes())? - .map(|pdu_id| self.pdu_count(&pdu_id)) - .transpose() - } + fn get_pdu_count(&self, event_id: &EventId) -> Result>; /// Returns the json of a pdu. - pub fn get_pdu_json(&self, event_id: &EventId) -> Result> { - self.eventid_pduid - .get(event_id.as_bytes())? - .map_or_else( - || self.eventid_outlierpdu.get(event_id.as_bytes()), - |pduid| { - Ok(Some(self.pduid_pdu.get(&pduid)?.ok_or_else(|| { - Error::bad_database("Invalid pduid in eventid_pduid.") - })?)) - }, - )? - .map(|pdu| { - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) - }) - .transpose() - } + pub fn get_pdu_json(&self, event_id: &EventId) -> Result>; /// Returns the json of a pdu. pub fn get_non_outlier_pdu_json( - &self, - event_id: &EventId, - ) -> Result> { - self.eventid_pduid - .get(event_id.as_bytes())? - .map(|pduid| { - self.pduid_pdu - .get(&pduid)? - .ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid.")) - }) - .transpose()? - .map(|pdu| { - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) - }) - .transpose() - } /// Returns the pdu's id. - pub fn get_pdu_id(&self, event_id: &EventId) -> Result>> { - self.eventid_pduid.get(event_id.as_bytes()) - } + pub fn get_pdu_id(&self, event_id: &EventId) -> Result>>; /// Returns the pdu. /// /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. - pub fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result> { - self.eventid_pduid - .get(event_id.as_bytes())? - .map(|pduid| { - self.pduid_pdu - .get(&pduid)? - .ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid.")) - }) - .transpose()? - .map(|pdu| { - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) - }) - .transpose() - } + pub fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result>; /// Returns the pdu. /// /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. - pub fn get_pdu(&self, event_id: &EventId) -> Result>> { - if let Some(p) = self.pdu_cache.lock().unwrap().get_mut(event_id) { - return Ok(Some(Arc::clone(p))); - } - - if let Some(pdu) = self - .eventid_pduid - .get(event_id.as_bytes())? - .map_or_else( - || self.eventid_outlierpdu.get(event_id.as_bytes()), - |pduid| { - Ok(Some(self.pduid_pdu.get(&pduid)?.ok_or_else(|| { - Error::bad_database("Invalid pduid in eventid_pduid.") - })?)) - }, - )? - .map(|pdu| { - serde_json::from_slice(&pdu) - .map_err(|_| Error::bad_database("Invalid PDU in db.")) - .map(Arc::new) - }) - .transpose()? - { - self.pdu_cache - .lock() - .unwrap() - .insert(event_id.to_owned(), Arc::clone(&pdu)); - Ok(Some(pdu)) - } else { - Ok(None) - } - } + pub fn get_pdu(&self, event_id: &EventId) -> Result>>; /// Returns the pdu. /// /// This does __NOT__ check the outliers `Tree`. - pub fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result> { - self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { - Ok(Some( - serde_json::from_slice(&pdu) - .map_err(|_| Error::bad_database("Invalid PDU in db."))?, - )) - }) - } + pub fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result>; /// Returns the pdu as a `BTreeMap`. - pub fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result> { - self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { - Ok(Some( - serde_json::from_slice(&pdu) - .map_err(|_| Error::bad_database("Invalid PDU in db."))?, - )) - }) - } + pub fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result>; /// Returns the `count` of this pdu's id. - pub fn pdu_count(&self, pdu_id: &[u8]) -> Result { - utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::()..]) - .map_err(|_| Error::bad_database("PDU has invalid count bytes.")) - } + pub fn pdu_count(&self, pdu_id: &[u8]) -> Result; /// Removes a pdu and creates a new one with the same id. - #[tracing::instrument(skip(self))] - fn replace_pdu(&self, pdu_id: &[u8], pdu: &PduEvent) -> Result<()> { - if self.pduid_pdu.get(pdu_id)?.is_some() { - self.pduid_pdu.insert( - pdu_id, - &serde_json::to_vec(pdu).expect("PduEvent::to_vec always works"), - )?; - Ok(()) - } else { - Err(Error::BadRequest( - ErrorKind::NotFound, - "PDU does not exist.", - )) - } - } - - /// Creates a new persisted data unit and adds it to a room. - /// - /// By this point the incoming event should be fully authenticated, no auth happens - /// in `append_pdu`. - /// - /// Returns pdu id - #[tracing::instrument(skip(self, pdu, pdu_json, leaves, db))] - pub fn append_pdu<'a>( - &self, - pdu: &PduEvent, - mut pdu_json: CanonicalJsonObject, - leaves: impl IntoIterator + Debug, - db: &Database, - ) -> Result> { - let shortroomid = self.get_shortroomid(&pdu.room_id)?.expect("room exists"); - - // Make unsigned fields correct. This is not properly documented in the spec, but state - // events need to have previous content in the unsigned field, so clients can easily - // interpret things like membership changes - if let Some(state_key) = &pdu.state_key { - if let CanonicalJsonValue::Object(unsigned) = pdu_json - .entry("unsigned".to_owned()) - .or_insert_with(|| CanonicalJsonValue::Object(Default::default())) - { - if let Some(shortstatehash) = self.pdu_shortstatehash(&pdu.event_id).unwrap() { - if let Some(prev_state) = self - .state_get(shortstatehash, &pdu.kind.to_string().into(), state_key) - .unwrap() - { - unsigned.insert( - "prev_content".to_owned(), - CanonicalJsonValue::Object( - utils::to_canonical_object(prev_state.content.clone()) - .expect("event is valid, we just created it"), - ), - ); - } - } - } else { - error!("Invalid unsigned type in pdu."); - } - } - - // We must keep track of all events that have been referenced. - self.mark_as_referenced(&pdu.room_id, &pdu.prev_events)?; - self.replace_pdu_leaves(&pdu.room_id, leaves)?; - - let mutex_insert = Arc::clone( - db.globals - .roomid_mutex_insert - .write() - .unwrap() - .entry(pdu.room_id.clone()) - .or_default(), - ); - let insert_lock = mutex_insert.lock().unwrap(); - - let count1 = db.globals.next_count()?; - // Mark as read first so the sending client doesn't get a notification even if appending - // fails - self.edus - .private_read_set(&pdu.room_id, &pdu.sender, count1, &db.globals)?; - self.reset_notification_counts(&pdu.sender, &pdu.room_id)?; - - let count2 = db.globals.next_count()?; - let mut pdu_id = shortroomid.to_be_bytes().to_vec(); - pdu_id.extend_from_slice(&count2.to_be_bytes()); - - // There's a brief moment of time here where the count is updated but the pdu does not - // exist. This could theoretically lead to dropped pdus, but it's extremely rare - // - // Update: We fixed this using insert_lock - - self.pduid_pdu.insert( - &pdu_id, - &serde_json::to_vec(&pdu_json).expect("CanonicalJsonObject is always a valid"), - )?; - self.lasttimelinecount_cache - .lock() - .unwrap() - .insert(pdu.room_id.clone(), count2); - - self.eventid_pduid - .insert(pdu.event_id.as_bytes(), &pdu_id)?; - self.eventid_outlierpdu.remove(pdu.event_id.as_bytes())?; - - drop(insert_lock); - - // See if the event matches any known pushers - let power_levels: RoomPowerLevelsEventContent = db - .rooms - .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")? - .map(|ev| { - serde_json::from_str(ev.content.get()) - .map_err(|_| Error::bad_database("invalid m.room.power_levels event")) - }) - .transpose()? - .unwrap_or_default(); - - let sync_pdu = pdu.to_sync_room_event(); - - let mut notifies = Vec::new(); - let mut highlights = Vec::new(); - - for user in self.get_our_real_users(&pdu.room_id, db)?.iter() { - // Don't notify the user of their own events - if user == &pdu.sender { - continue; - } - - let rules_for_user = db - .account_data - .get( - None, - user, - GlobalAccountDataEventType::PushRules.to_string().into(), - )? - .map(|ev: PushRulesEvent| ev.content.global) - .unwrap_or_else(|| Ruleset::server_default(user)); - - let mut highlight = false; - let mut notify = false; - - for action in pusher::get_actions( - user, - &rules_for_user, - &power_levels, - &sync_pdu, - &pdu.room_id, - db, - )? { - match action { - Action::DontNotify => notify = false, - // TODO: Implement proper support for coalesce - Action::Notify | Action::Coalesce => notify = true, - Action::SetTweak(Tweak::Highlight(true)) => { - highlight = true; - } - _ => {} - }; - } - - let mut userroom_id = user.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(pdu.room_id.as_bytes()); - - if notify { - notifies.push(userroom_id.clone()); - } - - if highlight { - highlights.push(userroom_id); - } - - for senderkey in db.pusher.get_pusher_senderkeys(user) { - db.sending.send_push_pdu(&*pdu_id, senderkey)?; - } - } - - self.userroomid_notificationcount - .increment_batch(&mut notifies.into_iter())?; - self.userroomid_highlightcount - .increment_batch(&mut highlights.into_iter())?; - - match pdu.kind { - RoomEventType::RoomRedaction => { - if let Some(redact_id) = &pdu.redacts { - self.redact_pdu(redact_id, pdu)?; - } - } - RoomEventType::RoomMember => { - if let Some(state_key) = &pdu.state_key { - #[derive(Deserialize)] - struct ExtractMembership { - membership: MembershipState, - } - - // if the state_key fails - let target_user_id = UserId::parse(state_key.clone()) - .expect("This state_key was previously validated"); - - let content = serde_json::from_str::(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid content in pdu."))?; - - let invite_state = match content.membership { - MembershipState::Invite => { - let state = self.calculate_invite_state(pdu)?; - Some(state) - } - _ => None, - }; - - // Update our membership info, we do this here incase a user is invited - // and immediately leaves we need the DB to record the invite event for auth - self.update_membership( - &pdu.room_id, - &target_user_id, - content.membership, - &pdu.sender, - invite_state, - db, - true, - )?; - } - } - RoomEventType::RoomMessage => { - #[derive(Deserialize)] - struct ExtractBody<'a> { - #[serde(borrow)] - body: Option>, - } - - let content = serde_json::from_str::>(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid content in pdu."))?; - - if let Some(body) = content.body { - DB.rooms.search.index_pdu(room_id, pdu_id, body)?; - - let admin_room = self.id_from_alias( - <&RoomAliasId>::try_from( - format!("#admins:{}", db.globals.server_name()).as_str(), - ) - .expect("#admins:server_name is a valid room alias"), - )?; - let server_user = format!("@conduit:{}", db.globals.server_name()); - - let to_conduit = body.starts_with(&format!("{}: ", server_user)); - - // This will evaluate to false if the emergency password is set up so that - // the administrator can execute commands as conduit - let from_conduit = - pdu.sender == server_user && db.globals.emergency_password().is_none(); - - if to_conduit && !from_conduit && admin_room.as_ref() == Some(&pdu.room_id) { - db.admin.process_message(body.to_string()); - } - } - } - _ => {} - } - - for appservice in db.appservice.all()? { - if self.appservice_in_room(room_id, &appservice, db)? { - db.sending.send_pdu_appservice(&appservice.0, &pdu_id)?; - continue; - } - - // If the RoomMember event has a non-empty state_key, it is targeted at someone. - // If it is our appservice user, we send this PDU to it. - if pdu.kind == RoomEventType::RoomMember { - if let Some(state_key_uid) = &pdu - .state_key - .as_ref() - .and_then(|state_key| UserId::parse(state_key.as_str()).ok()) - { - if let Some(appservice_uid) = appservice - .1 - .get("sender_localpart") - .and_then(|string| string.as_str()) - .and_then(|string| { - UserId::parse_with_server_name(string, db.globals.server_name()).ok() - }) - { - if state_key_uid == &appservice_uid { - db.sending.send_pdu_appservice(&appservice.0, &pdu_id)?; - continue; - } - } - } - } - - if let Some(namespaces) = appservice.1.get("namespaces") { - let users = namespaces - .get("users") - .and_then(|users| users.as_sequence()) - .map_or_else(Vec::new, |users| { - users - .iter() - .filter_map(|users| Regex::new(users.get("regex")?.as_str()?).ok()) - .collect::>() - }); - let aliases = namespaces - .get("aliases") - .and_then(|aliases| aliases.as_sequence()) - .map_or_else(Vec::new, |aliases| { - aliases - .iter() - .filter_map(|aliases| Regex::new(aliases.get("regex")?.as_str()?).ok()) - .collect::>() - }); - let rooms = namespaces - .get("rooms") - .and_then(|rooms| rooms.as_sequence()); - - let matching_users = |users: &Regex| { - users.is_match(pdu.sender.as_str()) - || pdu.kind == RoomEventType::RoomMember - && pdu - .state_key - .as_ref() - .map_or(false, |state_key| users.is_match(state_key)) - }; - let matching_aliases = |aliases: &Regex| { - self.room_aliases(room_id) - .filter_map(|r| r.ok()) - .any(|room_alias| aliases.is_match(room_alias.as_str())) - }; - - if aliases.iter().any(matching_aliases) - || rooms.map_or(false, |rooms| rooms.contains(&room_id.as_str().into())) - || users.iter().any(matching_users) - { - db.sending.send_pdu_appservice(&appservice.0, &pdu_id)?; - } - } - } - - - Ok(pdu_id) - } - - pub fn create_hash_and_sign_event( - &self, - pdu_builder: PduBuilder, - sender: &UserId, - room_id: &RoomId, - db: &Database, - _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex - ) -> (PduEvent, CanonicalJsonObj) { - let PduBuilder { - event_type, - content, - unsigned, - state_key, - redacts, - } = pdu_builder; - - let prev_events: Vec<_> = db - .rooms - .get_pdu_leaves(room_id)? - .into_iter() - .take(20) - .collect(); - - let create_event = db - .rooms - .room_state_get(room_id, &StateEventType::RoomCreate, "")?; - - let create_event_content: Option = create_event - .as_ref() - .map(|create_event| { - serde_json::from_str(create_event.content.get()).map_err(|e| { - warn!("Invalid create event: {}", e); - Error::bad_database("Invalid create event in db.") - }) - }) - .transpose()?; - - // If there was no create event yet, assume we are creating a room with the default - // version right now - let room_version_id = create_event_content - .map_or(db.globals.default_room_version(), |create_event| { - create_event.room_version - }); - let room_version = - RoomVersion::new(&room_version_id).expect("room version is supported"); - - let auth_events = - self.get_auth_events(room_id, &event_type, sender, state_key.as_deref(), &content)?; - - // Our depth is the maximum depth of prev_events + 1 - let depth = prev_events - .iter() - .filter_map(|event_id| Some(db.rooms.get_pdu(event_id).ok()??.depth)) - .max() - .unwrap_or_else(|| uint!(0)) - + uint!(1); - - let mut unsigned = unsigned.unwrap_or_default(); - - if let Some(state_key) = &state_key { - if let Some(prev_pdu) = - self.room_state_get(room_id, &event_type.to_string().into(), state_key)? - { - unsigned.insert( - "prev_content".to_owned(), - serde_json::from_str(prev_pdu.content.get()).expect("string is valid json"), - ); - unsigned.insert( - "prev_sender".to_owned(), - serde_json::to_value(&prev_pdu.sender).expect("UserId::to_value always works"), - ); - } - } - - let pdu = PduEvent { - event_id: ruma::event_id!("$thiswillbefilledinlater").into(), - room_id: room_id.to_owned(), - sender: sender_user.to_owned(), - origin_server_ts: utils::millis_since_unix_epoch() - .try_into() - .expect("time is valid"), - kind: event_type, - content, - state_key, - prev_events, - depth, - auth_events: auth_events - .iter() - .map(|(_, pdu)| pdu.event_id.clone()) - .collect(), - redacts, - unsigned: if unsigned.is_empty() { - None - } else { - Some(to_raw_value(&unsigned).expect("to_raw_value always works")) - }, - hashes: EventHash { - sha256: "aaa".to_owned(), - }, - signatures: None, - }; - - let auth_check = state_res::auth_check( - &room_version, - &pdu, - None::, // TODO: third_party_invite - |k, s| auth_events.get(&(k.clone(), s.to_owned())), - ) - .map_err(|e| { - error!("{:?}", e); - Error::bad_database("Auth check failed.") - })?; - - if !auth_check { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Event is not authorized.", - )); - } - - // Hash and sign - let mut pdu_json = - utils::to_canonical_object(&pdu).expect("event is valid, we just created it"); - - pdu_json.remove("event_id"); - - // Add origin because synapse likes that (and it's required in the spec) - pdu_json.insert( - "origin".to_owned(), - to_canonical_value(db.globals.server_name()) - .expect("server name is a valid CanonicalJsonValue"), - ); - - match ruma::signatures::hash_and_sign_event( - db.globals.server_name().as_str(), - db.globals.keypair(), - &mut pdu_json, - &room_version_id, - ) { - Ok(_) => {} - Err(e) => { - return match e { - ruma::signatures::Error::PduSize => Err(Error::BadRequest( - ErrorKind::TooLarge, - "Message is too long", - )), - _ => Err(Error::BadRequest( - ErrorKind::Unknown, - "Signing event failed", - )), - } - } - } - - // Generate event id - pdu.event_id = EventId::parse_arc(format!( - "${}", - ruma::signatures::reference_hash(&pdu_json, &room_version_id) - .expect("ruma can calculate reference hashes") - )) - .expect("ruma's reference hashes are valid event ids"); - - pdu_json.insert( - "event_id".to_owned(), - CanonicalJsonValue::String(pdu.event_id.as_str().to_owned()), - ); - - // Generate short event id - let _shorteventid = self.get_or_create_shorteventid(&pdu.event_id, &db.globals)?; - } - - /// Creates a new persisted data unit and adds it to a room. This function takes a - /// roomid_mutex_state, meaning that only this function is able to mutate the room state. - #[tracing::instrument(skip(self, db, _mutex_lock))] - pub fn build_and_append_pdu( - &self, - pdu_builder: PduBuilder, - sender: &UserId, - room_id: &RoomId, - db: &Database, - _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex - ) -> Result> { - - let (pdu, pdu_json) = create_hash_and_sign_event()?; - - - // We append to state before appending the pdu, so we don't have a moment in time with the - // pdu without it's state. This is okay because append_pdu can't fail. - let statehashid = self.append_to_state(&pdu, &db.globals)?; - - let pdu_id = self.append_pdu( - &pdu, - pdu_json, - // Since this PDU references all pdu_leaves we can update the leaves - // of the room - iter::once(&*pdu.event_id), - db, - )?; - - // We set the room state after inserting the pdu, so that we never have a moment in time - // where events in the current room state do not exist - self.set_room_state(room_id, statehashid)?; - - let mut servers: HashSet> = - self.room_servers(room_id).filter_map(|r| r.ok()).collect(); - - // In case we are kicking or banning a user, we need to inform their server of the change - if pdu.kind == RoomEventType::RoomMember { - if let Some(state_key_uid) = &pdu - .state_key - .as_ref() - .and_then(|state_key| UserId::parse(state_key.as_str()).ok()) - { - servers.insert(Box::from(state_key_uid.server_name())); - } - } - - // Remove our server from the server list since it will be added to it by room_servers() and/or the if statement above - servers.remove(db.globals.server_name()); - - db.sending.send_pdu(servers.into_iter(), &pdu_id)?; - - Ok(pdu.event_id) - } - - /// Append the incoming event setting the state snapshot to the state from the - /// server that sent the event. - #[tracing::instrument(skip_all)] - fn append_incoming_pdu<'a>( - db: &Database, - pdu: &PduEvent, - pdu_json: CanonicalJsonObject, - new_room_leaves: impl IntoIterator + Clone + Debug, - state_ids_compressed: HashSet, - soft_fail: bool, - _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex - ) -> Result>> { - // We append to state before appending the pdu, so we don't have a moment in time with the - // pdu without it's state. This is okay because append_pdu can't fail. - db.rooms.set_event_state( - &pdu.event_id, - &pdu.room_id, - state_ids_compressed, - &db.globals, - )?; - - if soft_fail { - db.rooms - .mark_as_referenced(&pdu.room_id, &pdu.prev_events)?; - db.rooms.replace_pdu_leaves(&pdu.room_id, new_room_leaves)?; - return Ok(None); - } - - let pdu_id = db.rooms.append_pdu(pdu, pdu_json, new_room_leaves, db)?; - - Ok(Some(pdu_id)) - } - - /// Returns an iterator over all PDUs in a room. - #[tracing::instrument(skip(self))] - pub fn all_pdus<'a>( - &'a self, - user_id: &UserId, - room_id: &RoomId, - ) -> Result, PduEvent)>> + 'a> { - self.pdus_since(user_id, room_id, 0) - } + fn replace_pdu(&self, pdu_id: &[u8], pdu: &PduEvent) -> Result<()>; /// Returns an iterator over all events in a room that happened after the event with id `since` /// in chronological order. @@ -820,32 +45,7 @@ user_id: &UserId, room_id: &RoomId, since: u64, - ) -> Result, PduEvent)>> + 'a> { - let prefix = self - .get_shortroomid(room_id)? - .expect("room exists") - .to_be_bytes() - .to_vec(); - - // Skip the first pdu if it's exactly at since, because we sent that last time - let mut first_pdu_id = prefix.clone(); - first_pdu_id.extend_from_slice(&(since + 1).to_be_bytes()); - - let user_id = user_id.to_owned(); - - Ok(self - .pduid_pdu - .iter_from(&first_pdu_id, false) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(move |(pdu_id, v)| { - let mut pdu = serde_json::from_slice::(&v) - .map_err(|_| Error::bad_database("PDU in db is invalid."))?; - if pdu.sender != user_id { - pdu.remove_transaction_id()?; - } - Ok((pdu_id, pdu)) - })) - } + ) -> Result, PduEvent)>> + 'a>; /// Returns an iterator over all events and their tokens in a room that happened before the /// event with id `until` in reverse-chronological order. @@ -855,83 +55,12 @@ user_id: &UserId, room_id: &RoomId, until: u64, - ) -> Result, PduEvent)>> + 'a> { - // Create the first part of the full pdu id - let prefix = self - .get_shortroomid(room_id)? - .expect("room exists") - .to_be_bytes() - .to_vec(); + ) -> Result, PduEvent)>> + 'a>; - let mut current = prefix.clone(); - current.extend_from_slice(&(until.saturating_sub(1)).to_be_bytes()); // -1 because we don't want event at `until` - - let current: &[u8] = ¤t; - - let user_id = user_id.to_owned(); - - Ok(self - .pduid_pdu - .iter_from(current, true) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(move |(pdu_id, v)| { - let mut pdu = serde_json::from_slice::(&v) - .map_err(|_| Error::bad_database("PDU in db is invalid."))?; - if pdu.sender != user_id { - pdu.remove_transaction_id()?; - } - Ok((pdu_id, pdu)) - })) - } - - /// Returns an iterator over all events and their token in a room that happened after the event - /// with id `from` in chronological order. - #[tracing::instrument(skip(self))] pub fn pdus_after<'a>( &'a self, user_id: &UserId, room_id: &RoomId, from: u64, - ) -> Result, PduEvent)>> + 'a> { - // Create the first part of the full pdu id - let prefix = self - .get_shortroomid(room_id)? - .expect("room exists") - .to_be_bytes() - .to_vec(); - - let mut current = prefix.clone(); - current.extend_from_slice(&(from + 1).to_be_bytes()); // +1 so we don't send the base event - - let current: &[u8] = ¤t; - - let user_id = user_id.to_owned(); - - Ok(self - .pduid_pdu - .iter_from(current, false) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(move |(pdu_id, v)| { - let mut pdu = serde_json::from_slice::(&v) - .map_err(|_| Error::bad_database("PDU in db is invalid."))?; - if pdu.sender != user_id { - pdu.remove_transaction_id()?; - } - Ok((pdu_id, pdu)) - })) - } - - /// Replace a PDU with the redacted form. - #[tracing::instrument(skip(self, reason))] - pub fn redact_pdu(&self, event_id: &EventId, reason: &PduEvent) -> Result<()> { - if let Some(pdu_id) = self.get_pdu_id(event_id)? { - let mut pdu = self - .get_pdu_from_id(&pdu_id)? - .ok_or_else(|| Error::bad_database("PDU ID points to invalid PDU."))?; - pdu.redact(reason)?; - self.replace_pdu(&pdu_id, &pdu)?; - } - // If event does not exist, just noop - Ok(()) - } - + ) -> Result, PduEvent)>> + 'a>; +} diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 5b423d2d..c6393c68 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -1,4 +1,14 @@ +mod data; +pub use data::Data; +use crate::service::*; + +pub struct Service { + db: D, +} + +impl Service<_> { + /* /// Checks if a room exists. #[tracing::instrument(skip(self))] pub fn first_pdu_in_room(&self, room_id: &RoomId) -> Result>> { @@ -20,38 +30,15 @@ .next() .transpose() } + */ #[tracing::instrument(skip(self))] pub fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { - match self - .lasttimelinecount_cache - .lock() - .unwrap() - .entry(room_id.to_owned()) - { - hash_map::Entry::Vacant(v) => { - if let Some(last_count) = self - .pdus_until(&sender_user, &room_id, u64::MAX)? - .filter_map(|r| { - // Filter out buggy events - if r.is_err() { - error!("Bad pdu in pdus_since: {:?}", r); - } - r.ok() - }) - .map(|(pduid, _)| self.pdu_count(&pduid)) - .next() - { - Ok(*v.insert(last_count?)) - } else { - Ok(0) - } - } - hash_map::Entry::Occupied(o) => Ok(*o.get()), - } + self.db.last_timeline_count(sender_user: &UserId, room_id: &RoomId) } // TODO Is this the same as the function above? + /* #[tracing::instrument(skip(self))] pub fn latest_pdu_count(&self, room_id: &RoomId) -> Result { let prefix = self @@ -71,33 +58,16 @@ .transpose() .map(|op| op.unwrap_or_default()) } - - + */ /// Returns the `count` of this pdu's id. pub fn get_pdu_count(&self, event_id: &EventId) -> Result> { - self.eventid_pduid - .get(event_id.as_bytes())? - .map(|pdu_id| self.pdu_count(&pdu_id)) - .transpose() + self.db.get_pdu_count(event_id) } /// Returns the json of a pdu. pub fn get_pdu_json(&self, event_id: &EventId) -> Result> { - self.eventid_pduid - .get(event_id.as_bytes())? - .map_or_else( - || self.eventid_outlierpdu.get(event_id.as_bytes()), - |pduid| { - Ok(Some(self.pduid_pdu.get(&pduid)?.ok_or_else(|| { - Error::bad_database("Invalid pduid in eventid_pduid.") - })?)) - }, - )? - .map(|pdu| { - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) - }) - .transpose() + self.db.get_pdu_json(event_id) } /// Returns the json of a pdu. @@ -105,122 +75,49 @@ &self, event_id: &EventId, ) -> Result> { - self.eventid_pduid - .get(event_id.as_bytes())? - .map(|pduid| { - self.pduid_pdu - .get(&pduid)? - .ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid.")) - }) - .transpose()? - .map(|pdu| { - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) - }) - .transpose() + self.db.get_non_outlier_pdu(event_id) } /// Returns the pdu's id. pub fn get_pdu_id(&self, event_id: &EventId) -> Result>> { - self.eventid_pduid.get(event_id.as_bytes()) + self.db.get_pdu_id(event_id) } /// Returns the pdu. /// /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. pub fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result> { - self.eventid_pduid - .get(event_id.as_bytes())? - .map(|pduid| { - self.pduid_pdu - .get(&pduid)? - .ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid.")) - }) - .transpose()? - .map(|pdu| { - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) - }) - .transpose() + self.db.get_non_outlier_pdu(event_id) } /// Returns the pdu. /// /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. pub fn get_pdu(&self, event_id: &EventId) -> Result>> { - if let Some(p) = self.pdu_cache.lock().unwrap().get_mut(event_id) { - return Ok(Some(Arc::clone(p))); - } - - if let Some(pdu) = self - .eventid_pduid - .get(event_id.as_bytes())? - .map_or_else( - || self.eventid_outlierpdu.get(event_id.as_bytes()), - |pduid| { - Ok(Some(self.pduid_pdu.get(&pduid)?.ok_or_else(|| { - Error::bad_database("Invalid pduid in eventid_pduid.") - })?)) - }, - )? - .map(|pdu| { - serde_json::from_slice(&pdu) - .map_err(|_| Error::bad_database("Invalid PDU in db.")) - .map(Arc::new) - }) - .transpose()? - { - self.pdu_cache - .lock() - .unwrap() - .insert(event_id.to_owned(), Arc::clone(&pdu)); - Ok(Some(pdu)) - } else { - Ok(None) - } + self.db.get_pdu(event_id) } /// Returns the pdu. /// /// This does __NOT__ check the outliers `Tree`. pub fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result> { - self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { - Ok(Some( - serde_json::from_slice(&pdu) - .map_err(|_| Error::bad_database("Invalid PDU in db."))?, - )) - }) + self.db.get_pdu_from_id(pdu_id) } /// Returns the pdu as a `BTreeMap`. pub fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result> { - self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { - Ok(Some( - serde_json::from_slice(&pdu) - .map_err(|_| Error::bad_database("Invalid PDU in db."))?, - )) - }) + self.db.get_pdu_json_from_id(pdu_id) } /// Returns the `count` of this pdu's id. pub fn pdu_count(&self, pdu_id: &[u8]) -> Result { - utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::()..]) - .map_err(|_| Error::bad_database("PDU has invalid count bytes.")) + self.db.pdu_count(pdu_id) } /// Removes a pdu and creates a new one with the same id. #[tracing::instrument(skip(self))] fn replace_pdu(&self, pdu_id: &[u8], pdu: &PduEvent) -> Result<()> { - if self.pduid_pdu.get(pdu_id)?.is_some() { - self.pduid_pdu.insert( - pdu_id, - &serde_json::to_vec(pdu).expect("PduEvent::to_vec always works"), - )?; - Ok(()) - } else { - Err(Error::BadRequest( - ErrorKind::NotFound, - "PDU does not exist.", - )) - } + self.db.pdu_count(pdu_id, pdu: &PduEvent) } /// Creates a new persisted data unit and adds it to a room. @@ -803,7 +700,6 @@ } /// Returns an iterator over all PDUs in a room. - #[tracing::instrument(skip(self))] pub fn all_pdus<'a>( &'a self, user_id: &UserId, @@ -814,37 +710,13 @@ /// Returns an iterator over all events in a room that happened after the event with id `since` /// in chronological order. - #[tracing::instrument(skip(self))] pub fn pdus_since<'a>( &'a self, user_id: &UserId, room_id: &RoomId, since: u64, ) -> Result, PduEvent)>> + 'a> { - let prefix = self - .get_shortroomid(room_id)? - .expect("room exists") - .to_be_bytes() - .to_vec(); - - // Skip the first pdu if it's exactly at since, because we sent that last time - let mut first_pdu_id = prefix.clone(); - first_pdu_id.extend_from_slice(&(since + 1).to_be_bytes()); - - let user_id = user_id.to_owned(); - - Ok(self - .pduid_pdu - .iter_from(&first_pdu_id, false) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(move |(pdu_id, v)| { - let mut pdu = serde_json::from_slice::(&v) - .map_err(|_| Error::bad_database("PDU in db is invalid."))?; - if pdu.sender != user_id { - pdu.remove_transaction_id()?; - } - Ok((pdu_id, pdu)) - })) + self.db.pdus_since(user_id, room_id, since) } /// Returns an iterator over all events and their tokens in a room that happened before the @@ -856,32 +728,7 @@ room_id: &RoomId, until: u64, ) -> Result, PduEvent)>> + 'a> { - // Create the first part of the full pdu id - let prefix = self - .get_shortroomid(room_id)? - .expect("room exists") - .to_be_bytes() - .to_vec(); - - let mut current = prefix.clone(); - current.extend_from_slice(&(until.saturating_sub(1)).to_be_bytes()); // -1 because we don't want event at `until` - - let current: &[u8] = ¤t; - - let user_id = user_id.to_owned(); - - Ok(self - .pduid_pdu - .iter_from(current, true) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(move |(pdu_id, v)| { - let mut pdu = serde_json::from_slice::(&v) - .map_err(|_| Error::bad_database("PDU in db is invalid."))?; - if pdu.sender != user_id { - pdu.remove_transaction_id()?; - } - Ok((pdu_id, pdu)) - })) + self.db.pdus_until(user_id, room_id, until) } /// Returns an iterator over all events and their token in a room that happened after the event @@ -893,32 +740,7 @@ room_id: &RoomId, from: u64, ) -> Result, PduEvent)>> + 'a> { - // Create the first part of the full pdu id - let prefix = self - .get_shortroomid(room_id)? - .expect("room exists") - .to_be_bytes() - .to_vec(); - - let mut current = prefix.clone(); - current.extend_from_slice(&(from + 1).to_be_bytes()); // +1 so we don't send the base event - - let current: &[u8] = ¤t; - - let user_id = user_id.to_owned(); - - Ok(self - .pduid_pdu - .iter_from(current, false) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(move |(pdu_id, v)| { - let mut pdu = serde_json::from_slice::(&v) - .map_err(|_| Error::bad_database("PDU in db is invalid."))?; - if pdu.sender != user_id { - pdu.remove_transaction_id()?; - } - Ok((pdu_id, pdu)) - })) + self.db.pdus_after(user_id, room_id, from) } /// Replace a PDU with the redacted form. diff --git a/src/service/transaction_ids/data.rs b/src/service/transaction_ids/data.rs new file mode 100644 index 00000000..f1ff5f88 --- /dev/null +++ b/src/service/transaction_ids/data.rs @@ -0,0 +1,16 @@ +pub trait Data { + pub fn add_txnid( + &self, + user_id: &UserId, + device_id: Option<&DeviceId>, + txn_id: &TransactionId, + data: &[u8], + ) -> Result<()>; + + pub fn existing_txnid( + &self, + user_id: &UserId, + device_id: Option<&DeviceId>, + txn_id: &TransactionId, + ) -> Result>>; +} diff --git a/src/service/transaction_ids/mod.rs b/src/service/transaction_ids/mod.rs new file mode 100644 index 00000000..d944847e --- /dev/null +++ b/src/service/transaction_ids/mod.rs @@ -0,0 +1,44 @@ +mod data; +pub use data::Data; + +use crate::service::*; + +pub struct Service { + db: D, +} + +impl Service<_> { + pub fn add_txnid( + &self, + user_id: &UserId, + device_id: Option<&DeviceId>, + txn_id: &TransactionId, + data: &[u8], + ) -> Result<()> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(device_id.map(|d| d.as_bytes()).unwrap_or_default()); + key.push(0xff); + key.extend_from_slice(txn_id.as_bytes()); + + self.userdevicetxnid_response.insert(&key, data)?; + + Ok(()) + } + + pub fn existing_txnid( + &self, + user_id: &UserId, + device_id: Option<&DeviceId>, + txn_id: &TransactionId, + ) -> Result>> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(device_id.map(|d| d.as_bytes()).unwrap_or_default()); + key.push(0xff); + key.extend_from_slice(txn_id.as_bytes()); + + // If there's no entry, this is a new transaction + self.userdevicetxnid_response.get(&key) + } +} diff --git a/src/service/users/data.rs b/src/service/users/data.rs index 7c15f1d8..d99d0328 100644 --- a/src/service/users/data.rs +++ b/src/service/users/data.rs @@ -1,396 +1,86 @@ -use crate::{utils, Error, Result}; -use ruma::{ - api::client::{device::Device, error::ErrorKind, filter::IncomingFilterDefinition}, - encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, - events::{AnyToDeviceEvent, StateEventType}, - serde::Raw, - DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, MxcUri, RoomAliasId, - UInt, UserId, -}; -use std::{collections::BTreeMap, mem, sync::Arc}; -use tracing::warn; - -use super::abstraction::Tree; - -pub struct Users { - pub(super) userid_password: Arc, - pub(super) userid_displayname: Arc, - pub(super) userid_avatarurl: Arc, - pub(super) userid_blurhash: Arc, - pub(super) userdeviceid_token: Arc, - pub(super) userdeviceid_metadata: Arc, // This is also used to check if a device exists - pub(super) userid_devicelistversion: Arc, // DevicelistVersion = u64 - pub(super) token_userdeviceid: Arc, - - pub(super) onetimekeyid_onetimekeys: Arc, // OneTimeKeyId = UserId + DeviceKeyId - pub(super) userid_lastonetimekeyupdate: Arc, // LastOneTimeKeyUpdate = Count - pub(super) keychangeid_userid: Arc, // KeyChangeId = UserId/RoomId + Count - pub(super) keyid_key: Arc, // KeyId = UserId + KeyId (depends on key type) - pub(super) userid_masterkeyid: Arc, - pub(super) userid_selfsigningkeyid: Arc, - pub(super) userid_usersigningkeyid: Arc, - - pub(super) userfilterid_filter: Arc, // UserFilterId = UserId + FilterId - - pub(super) todeviceid_events: Arc, // ToDeviceId = UserId + DeviceId + Count -} - -impl Users { +pub trait Data { /// Check if a user has an account on this homeserver. - #[tracing::instrument(skip(self, user_id))] - pub fn exists(&self, user_id: &UserId) -> Result { - Ok(self.userid_password.get(user_id.as_bytes())?.is_some()) - } + pub fn exists(&self, user_id: &UserId) -> Result; /// Check if account is deactivated - #[tracing::instrument(skip(self, user_id))] - pub fn is_deactivated(&self, user_id: &UserId) -> Result { - Ok(self - .userid_password - .get(user_id.as_bytes())? - .ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "User does not exist.", - ))? - .is_empty()) - } + pub fn is_deactivated(&self, user_id: &UserId) -> Result; /// Check if a user is an admin - #[tracing::instrument(skip(self, user_id, rooms, globals))] pub fn is_admin( &self, user_id: &UserId, rooms: &super::rooms::Rooms, globals: &super::globals::Globals, - ) -> Result { - let admin_room_alias_id = RoomAliasId::parse(format!("#admins:{}", globals.server_name())) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid alias."))?; - let admin_room_id = rooms.id_from_alias(&admin_room_alias_id)?.unwrap(); - - rooms.is_joined(user_id, &admin_room_id) - } + ) -> Result; /// Create a new user account on this homeserver. - #[tracing::instrument(skip(self, user_id, password))] - pub fn create(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { - self.set_password(user_id, password)?; - Ok(()) - } + pub fn create(&self, user_id: &UserId, password: Option<&str>) -> Result<()>; /// Returns the number of users registered on this server. - #[tracing::instrument(skip(self))] - pub fn count(&self) -> Result { - Ok(self.userid_password.iter().count()) - } + pub fn count(&self) -> Result; /// Find out which user an access token belongs to. - #[tracing::instrument(skip(self, token))] - pub fn find_from_token(&self, token: &str) -> Result, String)>> { - self.token_userdeviceid - .get(token.as_bytes())? - .map_or(Ok(None), |bytes| { - let mut parts = bytes.split(|&b| b == 0xff); - let user_bytes = parts.next().ok_or_else(|| { - Error::bad_database("User ID in token_userdeviceid is invalid.") - })?; - let device_bytes = parts.next().ok_or_else(|| { - Error::bad_database("Device ID in token_userdeviceid is invalid.") - })?; - - Ok(Some(( - UserId::parse(utils::string_from_bytes(user_bytes).map_err(|_| { - Error::bad_database("User ID in token_userdeviceid is invalid unicode.") - })?) - .map_err(|_| { - Error::bad_database("User ID in token_userdeviceid is invalid.") - })?, - utils::string_from_bytes(device_bytes).map_err(|_| { - Error::bad_database("Device ID in token_userdeviceid is invalid.") - })?, - ))) - }) - } + pub fn find_from_token(&self, token: &str) -> Result, String)>>; /// Returns an iterator over all users on this homeserver. - #[tracing::instrument(skip(self))] - pub fn iter(&self) -> impl Iterator>> + '_ { - self.userid_password.iter().map(|(bytes, _)| { - UserId::parse(utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("User ID in userid_password is invalid unicode.") - })?) - .map_err(|_| Error::bad_database("User ID in userid_password is invalid.")) - }) - } + pub fn iter(&self) -> impl Iterator>> + '_; /// Returns a list of local users as list of usernames. /// /// A user account is considered `local` if the length of it's password is greater then zero. - #[tracing::instrument(skip(self))] - pub fn list_local_users(&self) -> Result> { - let users: Vec = self - .userid_password - .iter() - .filter_map(|(username, pw)| self.get_username_with_valid_password(&username, &pw)) - .collect(); - Ok(users) - } + pub fn list_local_users(&self) -> Result>; /// Will only return with Some(username) if the password was not empty and the /// username could be successfully parsed. /// If utils::string_from_bytes(...) returns an error that username will be skipped /// and the error will be logged. - #[tracing::instrument(skip(self))] - fn get_username_with_valid_password(&self, username: &[u8], password: &[u8]) -> Option { - // A valid password is not empty - if password.is_empty() { - None - } else { - match utils::string_from_bytes(username) { - Ok(u) => Some(u), - Err(e) => { - warn!( - "Failed to parse username while calling get_local_users(): {}", - e.to_string() - ); - None - } - } - } - } + fn get_username_with_valid_password(&self, username: &[u8], password: &[u8]) -> Option; /// Returns the password hash for the given user. - #[tracing::instrument(skip(self, user_id))] - pub fn password_hash(&self, user_id: &UserId) -> Result> { - self.userid_password - .get(user_id.as_bytes())? - .map_or(Ok(None), |bytes| { - Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Password hash in db is not valid string.") - })?)) - }) - } + pub fn password_hash(&self, user_id: &UserId) -> Result>; /// Hash and set the user's password to the Argon2 hash - #[tracing::instrument(skip(self, user_id, password))] - pub fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { - if let Some(password) = password { - if let Ok(hash) = utils::calculate_hash(password) { - self.userid_password - .insert(user_id.as_bytes(), hash.as_bytes())?; - Ok(()) - } else { - Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Password does not meet the requirements.", - )) - } - } else { - self.userid_password.insert(user_id.as_bytes(), b"")?; - Ok(()) - } - } + pub fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()>; /// Returns the displayname of a user on this homeserver. - #[tracing::instrument(skip(self, user_id))] - pub fn displayname(&self, user_id: &UserId) -> Result> { - self.userid_displayname - .get(user_id.as_bytes())? - .map_or(Ok(None), |bytes| { - Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Displayname in db is invalid.") - })?)) - }) - } + pub fn displayname(&self, user_id: &UserId) -> Result>; /// Sets a new displayname or removes it if displayname is None. You still need to nofify all rooms of this change. - #[tracing::instrument(skip(self, user_id, displayname))] - pub fn set_displayname(&self, user_id: &UserId, displayname: Option) -> Result<()> { - if let Some(displayname) = displayname { - self.userid_displayname - .insert(user_id.as_bytes(), displayname.as_bytes())?; - } else { - self.userid_displayname.remove(user_id.as_bytes())?; - } - - Ok(()) - } + pub fn set_displayname(&self, user_id: &UserId, displayname: Option) -> Result<()>; /// Get the avatar_url of a user. - #[tracing::instrument(skip(self, user_id))] - pub fn avatar_url(&self, user_id: &UserId) -> Result>> { - self.userid_avatarurl - .get(user_id.as_bytes())? - .map(|bytes| { - let s = utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Avatar URL in db is invalid."))?; - s.try_into() - .map_err(|_| Error::bad_database("Avatar URL in db is invalid.")) - }) - .transpose() - } + pub fn avatar_url(&self, user_id: &UserId) -> Result>>; /// Sets a new avatar_url or removes it if avatar_url is None. - #[tracing::instrument(skip(self, user_id, avatar_url))] - pub fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option>) -> Result<()> { - if let Some(avatar_url) = avatar_url { - self.userid_avatarurl - .insert(user_id.as_bytes(), avatar_url.to_string().as_bytes())?; - } else { - self.userid_avatarurl.remove(user_id.as_bytes())?; - } - - Ok(()) - } + pub fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option>) -> Result<()>; /// Get the blurhash of a user. - #[tracing::instrument(skip(self, user_id))] - pub fn blurhash(&self, user_id: &UserId) -> Result> { - self.userid_blurhash - .get(user_id.as_bytes())? - .map(|bytes| { - let s = utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Avatar URL in db is invalid."))?; - - Ok(s) - }) - .transpose() - } + pub fn blurhash(&self, user_id: &UserId) -> Result>; /// Sets a new avatar_url or removes it if avatar_url is None. - #[tracing::instrument(skip(self, user_id, blurhash))] - pub fn set_blurhash(&self, user_id: &UserId, blurhash: Option) -> Result<()> { - if let Some(blurhash) = blurhash { - self.userid_blurhash - .insert(user_id.as_bytes(), blurhash.as_bytes())?; - } else { - self.userid_blurhash.remove(user_id.as_bytes())?; - } - - Ok(()) - } + pub fn set_blurhash(&self, user_id: &UserId, blurhash: Option) -> Result<()>; /// Adds a new device to a user. - #[tracing::instrument(skip(self, user_id, device_id, token, initial_device_display_name))] pub fn create_device( &self, user_id: &UserId, device_id: &DeviceId, token: &str, initial_device_display_name: Option, - ) -> Result<()> { - // This method should never be called for nonexistent users. - assert!(self.exists(user_id)?); - - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xff); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - self.userid_devicelistversion - .increment(user_id.as_bytes())?; - - self.userdeviceid_metadata.insert( - &userdeviceid, - &serde_json::to_vec(&Device { - device_id: device_id.into(), - display_name: initial_device_display_name, - last_seen_ip: None, // TODO - last_seen_ts: Some(MilliSecondsSinceUnixEpoch::now()), - }) - .expect("Device::to_string never fails."), - )?; - - self.set_token(user_id, device_id, token)?; - - Ok(()) - } + ) -> Result<()>; /// Removes a device from a user. - #[tracing::instrument(skip(self, user_id, device_id))] - pub fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xff); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - // Remove tokens - if let Some(old_token) = self.userdeviceid_token.get(&userdeviceid)? { - self.userdeviceid_token.remove(&userdeviceid)?; - self.token_userdeviceid.remove(&old_token)?; - } - - // Remove todevice events - let mut prefix = userdeviceid.clone(); - prefix.push(0xff); - - for (key, _) in self.todeviceid_events.scan_prefix(prefix) { - self.todeviceid_events.remove(&key)?; - } - - // TODO: Remove onetimekeys - - self.userid_devicelistversion - .increment(user_id.as_bytes())?; - - self.userdeviceid_metadata.remove(&userdeviceid)?; - - Ok(()) - } + pub fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()>; /// Returns an iterator over all device ids of this user. - #[tracing::instrument(skip(self, user_id))] pub fn all_device_ids<'a>( &'a self, user_id: &UserId, - ) -> impl Iterator>> + 'a { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); - // All devices have metadata - self.userdeviceid_metadata - .scan_prefix(prefix) - .map(|(bytes, _)| { - Ok(utils::string_from_bytes( - bytes - .rsplit(|&b| b == 0xff) - .next() - .ok_or_else(|| Error::bad_database("UserDevice ID in db is invalid."))?, - ) - .map_err(|_| Error::bad_database("Device ID in userdeviceid_metadata is invalid."))? - .into()) - }) - } + ) -> impl Iterator>> + 'a; /// Replaces the access token of one device. - #[tracing::instrument(skip(self, user_id, device_id, token))] - pub fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xff); - userdeviceid.extend_from_slice(device_id.as_bytes()); + pub fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()>; - // All devices have metadata - assert!(self.userdeviceid_metadata.get(&userdeviceid)?.is_some()); - - // Remove old token - if let Some(old_token) = self.userdeviceid_token.get(&userdeviceid)? { - self.token_userdeviceid.remove(&old_token)?; - // It will be removed from userdeviceid_token by the insert later - } - - // Assign token to user device combination - self.userdeviceid_token - .insert(&userdeviceid, token.as_bytes())?; - self.token_userdeviceid - .insert(token.as_bytes(), &userdeviceid)?; - - Ok(()) - } - - #[tracing::instrument(skip( - self, - user_id, - device_id, - one_time_key_key, - one_time_key_value, - globals - ))] pub fn add_one_time_key( &self, user_id: &UserId, @@ -398,121 +88,24 @@ impl Users { one_time_key_key: &DeviceKeyId, one_time_key_value: &Raw, globals: &super::globals::Globals, - ) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(device_id.as_bytes()); + ) -> Result<()>; - // All devices have metadata - // Only existing devices should be able to call this. - assert!(self.userdeviceid_metadata.get(&key)?.is_some()); + pub fn last_one_time_keys_update(&self, user_id: &UserId) -> Result; - key.push(0xff); - // TODO: Use DeviceKeyId::to_string when it's available (and update everything, - // because there are no wrapping quotation marks anymore) - key.extend_from_slice( - serde_json::to_string(one_time_key_key) - .expect("DeviceKeyId::to_string always works") - .as_bytes(), - ); - - self.onetimekeyid_onetimekeys.insert( - &key, - &serde_json::to_vec(&one_time_key_value).expect("OneTimeKey::to_vec always works"), - )?; - - self.userid_lastonetimekeyupdate - .insert(user_id.as_bytes(), &globals.next_count()?.to_be_bytes())?; - - Ok(()) - } - - #[tracing::instrument(skip(self, user_id))] - pub fn last_one_time_keys_update(&self, user_id: &UserId) -> Result { - self.userid_lastonetimekeyupdate - .get(user_id.as_bytes())? - .map(|bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Count in roomid_lastroomactiveupdate is invalid.") - }) - }) - .unwrap_or(Ok(0)) - } - - #[tracing::instrument(skip(self, user_id, device_id, key_algorithm, globals))] pub fn take_one_time_key( &self, user_id: &UserId, device_id: &DeviceId, key_algorithm: &DeviceKeyAlgorithm, globals: &super::globals::Globals, - ) -> Result, Raw)>> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); - prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xff); - prefix.push(b'"'); // Annoying quotation mark - prefix.extend_from_slice(key_algorithm.as_ref().as_bytes()); - prefix.push(b':'); + ) -> Result, Raw)>>; - self.userid_lastonetimekeyupdate - .insert(user_id.as_bytes(), &globals.next_count()?.to_be_bytes())?; - - self.onetimekeyid_onetimekeys - .scan_prefix(prefix) - .next() - .map(|(key, value)| { - self.onetimekeyid_onetimekeys.remove(&key)?; - - Ok(( - serde_json::from_slice( - &*key - .rsplit(|&b| b == 0xff) - .next() - .ok_or_else(|| Error::bad_database("OneTimeKeyId in db is invalid."))?, - ) - .map_err(|_| Error::bad_database("OneTimeKeyId in db is invalid."))?, - serde_json::from_slice(&*value) - .map_err(|_| Error::bad_database("OneTimeKeys in db are invalid."))?, - )) - }) - .transpose() - } - - #[tracing::instrument(skip(self, user_id, device_id))] pub fn count_one_time_keys( &self, user_id: &UserId, device_id: &DeviceId, - ) -> Result> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xff); - userdeviceid.extend_from_slice(device_id.as_bytes()); + ) -> Result>; - let mut counts = BTreeMap::new(); - - for algorithm in - self.onetimekeyid_onetimekeys - .scan_prefix(userdeviceid) - .map(|(bytes, _)| { - Ok::<_, Error>( - serde_json::from_slice::>( - &*bytes.rsplit(|&b| b == 0xff).next().ok_or_else(|| { - Error::bad_database("OneTimeKey ID in db is invalid.") - })?, - ) - .map_err(|_| Error::bad_database("DeviceKeyId in db is invalid."))? - .algorithm(), - ) - }) - { - *counts.entry(algorithm?).or_default() += UInt::from(1_u32); - } - - Ok(counts) - } - - #[tracing::instrument(skip(self, user_id, device_id, device_keys, rooms, globals))] pub fn add_device_keys( &self, user_id: &UserId, @@ -520,29 +113,8 @@ impl Users { device_keys: &Raw, rooms: &super::rooms::Rooms, globals: &super::globals::Globals, - ) -> Result<()> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xff); - userdeviceid.extend_from_slice(device_id.as_bytes()); + ) -> Result<()>; - self.keyid_key.insert( - &userdeviceid, - &serde_json::to_vec(&device_keys).expect("DeviceKeys::to_vec always works"), - )?; - - self.mark_device_key_update(user_id, rooms, globals)?; - - Ok(()) - } - - #[tracing::instrument(skip( - self, - master_key, - self_signing_key, - user_signing_key, - rooms, - globals - ))] pub fn add_cross_signing_keys( &self, user_id: &UserId, @@ -551,114 +123,8 @@ impl Users { user_signing_key: &Option>, rooms: &super::rooms::Rooms, globals: &super::globals::Globals, - ) -> Result<()> { - // TODO: Check signatures + ) -> Result<()>; - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); - - // Master key - let mut master_key_ids = master_key - .deserialize() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid master key"))? - .keys - .into_values(); - - let master_key_id = master_key_ids.next().ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Master key contained no key.", - ))?; - - if master_key_ids.next().is_some() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Master key contained more than one key.", - )); - } - - let mut master_key_key = prefix.clone(); - master_key_key.extend_from_slice(master_key_id.as_bytes()); - - self.keyid_key - .insert(&master_key_key, master_key.json().get().as_bytes())?; - - self.userid_masterkeyid - .insert(user_id.as_bytes(), &master_key_key)?; - - // Self-signing key - if let Some(self_signing_key) = self_signing_key { - let mut self_signing_key_ids = self_signing_key - .deserialize() - .map_err(|_| { - Error::BadRequest(ErrorKind::InvalidParam, "Invalid self signing key") - })? - .keys - .into_values(); - - let self_signing_key_id = self_signing_key_ids.next().ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Self signing key contained no key.", - ))?; - - if self_signing_key_ids.next().is_some() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Self signing key contained more than one key.", - )); - } - - let mut self_signing_key_key = prefix.clone(); - self_signing_key_key.extend_from_slice(self_signing_key_id.as_bytes()); - - self.keyid_key.insert( - &self_signing_key_key, - self_signing_key.json().get().as_bytes(), - )?; - - self.userid_selfsigningkeyid - .insert(user_id.as_bytes(), &self_signing_key_key)?; - } - - // User-signing key - if let Some(user_signing_key) = user_signing_key { - let mut user_signing_key_ids = user_signing_key - .deserialize() - .map_err(|_| { - Error::BadRequest(ErrorKind::InvalidParam, "Invalid user signing key") - })? - .keys - .into_values(); - - let user_signing_key_id = user_signing_key_ids.next().ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "User signing key contained no key.", - ))?; - - if user_signing_key_ids.next().is_some() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "User signing key contained more than one key.", - )); - } - - let mut user_signing_key_key = prefix; - user_signing_key_key.extend_from_slice(user_signing_key_id.as_bytes()); - - self.keyid_key.insert( - &user_signing_key_key, - user_signing_key.json().get().as_bytes(), - )?; - - self.userid_usersigningkeyid - .insert(user_id.as_bytes(), &user_signing_key_key)?; - } - - self.mark_device_key_update(user_id, rooms, globals)?; - - Ok(()) - } - - #[tracing::instrument(skip(self, target_id, key_id, signature, sender_id, rooms, globals))] pub fn sign_key( &self, target_id: &UserId, @@ -667,196 +133,42 @@ impl Users { sender_id: &UserId, rooms: &super::rooms::Rooms, globals: &super::globals::Globals, - ) -> Result<()> { - let mut key = target_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(key_id.as_bytes()); + ) -> Result<()>; - let mut cross_signing_key: serde_json::Value = - serde_json::from_slice(&self.keyid_key.get(&key)?.ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Tried to sign nonexistent key.", - ))?) - .map_err(|_| Error::bad_database("key in keyid_key is invalid."))?; - - let signatures = cross_signing_key - .get_mut("signatures") - .ok_or_else(|| Error::bad_database("key in keyid_key has no signatures field."))? - .as_object_mut() - .ok_or_else(|| Error::bad_database("key in keyid_key has invalid signatures field."))? - .entry(sender_id.to_owned()) - .or_insert_with(|| serde_json::Map::new().into()); - - signatures - .as_object_mut() - .ok_or_else(|| Error::bad_database("signatures in keyid_key for a user is invalid."))? - .insert(signature.0, signature.1.into()); - - self.keyid_key.insert( - &key, - &serde_json::to_vec(&cross_signing_key).expect("CrossSigningKey::to_vec always works"), - )?; - - // TODO: Should we notify about this change? - self.mark_device_key_update(target_id, rooms, globals)?; - - Ok(()) - } - - #[tracing::instrument(skip(self, user_or_room_id, from, to))] pub fn keys_changed<'a>( &'a self, user_or_room_id: &str, from: u64, to: Option, - ) -> impl Iterator>> + 'a { - let mut prefix = user_or_room_id.as_bytes().to_vec(); - prefix.push(0xff); + ) -> impl Iterator>> + 'a; - let mut start = prefix.clone(); - start.extend_from_slice(&(from + 1).to_be_bytes()); - - let to = to.unwrap_or(u64::MAX); - - self.keychangeid_userid - .iter_from(&start, false) - .take_while(move |(k, _)| { - k.starts_with(&prefix) - && if let Some(current) = k.splitn(2, |&b| b == 0xff).nth(1) { - if let Ok(c) = utils::u64_from_bytes(current) { - c <= to - } else { - warn!("BadDatabase: Could not parse keychangeid_userid bytes"); - false - } - } else { - warn!("BadDatabase: Could not parse keychangeid_userid"); - false - } - }) - .map(|(_, bytes)| { - UserId::parse(utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("User ID in devicekeychangeid_userid is invalid unicode.") - })?) - .map_err(|_| Error::bad_database("User ID in devicekeychangeid_userid is invalid.")) - }) - } - - #[tracing::instrument(skip(self, user_id, rooms, globals))] pub fn mark_device_key_update( &self, user_id: &UserId, rooms: &super::rooms::Rooms, globals: &super::globals::Globals, - ) -> Result<()> { - let count = globals.next_count()?.to_be_bytes(); - for room_id in rooms.rooms_joined(user_id).filter_map(|r| r.ok()) { - // Don't send key updates to unencrypted rooms - if rooms - .room_state_get(&room_id, &StateEventType::RoomEncryption, "")? - .is_none() - { - continue; - } + ) -> Result<()>; - let mut key = room_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(&count); - - self.keychangeid_userid.insert(&key, user_id.as_bytes())?; - } - - let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(&count); - self.keychangeid_userid.insert(&key, user_id.as_bytes())?; - - Ok(()) - } - - #[tracing::instrument(skip(self, user_id, device_id))] pub fn get_device_keys( &self, user_id: &UserId, device_id: &DeviceId, - ) -> Result>> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(device_id.as_bytes()); + ) -> Result>>; - self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { - Ok(Some(serde_json::from_slice(&bytes).map_err(|_| { - Error::bad_database("DeviceKeys in db are invalid.") - })?)) - }) - } - - #[tracing::instrument(skip(self, user_id, allowed_signatures))] pub fn get_master_key bool>( &self, user_id: &UserId, allowed_signatures: F, - ) -> Result>> { - self.userid_masterkeyid - .get(user_id.as_bytes())? - .map_or(Ok(None), |key| { - self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { - let mut cross_signing_key = serde_json::from_slice::(&bytes) - .map_err(|_| Error::bad_database("CrossSigningKey in db is invalid."))?; - clean_signatures(&mut cross_signing_key, user_id, allowed_signatures)?; + ) -> Result>>; - Ok(Some(Raw::from_json( - serde_json::value::to_raw_value(&cross_signing_key) - .expect("Value to RawValue serialization"), - ))) - }) - }) - } - - #[tracing::instrument(skip(self, user_id, allowed_signatures))] pub fn get_self_signing_key bool>( &self, user_id: &UserId, allowed_signatures: F, - ) -> Result>> { - self.userid_selfsigningkeyid - .get(user_id.as_bytes())? - .map_or(Ok(None), |key| { - self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { - let mut cross_signing_key = serde_json::from_slice::(&bytes) - .map_err(|_| Error::bad_database("CrossSigningKey in db is invalid."))?; - clean_signatures(&mut cross_signing_key, user_id, allowed_signatures)?; + ) -> Result>>; - Ok(Some(Raw::from_json( - serde_json::value::to_raw_value(&cross_signing_key) - .expect("Value to RawValue serialization"), - ))) - }) - }) - } + pub fn get_user_signing_key(&self, user_id: &UserId) -> Result>>; - #[tracing::instrument(skip(self, user_id))] - pub fn get_user_signing_key(&self, user_id: &UserId) -> Result>> { - self.userid_usersigningkeyid - .get(user_id.as_bytes())? - .map_or(Ok(None), |key| { - self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { - Ok(Some(serde_json::from_slice(&bytes).map_err(|_| { - Error::bad_database("CrossSigningKey in db is invalid.") - })?)) - }) - }) - } - - #[tracing::instrument(skip( - self, - sender, - target_user_id, - target_device_id, - event_type, - content, - globals - ))] pub fn add_to_device_event( &self, sender: &UserId, @@ -865,237 +177,52 @@ impl Users { event_type: &str, content: serde_json::Value, globals: &super::globals::Globals, - ) -> Result<()> { - let mut key = target_user_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(target_device_id.as_bytes()); - key.push(0xff); - key.extend_from_slice(&globals.next_count()?.to_be_bytes()); + ) -> Result<()>; - let mut json = serde_json::Map::new(); - json.insert("type".to_owned(), event_type.to_owned().into()); - json.insert("sender".to_owned(), sender.to_string().into()); - json.insert("content".to_owned(), content); - - let value = serde_json::to_vec(&json).expect("Map::to_vec always works"); - - self.todeviceid_events.insert(&key, &value)?; - - Ok(()) - } - - #[tracing::instrument(skip(self, user_id, device_id))] pub fn get_to_device_events( &self, user_id: &UserId, device_id: &DeviceId, - ) -> Result>> { - let mut events = Vec::new(); + ) -> Result>>; - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); - prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xff); - - for (_, value) in self.todeviceid_events.scan_prefix(prefix) { - events.push( - serde_json::from_slice(&value) - .map_err(|_| Error::bad_database("Event in todeviceid_events is invalid."))?, - ); - } - - Ok(events) - } - - #[tracing::instrument(skip(self, user_id, device_id, until))] pub fn remove_to_device_events( &self, user_id: &UserId, device_id: &DeviceId, until: u64, - ) -> Result<()> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); - prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xff); + ) -> Result<()>; - let mut last = prefix.clone(); - last.extend_from_slice(&until.to_be_bytes()); - - for (key, _) in self - .todeviceid_events - .iter_from(&last, true) // this includes last - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(|(key, _)| { - Ok::<_, Error>(( - key.clone(), - utils::u64_from_bytes(&key[key.len() - mem::size_of::()..key.len()]) - .map_err(|_| Error::bad_database("ToDeviceId has invalid count bytes."))?, - )) - }) - .filter_map(|r| r.ok()) - .take_while(|&(_, count)| count <= until) - { - self.todeviceid_events.remove(&key)?; - } - - Ok(()) - } - - #[tracing::instrument(skip(self, user_id, device_id, device))] pub fn update_device_metadata( &self, user_id: &UserId, device_id: &DeviceId, device: &Device, - ) -> Result<()> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xff); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - // Only existing devices should be able to call this. - assert!(self.userdeviceid_metadata.get(&userdeviceid)?.is_some()); - - self.userid_devicelistversion - .increment(user_id.as_bytes())?; - - self.userdeviceid_metadata.insert( - &userdeviceid, - &serde_json::to_vec(device).expect("Device::to_string always works"), - )?; - - Ok(()) - } + ) -> Result<()>; /// Get device metadata. - #[tracing::instrument(skip(self, user_id, device_id))] pub fn get_device_metadata( &self, user_id: &UserId, device_id: &DeviceId, - ) -> Result> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xff); - userdeviceid.extend_from_slice(device_id.as_bytes()); + ) -> Result>; - self.userdeviceid_metadata - .get(&userdeviceid)? - .map_or(Ok(None), |bytes| { - Ok(Some(serde_json::from_slice(&bytes).map_err(|_| { - Error::bad_database("Metadata in userdeviceid_metadata is invalid.") - })?)) - }) - } + pub fn get_devicelist_version(&self, user_id: &UserId) -> Result>; - #[tracing::instrument(skip(self, user_id))] - pub fn get_devicelist_version(&self, user_id: &UserId) -> Result> { - self.userid_devicelistversion - .get(user_id.as_bytes())? - .map_or(Ok(None), |bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid devicelistversion in db.")) - .map(Some) - }) - } - - #[tracing::instrument(skip(self, user_id))] pub fn all_devices_metadata<'a>( &'a self, user_id: &UserId, - ) -> impl Iterator> + 'a { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); - - self.userdeviceid_metadata - .scan_prefix(key) - .map(|(_, bytes)| { - serde_json::from_slice::(&bytes) - .map_err(|_| Error::bad_database("Device in userdeviceid_metadata is invalid.")) - }) - } - - /// Deactivate account - #[tracing::instrument(skip(self, user_id))] - pub fn deactivate_account(&self, user_id: &UserId) -> Result<()> { - // Remove all associated devices - for device_id in self.all_device_ids(user_id) { - self.remove_device(user_id, &device_id?)?; - } - - // Set the password to "" to indicate a deactivated account. Hashes will never result in an - // empty string, so the user will not be able to log in again. Systems like changing the - // password without logging in should check if the account is deactivated. - self.userid_password.insert(user_id.as_bytes(), &[])?; - - // TODO: Unhook 3PID - Ok(()) - } + ) -> impl Iterator> + 'a; /// Creates a new sync filter. Returns the filter id. - #[tracing::instrument(skip(self))] pub fn create_filter( &self, user_id: &UserId, filter: &IncomingFilterDefinition, - ) -> Result { - let filter_id = utils::random_string(4); + ) -> Result; - let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(filter_id.as_bytes()); - - self.userfilterid_filter.insert( - &key, - &serde_json::to_vec(&filter).expect("filter is valid json"), - )?; - - Ok(filter_id) - } - - #[tracing::instrument(skip(self))] pub fn get_filter( &self, user_id: &UserId, filter_id: &str, - ) -> Result> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(filter_id.as_bytes()); - - let raw = self.userfilterid_filter.get(&key)?; - - if let Some(raw) = raw { - serde_json::from_slice(&raw) - .map_err(|_| Error::bad_database("Invalid filter event in db.")) - } else { - Ok(None) - } - } -} - -/// Ensure that a user only sees signatures from themselves and the target user -fn clean_signatures bool>( - cross_signing_key: &mut serde_json::Value, - user_id: &UserId, - allowed_signatures: F, -) -> Result<(), Error> { - if let Some(signatures) = cross_signing_key - .get_mut("signatures") - .and_then(|v| v.as_object_mut()) - { - // Don't allocate for the full size of the current signatures, but require - // at most one resize if nothing is dropped - let new_capacity = signatures.len() / 2; - for (user, signature) in - mem::replace(signatures, serde_json::Map::with_capacity(new_capacity)) - { - let id = <&UserId>::try_from(user.as_str()) - .map_err(|_| Error::bad_database("Invalid user ID in database."))?; - if id == user_id || allowed_signatures(id) { - signatures.insert(user, signature); - } - } - } - - Ok(()) + ) -> Result>; } diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index 7c15f1d8..93d6ea52 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -1,276 +1,107 @@ -use crate::{utils, Error, Result}; -use ruma::{ - api::client::{device::Device, error::ErrorKind, filter::IncomingFilterDefinition}, - encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, - events::{AnyToDeviceEvent, StateEventType}, - serde::Raw, - DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, MxcUri, RoomAliasId, - UInt, UserId, -}; -use std::{collections::BTreeMap, mem, sync::Arc}; -use tracing::warn; +mod data; +pub use data::Data; -use super::abstraction::Tree; +use crate::service::*; -pub struct Users { - pub(super) userid_password: Arc, - pub(super) userid_displayname: Arc, - pub(super) userid_avatarurl: Arc, - pub(super) userid_blurhash: Arc, - pub(super) userdeviceid_token: Arc, - pub(super) userdeviceid_metadata: Arc, // This is also used to check if a device exists - pub(super) userid_devicelistversion: Arc, // DevicelistVersion = u64 - pub(super) token_userdeviceid: Arc, - - pub(super) onetimekeyid_onetimekeys: Arc, // OneTimeKeyId = UserId + DeviceKeyId - pub(super) userid_lastonetimekeyupdate: Arc, // LastOneTimeKeyUpdate = Count - pub(super) keychangeid_userid: Arc, // KeyChangeId = UserId/RoomId + Count - pub(super) keyid_key: Arc, // KeyId = UserId + KeyId (depends on key type) - pub(super) userid_masterkeyid: Arc, - pub(super) userid_selfsigningkeyid: Arc, - pub(super) userid_usersigningkeyid: Arc, - - pub(super) userfilterid_filter: Arc, // UserFilterId = UserId + FilterId - - pub(super) todeviceid_events: Arc, // ToDeviceId = UserId + DeviceId + Count +pub struct Service { + db: D, } -impl Users { +impl Service<_> { /// Check if a user has an account on this homeserver. - #[tracing::instrument(skip(self, user_id))] pub fn exists(&self, user_id: &UserId) -> Result { - Ok(self.userid_password.get(user_id.as_bytes())?.is_some()) + self.db.exists(user_id) } /// Check if account is deactivated - #[tracing::instrument(skip(self, user_id))] pub fn is_deactivated(&self, user_id: &UserId) -> Result { - Ok(self - .userid_password - .get(user_id.as_bytes())? - .ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "User does not exist.", - ))? - .is_empty()) + self.db.is_deactivated(user_id) } /// Check if a user is an admin - #[tracing::instrument(skip(self, user_id, rooms, globals))] pub fn is_admin( &self, user_id: &UserId, - rooms: &super::rooms::Rooms, - globals: &super::globals::Globals, ) -> Result { - let admin_room_alias_id = RoomAliasId::parse(format!("#admins:{}", globals.server_name())) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid alias."))?; - let admin_room_id = rooms.id_from_alias(&admin_room_alias_id)?.unwrap(); - - rooms.is_joined(user_id, &admin_room_id) + self.db.is_admin(user_id) } /// Create a new user account on this homeserver. - #[tracing::instrument(skip(self, user_id, password))] pub fn create(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { - self.set_password(user_id, password)?; - Ok(()) + self.db.set_password(user_id, password) } /// Returns the number of users registered on this server. - #[tracing::instrument(skip(self))] pub fn count(&self) -> Result { - Ok(self.userid_password.iter().count()) + self.db.count() } /// Find out which user an access token belongs to. - #[tracing::instrument(skip(self, token))] pub fn find_from_token(&self, token: &str) -> Result, String)>> { - self.token_userdeviceid - .get(token.as_bytes())? - .map_or(Ok(None), |bytes| { - let mut parts = bytes.split(|&b| b == 0xff); - let user_bytes = parts.next().ok_or_else(|| { - Error::bad_database("User ID in token_userdeviceid is invalid.") - })?; - let device_bytes = parts.next().ok_or_else(|| { - Error::bad_database("Device ID in token_userdeviceid is invalid.") - })?; - - Ok(Some(( - UserId::parse(utils::string_from_bytes(user_bytes).map_err(|_| { - Error::bad_database("User ID in token_userdeviceid is invalid unicode.") - })?) - .map_err(|_| { - Error::bad_database("User ID in token_userdeviceid is invalid.") - })?, - utils::string_from_bytes(device_bytes).map_err(|_| { - Error::bad_database("Device ID in token_userdeviceid is invalid.") - })?, - ))) - }) + self.db.find_from_token(token) } /// Returns an iterator over all users on this homeserver. - #[tracing::instrument(skip(self))] pub fn iter(&self) -> impl Iterator>> + '_ { - self.userid_password.iter().map(|(bytes, _)| { - UserId::parse(utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("User ID in userid_password is invalid unicode.") - })?) - .map_err(|_| Error::bad_database("User ID in userid_password is invalid.")) - }) + self.db.iter() } /// Returns a list of local users as list of usernames. /// /// A user account is considered `local` if the length of it's password is greater then zero. - #[tracing::instrument(skip(self))] pub fn list_local_users(&self) -> Result> { - let users: Vec = self - .userid_password - .iter() - .filter_map(|(username, pw)| self.get_username_with_valid_password(&username, &pw)) - .collect(); - Ok(users) + self.db.list_local_users() } /// Will only return with Some(username) if the password was not empty and the /// username could be successfully parsed. /// If utils::string_from_bytes(...) returns an error that username will be skipped /// and the error will be logged. - #[tracing::instrument(skip(self))] fn get_username_with_valid_password(&self, username: &[u8], password: &[u8]) -> Option { - // A valid password is not empty - if password.is_empty() { - None - } else { - match utils::string_from_bytes(username) { - Ok(u) => Some(u), - Err(e) => { - warn!( - "Failed to parse username while calling get_local_users(): {}", - e.to_string() - ); - None - } - } - } + self.db.get_username_with_valid_password(username, password) } /// Returns the password hash for the given user. - #[tracing::instrument(skip(self, user_id))] pub fn password_hash(&self, user_id: &UserId) -> Result> { - self.userid_password - .get(user_id.as_bytes())? - .map_or(Ok(None), |bytes| { - Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Password hash in db is not valid string.") - })?)) - }) + self.db.password_hash(user_id) } /// Hash and set the user's password to the Argon2 hash - #[tracing::instrument(skip(self, user_id, password))] pub fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { - if let Some(password) = password { - if let Ok(hash) = utils::calculate_hash(password) { - self.userid_password - .insert(user_id.as_bytes(), hash.as_bytes())?; - Ok(()) - } else { - Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Password does not meet the requirements.", - )) - } - } else { - self.userid_password.insert(user_id.as_bytes(), b"")?; - Ok(()) - } + self.db.set_password(user_id, password) } /// Returns the displayname of a user on this homeserver. - #[tracing::instrument(skip(self, user_id))] pub fn displayname(&self, user_id: &UserId) -> Result> { - self.userid_displayname - .get(user_id.as_bytes())? - .map_or(Ok(None), |bytes| { - Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Displayname in db is invalid.") - })?)) - }) + self.db.displayname(user_id) } /// Sets a new displayname or removes it if displayname is None. You still need to nofify all rooms of this change. - #[tracing::instrument(skip(self, user_id, displayname))] pub fn set_displayname(&self, user_id: &UserId, displayname: Option) -> Result<()> { - if let Some(displayname) = displayname { - self.userid_displayname - .insert(user_id.as_bytes(), displayname.as_bytes())?; - } else { - self.userid_displayname.remove(user_id.as_bytes())?; - } - - Ok(()) + self.db.set_displayname(user_id, displayname) } /// Get the avatar_url of a user. - #[tracing::instrument(skip(self, user_id))] pub fn avatar_url(&self, user_id: &UserId) -> Result>> { - self.userid_avatarurl - .get(user_id.as_bytes())? - .map(|bytes| { - let s = utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Avatar URL in db is invalid."))?; - s.try_into() - .map_err(|_| Error::bad_database("Avatar URL in db is invalid.")) - }) - .transpose() + self.db.avatar_url(user_id) } /// Sets a new avatar_url or removes it if avatar_url is None. - #[tracing::instrument(skip(self, user_id, avatar_url))] pub fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option>) -> Result<()> { - if let Some(avatar_url) = avatar_url { - self.userid_avatarurl - .insert(user_id.as_bytes(), avatar_url.to_string().as_bytes())?; - } else { - self.userid_avatarurl.remove(user_id.as_bytes())?; - } - - Ok(()) + self.db.set_avatar_url(user_id, avatar_url) } /// Get the blurhash of a user. - #[tracing::instrument(skip(self, user_id))] pub fn blurhash(&self, user_id: &UserId) -> Result> { - self.userid_blurhash - .get(user_id.as_bytes())? - .map(|bytes| { - let s = utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Avatar URL in db is invalid."))?; - - Ok(s) - }) - .transpose() + self.db.blurhash(user_id) } /// Sets a new avatar_url or removes it if avatar_url is None. - #[tracing::instrument(skip(self, user_id, blurhash))] pub fn set_blurhash(&self, user_id: &UserId, blurhash: Option) -> Result<()> { - if let Some(blurhash) = blurhash { - self.userid_blurhash - .insert(user_id.as_bytes(), blurhash.as_bytes())?; - } else { - self.userid_blurhash.remove(user_id.as_bytes())?; - } - - Ok(()) + self.db.set_blurhash(user_id, blurhash) } /// Adds a new device to a user. - #[tracing::instrument(skip(self, user_id, device_id, token, initial_device_display_name))] pub fn create_device( &self, user_id: &UserId, @@ -278,119 +109,27 @@ impl Users { token: &str, initial_device_display_name: Option, ) -> Result<()> { - // This method should never be called for nonexistent users. - assert!(self.exists(user_id)?); - - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xff); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - self.userid_devicelistversion - .increment(user_id.as_bytes())?; - - self.userdeviceid_metadata.insert( - &userdeviceid, - &serde_json::to_vec(&Device { - device_id: device_id.into(), - display_name: initial_device_display_name, - last_seen_ip: None, // TODO - last_seen_ts: Some(MilliSecondsSinceUnixEpoch::now()), - }) - .expect("Device::to_string never fails."), - )?; - - self.set_token(user_id, device_id, token)?; - - Ok(()) + self.db.create_device(user_id, device_id, token, initial_device_display_name) } /// Removes a device from a user. - #[tracing::instrument(skip(self, user_id, device_id))] pub fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xff); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - // Remove tokens - if let Some(old_token) = self.userdeviceid_token.get(&userdeviceid)? { - self.userdeviceid_token.remove(&userdeviceid)?; - self.token_userdeviceid.remove(&old_token)?; - } - - // Remove todevice events - let mut prefix = userdeviceid.clone(); - prefix.push(0xff); - - for (key, _) in self.todeviceid_events.scan_prefix(prefix) { - self.todeviceid_events.remove(&key)?; - } - - // TODO: Remove onetimekeys - - self.userid_devicelistversion - .increment(user_id.as_bytes())?; - - self.userdeviceid_metadata.remove(&userdeviceid)?; - - Ok(()) + self.db.remove_device(user_id, device_id) } /// Returns an iterator over all device ids of this user. - #[tracing::instrument(skip(self, user_id))] pub fn all_device_ids<'a>( &'a self, user_id: &UserId, ) -> impl Iterator>> + 'a { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); - // All devices have metadata - self.userdeviceid_metadata - .scan_prefix(prefix) - .map(|(bytes, _)| { - Ok(utils::string_from_bytes( - bytes - .rsplit(|&b| b == 0xff) - .next() - .ok_or_else(|| Error::bad_database("UserDevice ID in db is invalid."))?, - ) - .map_err(|_| Error::bad_database("Device ID in userdeviceid_metadata is invalid."))? - .into()) - }) + self.db.all_device_ids(user_id) } /// Replaces the access token of one device. - #[tracing::instrument(skip(self, user_id, device_id, token))] pub fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xff); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - // All devices have metadata - assert!(self.userdeviceid_metadata.get(&userdeviceid)?.is_some()); - - // Remove old token - if let Some(old_token) = self.userdeviceid_token.get(&userdeviceid)? { - self.token_userdeviceid.remove(&old_token)?; - // It will be removed from userdeviceid_token by the insert later - } - - // Assign token to user device combination - self.userdeviceid_token - .insert(&userdeviceid, token.as_bytes())?; - self.token_userdeviceid - .insert(token.as_bytes(), &userdeviceid)?; - - Ok(()) + self.db.set_token(user_id, device_id, token) } - #[tracing::instrument(skip( - self, - user_id, - device_id, - one_time_key_key, - one_time_key_value, - globals - ))] pub fn add_one_time_key( &self, user_id: &UserId, @@ -399,464 +138,103 @@ impl Users { one_time_key_value: &Raw, globals: &super::globals::Globals, ) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(device_id.as_bytes()); - - // All devices have metadata - // Only existing devices should be able to call this. - assert!(self.userdeviceid_metadata.get(&key)?.is_some()); - - key.push(0xff); - // TODO: Use DeviceKeyId::to_string when it's available (and update everything, - // because there are no wrapping quotation marks anymore) - key.extend_from_slice( - serde_json::to_string(one_time_key_key) - .expect("DeviceKeyId::to_string always works") - .as_bytes(), - ); - - self.onetimekeyid_onetimekeys.insert( - &key, - &serde_json::to_vec(&one_time_key_value).expect("OneTimeKey::to_vec always works"), - )?; - - self.userid_lastonetimekeyupdate - .insert(user_id.as_bytes(), &globals.next_count()?.to_be_bytes())?; - - Ok(()) + self.db.add_one_time_key(user_id, device_id, one_time_key_key, one_time_key_value) } - #[tracing::instrument(skip(self, user_id))] pub fn last_one_time_keys_update(&self, user_id: &UserId) -> Result { - self.userid_lastonetimekeyupdate - .get(user_id.as_bytes())? - .map(|bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Count in roomid_lastroomactiveupdate is invalid.") - }) - }) - .unwrap_or(Ok(0)) + self.db.last_one_time_keys_update(user_id) } - #[tracing::instrument(skip(self, user_id, device_id, key_algorithm, globals))] pub fn take_one_time_key( &self, user_id: &UserId, device_id: &DeviceId, key_algorithm: &DeviceKeyAlgorithm, - globals: &super::globals::Globals, ) -> Result, Raw)>> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); - prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xff); - prefix.push(b'"'); // Annoying quotation mark - prefix.extend_from_slice(key_algorithm.as_ref().as_bytes()); - prefix.push(b':'); - - self.userid_lastonetimekeyupdate - .insert(user_id.as_bytes(), &globals.next_count()?.to_be_bytes())?; - - self.onetimekeyid_onetimekeys - .scan_prefix(prefix) - .next() - .map(|(key, value)| { - self.onetimekeyid_onetimekeys.remove(&key)?; - - Ok(( - serde_json::from_slice( - &*key - .rsplit(|&b| b == 0xff) - .next() - .ok_or_else(|| Error::bad_database("OneTimeKeyId in db is invalid."))?, - ) - .map_err(|_| Error::bad_database("OneTimeKeyId in db is invalid."))?, - serde_json::from_slice(&*value) - .map_err(|_| Error::bad_database("OneTimeKeys in db are invalid."))?, - )) - }) - .transpose() + self.db.take_one_time_key(user_id, device_id, key_algorithm) } - #[tracing::instrument(skip(self, user_id, device_id))] pub fn count_one_time_keys( &self, user_id: &UserId, device_id: &DeviceId, ) -> Result> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xff); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - let mut counts = BTreeMap::new(); - - for algorithm in - self.onetimekeyid_onetimekeys - .scan_prefix(userdeviceid) - .map(|(bytes, _)| { - Ok::<_, Error>( - serde_json::from_slice::>( - &*bytes.rsplit(|&b| b == 0xff).next().ok_or_else(|| { - Error::bad_database("OneTimeKey ID in db is invalid.") - })?, - ) - .map_err(|_| Error::bad_database("DeviceKeyId in db is invalid."))? - .algorithm(), - ) - }) - { - *counts.entry(algorithm?).or_default() += UInt::from(1_u32); - } - - Ok(counts) + self.db.count_one_time_keys(user_id, device_id) } - #[tracing::instrument(skip(self, user_id, device_id, device_keys, rooms, globals))] pub fn add_device_keys( &self, user_id: &UserId, device_id: &DeviceId, device_keys: &Raw, - rooms: &super::rooms::Rooms, - globals: &super::globals::Globals, ) -> Result<()> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xff); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - self.keyid_key.insert( - &userdeviceid, - &serde_json::to_vec(&device_keys).expect("DeviceKeys::to_vec always works"), - )?; - - self.mark_device_key_update(user_id, rooms, globals)?; - - Ok(()) + self.db.add_device_keys(user_id, device_id, device_keys) } - #[tracing::instrument(skip( - self, - master_key, - self_signing_key, - user_signing_key, - rooms, - globals - ))] pub fn add_cross_signing_keys( &self, user_id: &UserId, master_key: &Raw, self_signing_key: &Option>, user_signing_key: &Option>, - rooms: &super::rooms::Rooms, - globals: &super::globals::Globals, ) -> Result<()> { - // TODO: Check signatures - - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); - - // Master key - let mut master_key_ids = master_key - .deserialize() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid master key"))? - .keys - .into_values(); - - let master_key_id = master_key_ids.next().ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Master key contained no key.", - ))?; - - if master_key_ids.next().is_some() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Master key contained more than one key.", - )); - } - - let mut master_key_key = prefix.clone(); - master_key_key.extend_from_slice(master_key_id.as_bytes()); - - self.keyid_key - .insert(&master_key_key, master_key.json().get().as_bytes())?; - - self.userid_masterkeyid - .insert(user_id.as_bytes(), &master_key_key)?; - - // Self-signing key - if let Some(self_signing_key) = self_signing_key { - let mut self_signing_key_ids = self_signing_key - .deserialize() - .map_err(|_| { - Error::BadRequest(ErrorKind::InvalidParam, "Invalid self signing key") - })? - .keys - .into_values(); - - let self_signing_key_id = self_signing_key_ids.next().ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Self signing key contained no key.", - ))?; - - if self_signing_key_ids.next().is_some() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Self signing key contained more than one key.", - )); - } - - let mut self_signing_key_key = prefix.clone(); - self_signing_key_key.extend_from_slice(self_signing_key_id.as_bytes()); - - self.keyid_key.insert( - &self_signing_key_key, - self_signing_key.json().get().as_bytes(), - )?; - - self.userid_selfsigningkeyid - .insert(user_id.as_bytes(), &self_signing_key_key)?; - } - - // User-signing key - if let Some(user_signing_key) = user_signing_key { - let mut user_signing_key_ids = user_signing_key - .deserialize() - .map_err(|_| { - Error::BadRequest(ErrorKind::InvalidParam, "Invalid user signing key") - })? - .keys - .into_values(); - - let user_signing_key_id = user_signing_key_ids.next().ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "User signing key contained no key.", - ))?; - - if user_signing_key_ids.next().is_some() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "User signing key contained more than one key.", - )); - } - - let mut user_signing_key_key = prefix; - user_signing_key_key.extend_from_slice(user_signing_key_id.as_bytes()); - - self.keyid_key.insert( - &user_signing_key_key, - user_signing_key.json().get().as_bytes(), - )?; - - self.userid_usersigningkeyid - .insert(user_id.as_bytes(), &user_signing_key_key)?; - } - - self.mark_device_key_update(user_id, rooms, globals)?; - - Ok(()) + self.db.add_cross_signing_keys(user_id, master_key, self_signing_key, user_signing_key) } - #[tracing::instrument(skip(self, target_id, key_id, signature, sender_id, rooms, globals))] pub fn sign_key( &self, target_id: &UserId, key_id: &str, signature: (String, String), sender_id: &UserId, - rooms: &super::rooms::Rooms, - globals: &super::globals::Globals, ) -> Result<()> { - let mut key = target_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(key_id.as_bytes()); - - let mut cross_signing_key: serde_json::Value = - serde_json::from_slice(&self.keyid_key.get(&key)?.ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Tried to sign nonexistent key.", - ))?) - .map_err(|_| Error::bad_database("key in keyid_key is invalid."))?; - - let signatures = cross_signing_key - .get_mut("signatures") - .ok_or_else(|| Error::bad_database("key in keyid_key has no signatures field."))? - .as_object_mut() - .ok_or_else(|| Error::bad_database("key in keyid_key has invalid signatures field."))? - .entry(sender_id.to_owned()) - .or_insert_with(|| serde_json::Map::new().into()); - - signatures - .as_object_mut() - .ok_or_else(|| Error::bad_database("signatures in keyid_key for a user is invalid."))? - .insert(signature.0, signature.1.into()); - - self.keyid_key.insert( - &key, - &serde_json::to_vec(&cross_signing_key).expect("CrossSigningKey::to_vec always works"), - )?; - - // TODO: Should we notify about this change? - self.mark_device_key_update(target_id, rooms, globals)?; - - Ok(()) + self.db.sign_key(target_id, key_id, signature, sender_id) } - #[tracing::instrument(skip(self, user_or_room_id, from, to))] pub fn keys_changed<'a>( &'a self, user_or_room_id: &str, from: u64, to: Option, ) -> impl Iterator>> + 'a { - let mut prefix = user_or_room_id.as_bytes().to_vec(); - prefix.push(0xff); - - let mut start = prefix.clone(); - start.extend_from_slice(&(from + 1).to_be_bytes()); - - let to = to.unwrap_or(u64::MAX); - - self.keychangeid_userid - .iter_from(&start, false) - .take_while(move |(k, _)| { - k.starts_with(&prefix) - && if let Some(current) = k.splitn(2, |&b| b == 0xff).nth(1) { - if let Ok(c) = utils::u64_from_bytes(current) { - c <= to - } else { - warn!("BadDatabase: Could not parse keychangeid_userid bytes"); - false - } - } else { - warn!("BadDatabase: Could not parse keychangeid_userid"); - false - } - }) - .map(|(_, bytes)| { - UserId::parse(utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("User ID in devicekeychangeid_userid is invalid unicode.") - })?) - .map_err(|_| Error::bad_database("User ID in devicekeychangeid_userid is invalid.")) - }) + self.db.keys_changed(user_or_room_id, from, to) } - #[tracing::instrument(skip(self, user_id, rooms, globals))] pub fn mark_device_key_update( &self, user_id: &UserId, - rooms: &super::rooms::Rooms, - globals: &super::globals::Globals, ) -> Result<()> { - let count = globals.next_count()?.to_be_bytes(); - for room_id in rooms.rooms_joined(user_id).filter_map(|r| r.ok()) { - // Don't send key updates to unencrypted rooms - if rooms - .room_state_get(&room_id, &StateEventType::RoomEncryption, "")? - .is_none() - { - continue; - } - - let mut key = room_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(&count); - - self.keychangeid_userid.insert(&key, user_id.as_bytes())?; - } - - let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(&count); - self.keychangeid_userid.insert(&key, user_id.as_bytes())?; - - Ok(()) + self.db.mark_device_key_update(user_id) } - #[tracing::instrument(skip(self, user_id, device_id))] pub fn get_device_keys( &self, user_id: &UserId, device_id: &DeviceId, ) -> Result>> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(device_id.as_bytes()); - - self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { - Ok(Some(serde_json::from_slice(&bytes).map_err(|_| { - Error::bad_database("DeviceKeys in db are invalid.") - })?)) - }) + self.db.get_device_keys(user_id, device_id) } - #[tracing::instrument(skip(self, user_id, allowed_signatures))] pub fn get_master_key bool>( &self, user_id: &UserId, allowed_signatures: F, ) -> Result>> { - self.userid_masterkeyid - .get(user_id.as_bytes())? - .map_or(Ok(None), |key| { - self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { - let mut cross_signing_key = serde_json::from_slice::(&bytes) - .map_err(|_| Error::bad_database("CrossSigningKey in db is invalid."))?; - clean_signatures(&mut cross_signing_key, user_id, allowed_signatures)?; - - Ok(Some(Raw::from_json( - serde_json::value::to_raw_value(&cross_signing_key) - .expect("Value to RawValue serialization"), - ))) - }) - }) + self.db.get_master_key(user_id, allow_signatures) } - #[tracing::instrument(skip(self, user_id, allowed_signatures))] pub fn get_self_signing_key bool>( &self, user_id: &UserId, allowed_signatures: F, ) -> Result>> { - self.userid_selfsigningkeyid - .get(user_id.as_bytes())? - .map_or(Ok(None), |key| { - self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { - let mut cross_signing_key = serde_json::from_slice::(&bytes) - .map_err(|_| Error::bad_database("CrossSigningKey in db is invalid."))?; - clean_signatures(&mut cross_signing_key, user_id, allowed_signatures)?; - - Ok(Some(Raw::from_json( - serde_json::value::to_raw_value(&cross_signing_key) - .expect("Value to RawValue serialization"), - ))) - }) - }) + self.db.get_self_signing_key(user_id, allowed_signatures) } - #[tracing::instrument(skip(self, user_id))] pub fn get_user_signing_key(&self, user_id: &UserId) -> Result>> { - self.userid_usersigningkeyid - .get(user_id.as_bytes())? - .map_or(Ok(None), |key| { - self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { - Ok(Some(serde_json::from_slice(&bytes).map_err(|_| { - Error::bad_database("CrossSigningKey in db is invalid.") - })?)) - }) - }) + self.db.get_user_signing_key(user_id) } - #[tracing::instrument(skip( - self, - sender, - target_user_id, - target_device_id, - event_type, - content, - globals - ))] pub fn add_to_device_event( &self, sender: &UserId, @@ -864,158 +242,57 @@ impl Users { target_device_id: &DeviceId, event_type: &str, content: serde_json::Value, - globals: &super::globals::Globals, ) -> Result<()> { - let mut key = target_user_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(target_device_id.as_bytes()); - key.push(0xff); - key.extend_from_slice(&globals.next_count()?.to_be_bytes()); - - let mut json = serde_json::Map::new(); - json.insert("type".to_owned(), event_type.to_owned().into()); - json.insert("sender".to_owned(), sender.to_string().into()); - json.insert("content".to_owned(), content); - - let value = serde_json::to_vec(&json).expect("Map::to_vec always works"); - - self.todeviceid_events.insert(&key, &value)?; - - Ok(()) + self.db.add_to_device_event(sender, target_user_id, target_device_id, event_type, content) } - #[tracing::instrument(skip(self, user_id, device_id))] pub fn get_to_device_events( &self, user_id: &UserId, device_id: &DeviceId, ) -> Result>> { - let mut events = Vec::new(); - - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); - prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xff); - - for (_, value) in self.todeviceid_events.scan_prefix(prefix) { - events.push( - serde_json::from_slice(&value) - .map_err(|_| Error::bad_database("Event in todeviceid_events is invalid."))?, - ); - } - - Ok(events) + self.get_to_device_events(user_id, device_id) } - #[tracing::instrument(skip(self, user_id, device_id, until))] pub fn remove_to_device_events( &self, user_id: &UserId, device_id: &DeviceId, until: u64, ) -> Result<()> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); - prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xff); - - let mut last = prefix.clone(); - last.extend_from_slice(&until.to_be_bytes()); - - for (key, _) in self - .todeviceid_events - .iter_from(&last, true) // this includes last - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(|(key, _)| { - Ok::<_, Error>(( - key.clone(), - utils::u64_from_bytes(&key[key.len() - mem::size_of::()..key.len()]) - .map_err(|_| Error::bad_database("ToDeviceId has invalid count bytes."))?, - )) - }) - .filter_map(|r| r.ok()) - .take_while(|&(_, count)| count <= until) - { - self.todeviceid_events.remove(&key)?; - } - - Ok(()) + self.db.remove_to_device_events(user_id, device_id, until) } - #[tracing::instrument(skip(self, user_id, device_id, device))] pub fn update_device_metadata( &self, user_id: &UserId, device_id: &DeviceId, device: &Device, ) -> Result<()> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xff); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - // Only existing devices should be able to call this. - assert!(self.userdeviceid_metadata.get(&userdeviceid)?.is_some()); - - self.userid_devicelistversion - .increment(user_id.as_bytes())?; - - self.userdeviceid_metadata.insert( - &userdeviceid, - &serde_json::to_vec(device).expect("Device::to_string always works"), - )?; - - Ok(()) + self.db.update_device_metadata(user_id, device_id, device) } /// Get device metadata. - #[tracing::instrument(skip(self, user_id, device_id))] pub fn get_device_metadata( &self, user_id: &UserId, device_id: &DeviceId, ) -> Result> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xff); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - self.userdeviceid_metadata - .get(&userdeviceid)? - .map_or(Ok(None), |bytes| { - Ok(Some(serde_json::from_slice(&bytes).map_err(|_| { - Error::bad_database("Metadata in userdeviceid_metadata is invalid.") - })?)) - }) + self.get_device_metadata(user_id, device_id) } - #[tracing::instrument(skip(self, user_id))] pub fn get_devicelist_version(&self, user_id: &UserId) -> Result> { - self.userid_devicelistversion - .get(user_id.as_bytes())? - .map_or(Ok(None), |bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid devicelistversion in db.")) - .map(Some) - }) + self.db.devicelist_version(user_id) } - #[tracing::instrument(skip(self, user_id))] pub fn all_devices_metadata<'a>( &'a self, user_id: &UserId, ) -> impl Iterator> + 'a { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); - - self.userdeviceid_metadata - .scan_prefix(key) - .map(|(_, bytes)| { - serde_json::from_slice::(&bytes) - .map_err(|_| Error::bad_database("Device in userdeviceid_metadata is invalid.")) - }) + self.db.all_devices_metadata(user_id) } /// Deactivate account - #[tracing::instrument(skip(self, user_id))] pub fn deactivate_account(&self, user_id: &UserId) -> Result<()> { // Remove all associated devices for device_id in self.all_device_ids(user_id) { @@ -1032,44 +309,20 @@ impl Users { } /// Creates a new sync filter. Returns the filter id. - #[tracing::instrument(skip(self))] pub fn create_filter( &self, user_id: &UserId, filter: &IncomingFilterDefinition, ) -> Result { - let filter_id = utils::random_string(4); - - let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(filter_id.as_bytes()); - - self.userfilterid_filter.insert( - &key, - &serde_json::to_vec(&filter).expect("filter is valid json"), - )?; - - Ok(filter_id) + self.db.create_filter(user_id, filter) } - #[tracing::instrument(skip(self))] pub fn get_filter( &self, user_id: &UserId, filter_id: &str, ) -> Result> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(filter_id.as_bytes()); - - let raw = self.userfilterid_filter.get(&key)?; - - if let Some(raw) = raw { - serde_json::from_slice(&raw) - .map_err(|_| Error::bad_database("Invalid filter event in db.")) - } else { - Ok(None) - } + self.db.get_filter(user_id, filter_id) } }