From 8773e5013d08a37851c4379db6bc66641602ece2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Sun, 11 Apr 2021 21:01:27 +0200 Subject: [PATCH] feat: incoming invites over federation --- Cargo.lock | 36 +++--- Cargo.toml | 8 +- src/client_server/account.rs | 10 +- src/client_server/membership.rs | 3 + src/client_server/sync.rs | 45 ++----- src/database.rs | 10 +- src/database/pusher.rs | 4 +- src/database/rooms.rs | 203 +++++++++++++++++++++++--------- src/main.rs | 1 + src/server_server.rs | 133 +++++++++++++++++---- 10 files changed, 307 insertions(+), 146 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index cf881c2a..42042b62 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1625,7 +1625,7 @@ dependencies = [ [[package]] name = "ruma" version = "0.0.2" -source = "git+https://github.com/ruma/ruma?rev=a310ccc318a4eb51062923d570d5a86c1468e8a1#a310ccc318a4eb51062923d570d5a86c1468e8a1" +source = "git+https://github.com/timokoesters/ruma?rev=b11de1e1f9d3c15267d09617131cf217f8277fa4#b11de1e1f9d3c15267d09617131cf217f8277fa4" dependencies = [ "assign", "js_int", @@ -1645,7 +1645,7 @@ dependencies = [ [[package]] name = "ruma-api" version = "0.17.0-alpha.2" -source = "git+https://github.com/ruma/ruma?rev=a310ccc318a4eb51062923d570d5a86c1468e8a1#a310ccc318a4eb51062923d570d5a86c1468e8a1" +source = "git+https://github.com/timokoesters/ruma?rev=b11de1e1f9d3c15267d09617131cf217f8277fa4#b11de1e1f9d3c15267d09617131cf217f8277fa4" dependencies = [ "http", "percent-encoding", @@ -1660,7 +1660,7 @@ dependencies = [ [[package]] name = "ruma-api-macros" version = "0.17.0-alpha.2" -source = "git+https://github.com/ruma/ruma?rev=a310ccc318a4eb51062923d570d5a86c1468e8a1#a310ccc318a4eb51062923d570d5a86c1468e8a1" +source = "git+https://github.com/timokoesters/ruma?rev=b11de1e1f9d3c15267d09617131cf217f8277fa4#b11de1e1f9d3c15267d09617131cf217f8277fa4" dependencies = [ "proc-macro-crate", "proc-macro2", @@ -1671,7 +1671,7 @@ dependencies = [ [[package]] name = "ruma-appservice-api" version = "0.2.0-alpha.2" -source = "git+https://github.com/ruma/ruma?rev=a310ccc318a4eb51062923d570d5a86c1468e8a1#a310ccc318a4eb51062923d570d5a86c1468e8a1" +source = "git+https://github.com/timokoesters/ruma?rev=b11de1e1f9d3c15267d09617131cf217f8277fa4#b11de1e1f9d3c15267d09617131cf217f8277fa4" dependencies = [ "ruma-api", "ruma-common", @@ -1685,7 +1685,7 @@ dependencies = [ [[package]] name = "ruma-client-api" version = "0.10.0-alpha.2" -source = "git+https://github.com/ruma/ruma?rev=a310ccc318a4eb51062923d570d5a86c1468e8a1#a310ccc318a4eb51062923d570d5a86c1468e8a1" +source = "git+https://github.com/timokoesters/ruma?rev=b11de1e1f9d3c15267d09617131cf217f8277fa4#b11de1e1f9d3c15267d09617131cf217f8277fa4" dependencies = [ "assign", "http", @@ -1704,7 +1704,7 @@ dependencies = [ [[package]] name = "ruma-common" version = "0.3.1" -source = "git+https://github.com/ruma/ruma?rev=a310ccc318a4eb51062923d570d5a86c1468e8a1#a310ccc318a4eb51062923d570d5a86c1468e8a1" +source = "git+https://github.com/timokoesters/ruma?rev=b11de1e1f9d3c15267d09617131cf217f8277fa4#b11de1e1f9d3c15267d09617131cf217f8277fa4" dependencies = [ "indexmap", "js_int", @@ -1720,7 +1720,7 @@ dependencies = [ [[package]] name = "ruma-events" version = "0.22.0-alpha.2" -source = "git+https://github.com/ruma/ruma?rev=a310ccc318a4eb51062923d570d5a86c1468e8a1#a310ccc318a4eb51062923d570d5a86c1468e8a1" +source = "git+https://github.com/timokoesters/ruma?rev=b11de1e1f9d3c15267d09617131cf217f8277fa4#b11de1e1f9d3c15267d09617131cf217f8277fa4" dependencies = [ "js_int", "ruma-common", @@ -1734,7 +1734,7 @@ dependencies = [ [[package]] name = "ruma-events-macros" version = "0.22.0-alpha.2" -source = "git+https://github.com/ruma/ruma?rev=a310ccc318a4eb51062923d570d5a86c1468e8a1#a310ccc318a4eb51062923d570d5a86c1468e8a1" +source = "git+https://github.com/timokoesters/ruma?rev=b11de1e1f9d3c15267d09617131cf217f8277fa4#b11de1e1f9d3c15267d09617131cf217f8277fa4" dependencies = [ "proc-macro-crate", "proc-macro2", @@ -1745,7 +1745,7 @@ dependencies = [ [[package]] name = "ruma-federation-api" version = "0.1.0-alpha.1" -source = "git+https://github.com/ruma/ruma?rev=a310ccc318a4eb51062923d570d5a86c1468e8a1#a310ccc318a4eb51062923d570d5a86c1468e8a1" +source = "git+https://github.com/timokoesters/ruma?rev=b11de1e1f9d3c15267d09617131cf217f8277fa4#b11de1e1f9d3c15267d09617131cf217f8277fa4" dependencies = [ "js_int", "ruma-api", @@ -1760,7 +1760,7 @@ dependencies = [ [[package]] name = "ruma-identifiers" version = "0.18.1" -source = "git+https://github.com/ruma/ruma?rev=a310ccc318a4eb51062923d570d5a86c1468e8a1#a310ccc318a4eb51062923d570d5a86c1468e8a1" +source = "git+https://github.com/timokoesters/ruma?rev=b11de1e1f9d3c15267d09617131cf217f8277fa4#b11de1e1f9d3c15267d09617131cf217f8277fa4" dependencies = [ "paste", "rand", @@ -1774,7 +1774,7 @@ dependencies = [ [[package]] name = "ruma-identifiers-macros" version = "0.18.1" -source = "git+https://github.com/ruma/ruma?rev=a310ccc318a4eb51062923d570d5a86c1468e8a1#a310ccc318a4eb51062923d570d5a86c1468e8a1" +source = "git+https://github.com/timokoesters/ruma?rev=b11de1e1f9d3c15267d09617131cf217f8277fa4#b11de1e1f9d3c15267d09617131cf217f8277fa4" dependencies = [ "proc-macro2", "quote", @@ -1785,12 +1785,12 @@ dependencies = [ [[package]] name = "ruma-identifiers-validation" version = "0.2.2" -source = "git+https://github.com/ruma/ruma?rev=a310ccc318a4eb51062923d570d5a86c1468e8a1#a310ccc318a4eb51062923d570d5a86c1468e8a1" +source = "git+https://github.com/timokoesters/ruma?rev=b11de1e1f9d3c15267d09617131cf217f8277fa4#b11de1e1f9d3c15267d09617131cf217f8277fa4" [[package]] name = "ruma-identity-service-api" version = "0.0.1" -source = "git+https://github.com/ruma/ruma?rev=a310ccc318a4eb51062923d570d5a86c1468e8a1#a310ccc318a4eb51062923d570d5a86c1468e8a1" +source = "git+https://github.com/timokoesters/ruma?rev=b11de1e1f9d3c15267d09617131cf217f8277fa4#b11de1e1f9d3c15267d09617131cf217f8277fa4" dependencies = [ "ruma-api", "ruma-common", @@ -1803,7 +1803,7 @@ dependencies = [ [[package]] name = "ruma-push-gateway-api" version = "0.0.1" -source = "git+https://github.com/ruma/ruma?rev=a310ccc318a4eb51062923d570d5a86c1468e8a1#a310ccc318a4eb51062923d570d5a86c1468e8a1" +source = "git+https://github.com/timokoesters/ruma?rev=b11de1e1f9d3c15267d09617131cf217f8277fa4#b11de1e1f9d3c15267d09617131cf217f8277fa4" dependencies = [ "js_int", "ruma-api", @@ -1818,7 +1818,7 @@ dependencies = [ [[package]] name = "ruma-serde" version = "0.3.1" -source = "git+https://github.com/ruma/ruma?rev=a310ccc318a4eb51062923d570d5a86c1468e8a1#a310ccc318a4eb51062923d570d5a86c1468e8a1" +source = "git+https://github.com/timokoesters/ruma?rev=b11de1e1f9d3c15267d09617131cf217f8277fa4#b11de1e1f9d3c15267d09617131cf217f8277fa4" dependencies = [ "form_urlencoded", "itoa", @@ -1831,7 +1831,7 @@ dependencies = [ [[package]] name = "ruma-serde-macros" version = "0.3.1" -source = "git+https://github.com/ruma/ruma?rev=a310ccc318a4eb51062923d570d5a86c1468e8a1#a310ccc318a4eb51062923d570d5a86c1468e8a1" +source = "git+https://github.com/timokoesters/ruma?rev=b11de1e1f9d3c15267d09617131cf217f8277fa4#b11de1e1f9d3c15267d09617131cf217f8277fa4" dependencies = [ "proc-macro-crate", "proc-macro2", @@ -1842,7 +1842,7 @@ dependencies = [ [[package]] name = "ruma-signatures" version = "0.6.0" -source = "git+https://github.com/ruma/ruma?rev=a310ccc318a4eb51062923d570d5a86c1468e8a1#a310ccc318a4eb51062923d570d5a86c1468e8a1" +source = "git+https://github.com/timokoesters/ruma?rev=b11de1e1f9d3c15267d09617131cf217f8277fa4#b11de1e1f9d3c15267d09617131cf217f8277fa4" dependencies = [ "base64 0.13.0", "ring", @@ -2120,7 +2120,7 @@ checksum = "3015a7d0a5fd5105c91c3710d42f9ccf0abfb287d62206484dcc67f9569a6483" [[package]] name = "state-res" version = "0.1.0" -source = "git+https://github.com/timokoesters/state-res?rev=1ec42ea2fc0b0728bf027a5899839ad94bb3091b#1ec42ea2fc0b0728bf027a5899839ad94bb3091b" +source = "git+https://github.com/timokoesters/state-res?rev=2e90b36babeb0d6b99ce8d4b513302a25dcdffc1#2e90b36babeb0d6b99ce8d4b513302a25dcdffc1" dependencies = [ "itertools 0.10.0", "log", diff --git a/Cargo.toml b/Cargo.toml index 3109dd8e..a28c08d1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,12 +18,12 @@ rocket = { git = "https://github.com/SergioBenitez/Rocket.git", rev = "93e62c86e #rocket = { git = "https://github.com/timokoesters/Rocket.git", branch = "empty_parameters", default-features = false, features = ["tls"] } # Used for matrix spec type definitions and helpers -ruma = { git = "https://github.com/ruma/ruma", rev = "a310ccc318a4eb51062923d570d5a86c1468e8a1", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "unstable-pre-spec", "unstable-exhaustive-types"] } -#ruma = { git = "https://github.com/DevinR528/ruma", features = ["rand", "client-api", "federation-api", "push-gateway-api", "unstable-exhaustive-types", "unstable-pre-spec", "unstable-synapse-quirks"], branch = "verified-export" } -#ruma = { path = "../ruma/ruma", features = ["unstable-exhaustive-types", "rand", "client-api", "federation-api", "push-gateway-api", "unstable-pre-spec", "unstable-synapse-quirks"] } +#ruma = { git = "https://github.com/ruma/ruma", rev = "a310ccc318a4eb51062923d570d5a86c1468e8a1", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "unstable-pre-spec", "unstable-exhaustive-types"] } +ruma = { git = "https://github.com/timokoesters/ruma", rev = "b11de1e1f9d3c15267d09617131cf217f8277fa4", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "unstable-pre-spec", "unstable-exhaustive-types"] } +#ruma = { path = "../ruma/ruma", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "unstable-pre-spec", "unstable-exhaustive-types"] } # Used when doing state resolution -state-res = { git = "https://github.com/timokoesters/state-res", rev = "1ec42ea2fc0b0728bf027a5899839ad94bb3091b", features = ["unstable-pre-spec"] } +state-res = { git = "https://github.com/timokoesters/state-res", rev = "2e90b36babeb0d6b99ce8d4b513302a25dcdffc1", features = ["unstable-pre-spec"] } #state-res = { path = "../state-res", features = ["unstable-pre-spec"] } # Used for long polling and federation sender, should be the same as rocket::tokio diff --git a/src/client_server/account.rs b/src/client_server/account.rs index 4c5b60cb..2241d45e 100644 --- a/src/client_server/account.rs +++ b/src/client_server/account.rs @@ -617,11 +617,11 @@ pub async fn deactivate_route( } // Leave all joined rooms and reject all invitations - for room_id in db - .rooms - .rooms_joined(&sender_user) - .chain(db.rooms.rooms_invited(&sender_user)) - { + for room_id in db.rooms.rooms_joined(&sender_user).chain( + db.rooms + .rooms_invited(&sender_user) + .map(|t| t.map(|(r, _)| r)), + ) { let room_id = room_id?; let event = member::MemberEventContent { membership: member::MembershipState::Leave, diff --git a/src/client_server/membership.rs b/src/client_server/membership.rs index 3f4f23ff..38762465 100644 --- a/src/client_server/membership.rs +++ b/src/client_server/membership.rs @@ -599,6 +599,8 @@ async fn join_room_by_id_helper( Error::BadServerResponse("Invalid user id in send_join response.") })?; + let invite_state = Vec::new(); // TODO add a few important events + // Update our membership info, we do this here incase a user is invited // and immediately leaves we need the DB to record the invite event for auth db.rooms.update_membership( @@ -616,6 +618,7 @@ async fn join_room_by_id_helper( Error::BadServerResponse("Invalid membership state content.") })?, &pdu.sender, + Some(invite_state), &db.account_data, &db.globals, )?; diff --git a/src/client_server/sync.rs b/src/client_server/sync.rs index bd7046dc..f1ad9a55 100644 --- a/src/client_server/sync.rs +++ b/src/client_server/sync.rs @@ -588,44 +588,23 @@ pub async fn sync_events_route( } let mut invited_rooms = BTreeMap::new(); - for room_id in db.rooms.rooms_invited(&sender_user) { - let room_id = room_id?; - let mut invited_since_last_sync = false; - for pdu in db.rooms.pdus_since(&sender_user, &room_id, since)? { - let (_, pdu) = pdu?; - if pdu.kind == EventType::RoomMember && pdu.state_key == Some(sender_user.to_string()) { - let content = serde_json::from_value::< - Raw, - >(pdu.content.clone()) - .expect("Raw::from_value always works") - .deserialize() - .map_err(|_| Error::bad_database("Invalid PDU in database."))?; + for result in db.rooms.rooms_invited(&sender_user) { + let (room_id, invite_state_events) = result?; + let invite_count = db.rooms.get_invite_count(&room_id, &sender_user)?; - if content.membership == MembershipState::Invite { - invited_since_last_sync = true; - break; - } - } - } - - if !invited_since_last_sync { + // Invited before last sync + if Some(since) >= invite_count { continue; } - let invited_room = sync_events::InvitedRoom { - invite_state: sync_events::InviteState { - events: db - .rooms - .room_state_full(&room_id)? - .into_iter() - .map(|(_, pdu)| pdu.to_stripped_state_event()) - .collect(), + invited_rooms.insert( + room_id.clone(), + sync_events::InvitedRoom { + invite_state: sync_events::InviteState { + events: invite_state_events, + }, }, - }; - - if !invited_room.is_empty() { - invited_rooms.insert(room_id.clone(), invited_room); - } + ); } for user_id in left_encrypted_users { diff --git a/src/database.rs b/src/database.rs index a266c219..211c3f4b 100644 --- a/src/database.rs +++ b/src/database.rs @@ -161,8 +161,8 @@ impl Database { userroomid_joined: db.open_tree("userroomid_joined")?, roomuserid_joined: db.open_tree("roomuserid_joined")?, roomuseroncejoinedids: db.open_tree("roomuseroncejoinedids")?, - userroomid_invited: db.open_tree("userroomid_invited")?, - roomuserid_invited: db.open_tree("roomuserid_invited")?, + userroomid_invitestate: db.open_tree("userroomid_invitestate")?, + roomuserid_invitecount: db.open_tree("roomuserid_invitecount")?, userroomid_left: db.open_tree("userroomid_left")?, statekey_shortstatekey: db.open_tree("statekey_shortstatekey")?, @@ -236,7 +236,11 @@ impl Database { ); futures.push(self.rooms.userroomid_joined.watch_prefix(&userid_prefix)); - futures.push(self.rooms.userroomid_invited.watch_prefix(&userid_prefix)); + futures.push( + self.rooms + .userroomid_invitestate + .watch_prefix(&userid_prefix), + ); futures.push(self.rooms.userroomid_left.watch_prefix(&userid_prefix)); // Events for rooms we are in diff --git a/src/database/pusher.rs b/src/database/pusher.rs index 9a9452c0..f4c02d0a 100644 --- a/src/database/pusher.rs +++ b/src/database/pusher.rs @@ -216,11 +216,11 @@ pub async fn send_push_notice( notify = Some(n); } - if notify == Some(true) { + if notify == Some(true) { send_notice(unread, pusher, tweaks, pdu, db).await?; } // Else the event triggered no actions - + Ok(()) } diff --git a/src/database/rooms.rs b/src/database/rooms.rs index 81697e33..ba987906 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -11,10 +11,10 @@ use ruma::{ events::{ ignored_user_list, room::{create::CreateEventContent, member, message}, - EventType, + AnyStrippedStateEvent, EventType, }, serde::{to_canonical_value, CanonicalJsonObject, CanonicalJsonValue, Raw}, - EventId, RoomAliasId, RoomId, RoomVersionId, ServerName, UserId, + uint, EventId, RoomAliasId, RoomId, RoomVersionId, ServerName, UserId, }; use sled::IVec; use state_res::{Event, StateMap}; @@ -51,8 +51,8 @@ pub struct Rooms { pub(super) userroomid_joined: sled::Tree, pub(super) roomuserid_joined: sled::Tree, pub(super) roomuseroncejoinedids: sled::Tree, - pub(super) userroomid_invited: sled::Tree, - pub(super) roomuserid_invited: sled::Tree, + pub(super) userroomid_invitestate: sled::Tree, + pub(super) roomuserid_invitecount: sled::Tree, pub(super) userroomid_left: sled::Tree, /// Remember the current state hash of a room. @@ -145,12 +145,12 @@ impl Rooms { /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). #[tracing::instrument(skip(self))] - pub fn state_get( + pub fn state_get_id( &self, shortstatehash: u64, event_type: &EventType, state_key: &str, - ) -> Result> { + ) -> Result> { let mut key = event_type.as_ref().as_bytes().to_vec(); key.push(0xff); key.extend_from_slice(&state_key.as_bytes()); @@ -161,7 +161,8 @@ impl Rooms { let mut stateid = shortstatehash.to_be_bytes().to_vec(); stateid.extend_from_slice(&shortstatekey); - self.stateid_shorteventid + Ok(self + .stateid_shorteventid .get(&stateid)? .map(|bytes| self.shorteventid_eventid.get(&bytes).ok().flatten()) .flatten() @@ -178,13 +179,24 @@ impl Rooms { ) }) .map(|r| r.ok()) - .flatten() - .map_or(Ok(None), |event_id| self.get_pdu(&event_id)) + .flatten()) } else { Ok(None) } } + /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). + #[tracing::instrument(skip(self))] + pub fn state_get( + &self, + shortstatehash: u64, + event_type: &EventType, + state_key: &str, + ) -> Result> { + self.state_get_id(shortstatehash, event_type, state_key)? + .map_or(Ok(None), |event_id| self.get_pdu(&event_id)) + } + /// Returns the state hash for this pdu. #[tracing::instrument(skip(self))] pub fn pdu_shortstatehash(&self, event_id: &EventId) -> Result> { @@ -354,6 +366,21 @@ impl Rooms { } } + /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). + #[tracing::instrument(skip(self))] + pub fn room_state_get_id( + &self, + room_id: &RoomId, + event_type: &EventType, + state_key: &str, + ) -> Result> { + if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { + self.state_get_id(current_shortstatehash, event_type, state_key) + } else { + Ok(None) + } + } + /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). #[tracing::instrument(skip(self))] pub fn room_state_get( @@ -395,7 +422,7 @@ impl Rooms { } /// Returns the json of a pdu. - pub fn get_pdu_json(&self, event_id: &EventId) -> Result> { + pub fn get_pdu_json(&self, event_id: &EventId) -> Result> { self.eventid_pduid .get(event_id.as_bytes())? .map_or_else::, _, _>( @@ -666,29 +693,64 @@ impl Rooms { // if the state_key fails let target_user_id = UserId::try_from(state_key.clone()) .expect("This state_key was previously validated"); + + let membership = serde_json::from_value::( + pdu.content + .get("membership") + .ok_or_else(|| { + Error::BadRequest( + ErrorKind::InvalidParam, + "Invalid member event content", + ) + })? + .clone(), + ) + .map_err(|_| { + Error::BadRequest( + ErrorKind::InvalidParam, + "Invalid membership state content.", + ) + })?; + + let invite_state = match membership { + member::MembershipState::Invite => { + let mut state = Vec::new(); + // Add recommended events + if let Some(e) = + self.room_state_get(&pdu.room_id, &EventType::RoomJoinRules, "")? + { + state.push(e.to_stripped_state_event()); + } + if let Some(e) = self.room_state_get( + &pdu.room_id, + &EventType::RoomCanonicalAlias, + "", + )? { + state.push(e.to_stripped_state_event()); + } + if let Some(e) = + self.room_state_get(&pdu.room_id, &EventType::RoomAvatar, "")? + { + state.push(e.to_stripped_state_event()); + } + if let Some(e) = + self.room_state_get(&pdu.room_id, &EventType::RoomName, "")? + { + state.push(e.to_stripped_state_event()); + } + Some(state) + } + _ => None, + }; + // Update our membership info, we do this here incase a user is invited // and immediately leaves we need the DB to record the invite event for auth self.update_membership( &pdu.room_id, &target_user_id, - serde_json::from_value::( - pdu.content - .get("membership") - .ok_or_else(|| { - Error::BadRequest( - ErrorKind::InvalidParam, - "Invalid member event content", - ) - })? - .clone(), - ) - .map_err(|_| { - Error::BadRequest( - ErrorKind::InvalidParam, - "Invalid membership state content.", - ) - })?, + membership, &pdu.sender, + invite_state, &db.account_data, &db.globals, )?; @@ -1044,10 +1106,10 @@ impl Rooms { // Our depth is the maximum depth of prev_events + 1 let depth = prev_events .iter() - .filter_map(|event_id| Some(self.get_pdu_json(event_id).ok()??.get("depth")?.as_u64()?)) + .filter_map(|event_id| Some(self.get_pdu(event_id).ok()??.depth)) .max() - .unwrap_or(0_u64) - + 1; + .unwrap_or(uint!(0)) + + uint!(1); let mut unsigned = unsigned.unwrap_or_default(); if let Some(state_key) = &state_key { @@ -1071,9 +1133,7 @@ impl Rooms { content, state_key, prev_events, - depth: depth - .try_into() - .map_err(|_| Error::bad_database("Depth is invalid"))?, + depth, auth_events: auth_events .iter() .map(|(_, pdu)| pdu.event_id.clone()) @@ -1384,6 +1444,7 @@ impl Rooms { user_id: &UserId, membership: member::MembershipState, sender: &UserId, + invite_state: Option>>, account_data: &super::account_data::AccountData, globals: &super::globals::Globals, ) -> Result<()> { @@ -1487,8 +1548,8 @@ impl Rooms { self.roomserverids.insert(&roomserver_id, &[])?; self.userroomid_joined.insert(&userroom_id, &[])?; self.roomuserid_joined.insert(&roomuser_id, &[])?; - self.userroomid_invited.remove(&userroom_id)?; - self.roomuserid_invited.remove(&roomuser_id)?; + self.userroomid_invitestate.remove(&userroom_id)?; + self.roomuserid_invitecount.remove(&roomuser_id)?; self.userroomid_left.remove(&userroom_id)?; } member::MembershipState::Invite => { @@ -1508,8 +1569,13 @@ impl Rooms { } self.roomserverids.insert(&roomserver_id, &[])?; - self.userroomid_invited.insert(&userroom_id, &[])?; - self.roomuserid_invited.insert(&roomuser_id, &[])?; + self.userroomid_invitestate.insert( + &userroom_id, + serde_json::to_vec(&invite_state.unwrap_or_default()) + .expect("state to bytes always works"), + )?; + self.roomuserid_invitecount + .insert(&roomuser_id, &globals.next_count()?.to_be_bytes())?; self.userroomid_joined.remove(&userroom_id)?; self.roomuserid_joined.remove(&roomuser_id)?; self.userroomid_left.remove(&userroom_id)?; @@ -1526,8 +1592,8 @@ impl Rooms { self.userroomid_left.insert(&userroom_id, &[])?; self.userroomid_joined.remove(&userroom_id)?; self.roomuserid_joined.remove(&roomuser_id)?; - self.userroomid_invited.remove(&userroom_id)?; - self.roomuserid_invited.remove(&roomuser_id)?; + self.userroomid_invitestate.remove(&userroom_id)?; + self.roomuserid_invitecount.remove(&roomuser_id)?; } _ => {} } @@ -1797,7 +1863,7 @@ impl Rooms { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xff); - self.roomuserid_invited + self.roomuserid_invitecount .scan_prefix(prefix) .keys() .map(|key| { @@ -1816,6 +1882,22 @@ impl Rooms { }) } + /// Returns an iterator over all invited members of a room. + #[tracing::instrument(skip(self))] + pub fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { + let mut key = room_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(user_id.as_bytes()); + + self.roomuserid_invitecount + .get(key)? + .map_or(Ok(None), |bytes| { + Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| { + Error::bad_database("Invalid invitecount in db.") + })?)) + }) + } + /// Returns an iterator over all rooms this user joined. #[tracing::instrument(skip(self))] pub fn rooms_joined(&self, user_id: &UserId) -> impl Iterator> { @@ -1840,27 +1922,32 @@ impl Rooms { /// Returns an iterator over all rooms a user was invited to. #[tracing::instrument(skip(self))] - pub fn rooms_invited(&self, user_id: &UserId) -> impl Iterator> { + pub fn rooms_invited( + &self, + user_id: &UserId, + ) -> impl Iterator>)>> { let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xff); - self.userroomid_invited - .scan_prefix(prefix) - .keys() - .map(|key| { - Ok(RoomId::try_from( - utils::string_from_bytes( - &key? - .rsplit(|&b| b == 0xff) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| { - Error::bad_database("Room ID in userroomid_invited is invalid unicode.") - })?, + self.userroomid_invitestate.scan_prefix(prefix).map(|r| { + let (key, state) = r?; + let room_id = RoomId::try_from( + utils::string_from_bytes( + &key.rsplit(|&b| b == 0xff) + .next() + .expect("rsplit always returns an element"), ) - .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?) - }) + .map_err(|_| { + Error::bad_database("Room ID in userroomid_invited is invalid unicode.") + })?, + ) + .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?; + + let state = serde_json::from_slice(&state) + .map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?; + + Ok((room_id, state)) + }) } /// Returns an iterator over all rooms a user left. @@ -1906,7 +1993,7 @@ impl Rooms { userroom_id.push(0xff); userroom_id.extend_from_slice(room_id.as_bytes()); - Ok(self.userroomid_invited.get(userroom_id)?.is_some()) + Ok(self.userroomid_invitestate.get(userroom_id)?.is_some()) } pub fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result { diff --git a/src/main.rs b/src/main.rs index 4ccc0251..6fd04ce6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -167,6 +167,7 @@ fn setup_rocket() -> (rocket::Rocket, Config) { server_server::get_event_route, server_server::get_missing_events_route, server_server::get_room_state_ids_route, + server_server::create_invite_route, server_server::get_profile_information_route, ], ) diff --git a/src/server_server.rs b/src/server_server.rs index 4a93a3d2..1fad54e4 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -10,20 +10,24 @@ use ruma::{ federation::{ directory::{get_public_rooms, get_public_rooms_filtered}, discovery::{ - get_remote_server_keys, get_server_keys, - get_server_version::v1 as get_server_version, ServerSigningKeys, VerifyKey, + get_remote_server_keys, get_server_keys, get_server_version, ServerSigningKeys, + VerifyKey, }, event::{get_event, get_missing_events, get_room_state_ids}, + membership::create_invite, query::get_profile_information, transactions::send_transaction_message, }, OutgoingRequest, }, directory::{IncomingFilter, IncomingRoomNetwork}, - events::{room::create::CreateEventContent, EventType}, + events::{ + room::{create::CreateEventContent, member::MembershipState}, + EventType, + }, serde::{to_canonical_value, Raw}, signatures::CanonicalJsonValue, - EventId, RoomId, ServerName, ServerSigningKeyId, UserId, + EventId, RoomId, RoomVersionId, ServerName, ServerSigningKeyId, UserId, }; use state_res::{Event, EventMap, StateMap}; use std::{ @@ -332,13 +336,13 @@ pub async fn request_well_known( #[tracing::instrument(skip(db))] pub fn get_server_version_route( db: State<'_, Database>, -) -> ConduitResult { +) -> ConduitResult { if !db.globals.allow_federation() { return Err(Error::bad_config("Federation is disabled.")); } - Ok(get_server_version::Response { - server: Some(get_server_version::Server { + Ok(get_server_version::v1::Response { + server: Some(get_server_version::v1::Server { name: Some("Conduit".to_owned()), version: Some(env!("CARGO_PKG_VERSION").to_owned()), }), @@ -1406,12 +1410,9 @@ pub fn get_event_route<'a>( origin: db.globals.server_name().to_owned(), origin_server_ts: SystemTime::now(), pdu: PduEvent::convert_to_outgoing_federation_event( - serde_json::from_value( - db.rooms - .get_pdu_json(&body.event_id)? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "Event not found."))?, - ) - .map_err(|_| Error::bad_database("Invalid pdu in database."))?, + db.rooms + .get_pdu_json(&body.event_id)? + .ok_or(Error::BadRequest(ErrorKind::NotFound, "Event not found."))?, ), } .into()) @@ -1438,9 +1439,10 @@ pub fn get_missing_events_route<'a>( if let Some(pdu) = db.rooms.get_pdu_json(&queued_events[i])? { if body.earliest_events.contains( &serde_json::from_value( - pdu.get("event_id") - .cloned() - .ok_or_else(|| Error::bad_database("Event in db has no event_id field."))?, + serde_json::to_value(pdu.get("event_id").cloned().ok_or_else(|| { + Error::bad_database("Event in db has no event_id field.") + })?) + .expect("canonical json is valid json value"), ) .map_err(|_| Error::bad_database("Invalid event_id field in pdu in db."))?, ) { @@ -1449,16 +1451,14 @@ pub fn get_missing_events_route<'a>( } queued_events.extend_from_slice( &serde_json::from_value::>( - pdu.get("prev_events").cloned().ok_or_else(|| { - Error::bad_database("Invalid prev_events field of pdu in db.") - })?, + serde_json::to_value(pdu.get("prev_events").cloned().ok_or_else(|| { + Error::bad_database("Event in db has no prev_events field.") + })?) + .expect("canonical json is valid json value"), ) .map_err(|_| Error::bad_database("Invalid prev_events content in pdu in db."))?, ); - events.push(PduEvent::convert_to_outgoing_federation_event( - serde_json::from_value(pdu) - .map_err(|_| Error::bad_database("Invalid pdu in database."))?, - )); + events.push(PduEvent::convert_to_outgoing_federation_event(pdu)); } i += 1; } @@ -1518,6 +1518,93 @@ pub fn get_room_state_ids_route<'a>( .into()) } +#[cfg_attr( + feature = "conduit_bin", + put("/_matrix/federation/v2/invite/<_>/<_>", data = "") +)] +#[tracing::instrument(skip(db, body))] +pub fn create_invite_route<'a>( + db: State<'a, Database>, + body: Ruma, +) -> ConduitResult { + if body.room_version < RoomVersionId::Version6 { + return Err(Error::BadRequest( + ErrorKind::IncompatibleRoomVersion { + room_version: body.room_version.clone(), + }, + "Server does not support this room version.", + )); + } + + let mut signed_event = utils::to_canonical_object(&body.event) + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invite event is invalid."))?; + + ruma::signatures::hash_and_sign_event( + db.globals.server_name().as_str(), + db.globals.keypair(), + &mut signed_event, + &body.room_version, + ) + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Failed to sign event."))?; + + let sender = serde_json::from_value( + serde_json::to_value( + signed_event + .get("sender") + .ok_or_else(|| { + Error::BadRequest(ErrorKind::InvalidParam, "Event had no sender field.") + })? + .clone(), + ) + .expect("CanonicalJsonValue to serde_json::Value always works"), + ) + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "sender is not a user id."))?; + let invited_user = serde_json::from_value( + serde_json::to_value( + signed_event + .get("state_key") + .ok_or_else(|| { + Error::BadRequest(ErrorKind::InvalidParam, "Event had no state_key field.") + })? + .clone(), + ) + .expect("CanonicalJsonValue to serde_json::Value always works"), + ) + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "state_key is not a user id."))?; + + let mut invite_state = body.invite_room_state.clone(); + + let mut event = serde_json::from_str::>( + &body.event.json().to_string(), + ) + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid invite event bytes."))?; + + event.insert("event_id".to_owned(), "$dummy".into()); + invite_state.push( + serde_json::from_value::(event.into()) + .map_err(|e| { + warn!("Invalid invite event: {}", e); + Error::BadRequest(ErrorKind::InvalidParam, "Invalid invite event.") + })? + .to_stripped_state_event(), + ); + + db.rooms.update_membership( + &body.room_id, + &invited_user, + MembershipState::Invite, + &sender, + Some(invite_state), + &db.account_data, + &db.globals, + )?; + + Ok(create_invite::v2::Response { + event: PduEvent::convert_to_outgoing_federation_event(signed_event), + } + .into()) +} + #[cfg_attr( feature = "conduit_bin", get("/_matrix/federation/v1/query/profile", data = "")