1
0
Fork 0
mirror of https://gitlab.com/famedly/conduit.git synced 2024-12-26 08:44:17 +01:00

feat: save pdus

PDUs are saved in a pduid -> pdus map. roomid -> pduleaves keeps track
of the leaves of the event graph and eventid -> pduid maps event ids to
pdus.
This commit is contained in:
timokoesters 2020-04-03 17:27:08 +02:00
parent 22cca206ba
commit fa3226898c
No known key found for this signature in database
GPG key ID: 24DA7517711A2BA4
8 changed files with 309 additions and 54 deletions

26
Cargo.lock generated
View file

@ -496,7 +496,9 @@ dependencies = [
"ruma-api",
"ruma-client-api",
"ruma-events",
"ruma-federation-api",
"ruma-identifiers",
"ruma-signatures",
"serde_json",
"sled",
]
@ -875,6 +877,19 @@ dependencies = [
"syn 1.0.17",
]
[[package]]
name = "ruma-federation-api"
version = "0.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2a73a23c4d9243be91e101e1942f4d9cd913ef5156d756bafdfe2409ee23d72"
dependencies = [
"js_int",
"ruma-events",
"ruma-identifiers",
"serde",
"serde_json",
]
[[package]]
name = "ruma-identifiers"
version = "0.14.1"
@ -886,6 +901,17 @@ dependencies = [
"url 2.1.1",
]
[[package]]
name = "ruma-signatures"
version = "0.5.0"
source = "git+https://github.com/ruma/ruma-signatures.git#a08fc01c0bce63f913e1b4b1a673169d59738b63"
dependencies = [
"base64 0.11.0",
"ring",
"serde_json",
"untrusted",
]
[[package]]
name = "rust-argon2"
version = "0.7.0"

View file

@ -19,4 +19,5 @@ ruma-api = "0.15.0"
ruma-events = "0.18.0"
js_int = "0.1.3"
serde_json = "1.0.50"
ruma-signatures = "0.5.0"
ruma-signatures = { git = "https://github.com/ruma/ruma-signatures.git" }
ruma-federation-api = "0.0.1"

View file

@ -1,3 +1,7 @@
[global]
address = "0.0.0.0"
port = 14004
#[global.tls]
#certs = "/etc/ssl/certs/ssl-cert-snakeoil.pem"
#key = "/etc/ssl/private/ssl-cert-snakeoil.key"
#certs = "/etc/letsencrypt/live/matrixtesting.koesters.xyz/fullchain.pem"
#key = "/etc/letsencrypt/live/matrixtesting.koesters.xyz/privkey.pem"

View file

@ -1,7 +1,9 @@
use crate::{utils, Database};
use log::debug;
use ruma_events::collections::all::Event;
use ruma_federation_api::RoomV3Pdu;
use ruma_identifiers::{EventId, RoomId, UserId};
use std::convert::TryInto;
use std::convert::{TryFrom, TryInto};
pub struct Data {
hostname: String,
@ -99,14 +101,152 @@ impl Data {
.unwrap();
}
/// Create a new room event.
pub fn event_add(&self, room_id: &RoomId, event_id: &EventId, event: &Event) {
let mut key = room_id.to_string().as_bytes().to_vec();
key.extend_from_slice(event_id.to_string().as_bytes());
pub fn pdu_get(&self, event_id: &EventId) -> Option<RoomV3Pdu> {
self.db
.roomid_eventid_event
.insert(&key, &*serde_json::to_string(event).unwrap())
.eventid_pduid
.get(event_id.to_string().as_bytes())
.unwrap()
.map(|pdu_id| {
serde_json::from_slice(
&self
.db
.pduid_pdus
.get(pdu_id)
.unwrap()
.expect("eventid_pduid in db is valid"),
)
.expect("pdu is valid")
})
}
// TODO: Make sure this isn't called twice in parallel
pub fn pdu_leaves_replace(&self, room_id: &RoomId, event_id: &EventId) -> Vec<EventId> {
let event_ids = self
.db
.roomid_pduleaves
.get_iter(room_id.to_string().as_bytes())
.values()
.map(|pdu_id| {
EventId::try_from(&*utils::string_from_bytes(&pdu_id.unwrap()))
.expect("pdu leaves are valid event ids")
})
.collect();
self.db
.roomid_pduleaves
.clear(room_id.to_string().as_bytes());
self.db.roomid_pduleaves.add(
&room_id.to_string().as_bytes(),
(*event_id.to_string()).into(),
);
event_ids
}
/// Add a persisted data unit from this homeserver
pub fn pdu_append(&self, event_id: &EventId, room_id: &RoomId, event: Event) {
// prev_events are the leaves of the current graph. This method removes all leaves from the
// room and replaces them with our event
let prev_events = self.pdu_leaves_replace(room_id, event_id);
// Our depth is the maximum depth of prev_events + 1
let depth = prev_events
.iter()
.map(|event_id| {
self.pdu_get(event_id)
.expect("pdu in prev_events is valid")
.depth
.into()
})
.max()
.unwrap_or(0_u64)
+ 1;
let mut pdu_value = serde_json::to_value(&event).expect("message event can be serialized");
let pdu = pdu_value.as_object_mut().unwrap();
pdu.insert(
"prev_events".to_owned(),
prev_events
.iter()
.map(|id| id.to_string())
.collect::<Vec<_>>()
.into(),
);
pdu.insert("origin".to_owned(), self.hostname().into());
pdu.insert("depth".to_owned(), depth.into());
pdu.insert("auth_events".to_owned(), vec!["$auth_eventid"].into()); // TODO
pdu.insert(
"hashes".to_owned(),
"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA".into(),
); // TODO
pdu.insert("signatures".to_owned(), "signature".into()); // TODO
// The new value will need a new index. We store the last used index in 'n' + id
let mut count_key: Vec<u8> = vec![b'n'];
count_key.extend_from_slice(&room_id.to_string().as_bytes());
// Increment the last index and use that
let index = utils::u64_from_bytes(
&self
.db
.pduid_pdus
.update_and_fetch(&count_key, utils::increment)
.unwrap()
.unwrap(),
);
let mut pdu_id = vec![b'd'];
pdu_id.extend_from_slice(room_id.to_string().as_bytes());
pdu_id.push(b'#'); // Add delimiter so we don't find rooms starting with the same id
pdu_id.extend_from_slice(index.to_string().as_bytes());
self.db
.pduid_pdus
.insert(&pdu_id, dbg!(&*serde_json::to_string(&pdu).unwrap()))
.unwrap();
self.db
.eventid_pduid
.insert(event_id.to_string(), pdu_id.clone())
.unwrap();
}
/// Returns a vector of all PDUs.
pub fn pdus_all(&self) -> Vec<RoomV3Pdu> {
self.pdus_since(
self.db
.eventid_pduid
.iter()
.values()
.next()
.unwrap()
.map(|key| utils::string_from_bytes(&key))
.expect("there should be at least one pdu"),
)
}
/// Returns a vector of all events that happened after the event with id `since`.
pub fn pdus_since(&self, since: String) -> Vec<RoomV3Pdu> {
let mut pdus = Vec::new();
if let Some(room_id) = since.rsplitn(2, '#').nth(1) {
let mut current = since.clone();
while let Some((key, value)) = self.db.pduid_pdus.get_gt(current).unwrap() {
if key.starts_with(&room_id.to_string().as_bytes()) {
current = utils::string_from_bytes(&key);
} else {
break;
}
pdus.push(serde_json::from_slice(&value).expect("pdu is valid"));
}
} else {
debug!("event at `since` not found");
}
pdus
}
pub fn debug(&self) {

View file

@ -15,11 +15,17 @@ impl MultiValue {
// Data keys start with d
let mut key = vec![b'd'];
key.extend_from_slice(id.as_ref());
key.push(0xff); // Add delimiter so we don't find usernames starting with the same id
key.push(0xff); // Add delimiter so we don't find keys starting with the same id
self.0.scan_prefix(key)
}
pub fn clear(&self, id: &[u8]) {
for key in self.get_iter(id).keys() {
self.0.remove(key.unwrap()).unwrap();
}
}
/// Add another value to the id.
pub fn add(&self, id: &[u8], value: IVec) {
// The new value will need a new index. We store the last used index in 'n' + id
@ -48,7 +54,9 @@ pub struct Database {
pub userid_deviceids: MultiValue,
pub deviceid_token: sled::Tree,
pub token_userid: sled::Tree,
pub roomid_eventid_event: sled::Tree,
pub pduid_pdus: sled::Tree,
pub roomid_pduleaves: MultiValue,
pub eventid_pduid: sled::Tree,
_db: sled::Db,
}
@ -67,7 +75,9 @@ impl Database {
userid_deviceids: MultiValue(db.open_tree("userid_deviceids").unwrap()),
deviceid_token: db.open_tree("deviceid_token").unwrap(),
token_userid: db.open_tree("token_userid").unwrap(),
roomid_eventid_event: db.open_tree("roomid_eventid_event").unwrap(),
pduid_pdus: db.open_tree("pduid_pdus").unwrap(),
roomid_pduleaves: MultiValue(db.open_tree("roomid_pduleaves").unwrap()),
eventid_pduid: db.open_tree("eventid_pduid").unwrap(),
_db: db,
}
}
@ -81,7 +91,7 @@ impl Database {
String::from_utf8_lossy(&v),
);
}
println!("# UserId -> DeviceIds:");
println!("\n# UserId -> DeviceIds:");
for (k, v) in self.userid_deviceids.iter_all().map(|r| r.unwrap()) {
println!(
"{} -> {}",
@ -89,7 +99,7 @@ impl Database {
String::from_utf8_lossy(&v),
);
}
println!("# DeviceId -> Token:");
println!("\n# DeviceId -> Token:");
for (k, v) in self.deviceid_token.iter().map(|r| r.unwrap()) {
println!(
"{} -> {}",
@ -97,7 +107,7 @@ impl Database {
String::from_utf8_lossy(&v),
);
}
println!("# Token -> UserId:");
println!("\n# Token -> UserId:");
for (k, v) in self.token_userid.iter().map(|r| r.unwrap()) {
println!(
"{} -> {}",
@ -105,8 +115,24 @@ impl Database {
String::from_utf8_lossy(&v),
);
}
println!("# RoomId + EventId -> Event:");
for (k, v) in self.roomid_eventid_event.iter().map(|r| r.unwrap()) {
println!("\n# RoomId -> PDU leaves:");
for (k, v) in self.roomid_pduleaves.iter_all().map(|r| r.unwrap()) {
println!(
"{} -> {}",
String::from_utf8_lossy(&k),
String::from_utf8_lossy(&v),
);
}
println!("\n# PDU Id -> PDUs:");
for (k, v) in self.pduid_pdus.iter().map(|r| r.unwrap()) {
println!(
"{} -> {}",
String::from_utf8_lossy(&k),
String::from_utf8_lossy(&v),
);
}
println!("\n# EventId -> PDU Id:");
for (k, v) in self.eventid_pduid.iter().map(|r| r.unwrap()) {
println!(
"{} -> {}",
String::from_utf8_lossy(&k),

View file

@ -8,12 +8,12 @@ pub use data::Data;
pub use database::Database;
use log::debug;
use rocket::{get, post, put, routes, State};
use rocket::{get, options, post, put, routes, State};
use ruma_client_api::{
error::{Error, ErrorKind},
r0::{
account::register, alias::get_alias, membership::join_room_by_id,
message::create_message_event, session::login,
message::create_message_event, session::login, sync::sync_events,
},
unversioned::get_supported_versions,
};
@ -24,20 +24,13 @@ use serde_json::map::Map;
use std::{
collections::HashMap,
convert::{TryFrom, TryInto},
path::PathBuf,
};
#[get("/_matrix/client/versions")]
fn get_supported_versions_route() -> MatrixResult<get_supported_versions::Response> {
MatrixResult(Ok(get_supported_versions::Response {
versions: vec![
"r0.0.1".to_owned(),
"r0.1.0".to_owned(),
"r0.2.0".to_owned(),
"r0.3.0".to_owned(),
"r0.4.0".to_owned(),
"r0.5.0".to_owned(),
"r0.6.0".to_owned(),
],
versions: vec!["r0.6.0".to_owned()],
unstable_features: HashMap::new(),
}))
}
@ -219,9 +212,9 @@ fn create_message_event_route(
body: Ruma<create_message_event::Request>,
) -> MatrixResult<create_message_event::Response> {
// Construct event
let event = Event::RoomMessage(MessageEvent {
let mut event = Event::RoomMessage(MessageEvent {
content: body.data.clone().into_result().unwrap(),
event_id: event_id.clone(),
event_id: EventId::try_from("$thiswillbefilledinlater").unwrap(),
origin_server_ts: utils::millis_since_unix_epoch(),
room_id: Some(body.room_id.clone()),
sender: body.user_id.clone().expect("user is authenticated"),
@ -229,18 +222,78 @@ fn create_message_event_route(
});
// Generate event id
dbg!(ruma_signatures::reference_hash(event));
let event_id = EventId::try_from(&*format!(
"${}",
ruma_signatures::reference_hash(&serde_json::to_value(&event).unwrap())
.expect("ruma can calculate reference hashes")
))
.expect("ruma's reference hashes are correct");
let event_id = EventId::try_from("$TODOrandomeventid:localhost").unwrap();
data.event_add(&body.room_id, &event_id, &event);
// Insert event id
if let Event::RoomMessage(message) = &mut event {
message.event_id = event_id.clone();
}
// Add PDU to the graph
data.pdu_append(&event_id, &body.room_id, event);
MatrixResult(Ok(create_message_event::Response { event_id }))
}
#[get("/_matrix/client/r0/sync")]
fn sync_route(data: State<Data>) -> MatrixResult<sync_events::Response> {
let pdus = data.pdus_all();
let mut joined_rooms = HashMap::new();
joined_rooms.insert(
"!roomid:localhost".try_into().unwrap(),
sync_events::JoinedRoom {
account_data: sync_events::AccountData { events: Vec::new() },
summary: sync_events::RoomSummary {
heroes: Vec::new(),
joined_member_count: None,
invited_member_count: None,
},
unread_notifications: sync_events::UnreadNotificationsCount {
highlight_count: None,
notification_count: None,
},
timeline: sync_events::Timeline {
limited: None,
prev_batch: None,
events: todo!(),
},
state: sync_events::State { events: Vec::new() },
ephemeral: sync_events::Ephemeral { events: Vec::new() },
},
);
MatrixResult(Ok(sync_events::Response {
next_batch: String::new(),
rooms: sync_events::Rooms {
leave: Default::default(),
join: joined_rooms,
invite: Default::default(),
},
presence: sync_events::Presence { events: Vec::new() },
device_lists: Default::default(),
device_one_time_keys_count: Default::default(),
to_device: sync_events::ToDevice { events: Vec::new() },
}))
}
#[options("/<_segments..>")]
fn options_route(_segments: PathBuf) -> MatrixResult<create_message_event::Response> {
MatrixResult(Err(Error {
kind: ErrorKind::NotFound,
message: "Room not found.".to_owned(),
status_code: http::StatusCode::NOT_FOUND,
}))
}
fn main() {
// Log info by default
if let Err(_) = std::env::var("RUST_LOG") {
std::env::set_var("RUST_LOG", "info");
std::env::set_var("RUST_LOG", "matrixserver=debug,info");
}
pretty_env_logger::init();
@ -257,6 +310,8 @@ fn main() {
get_alias_route,
join_room_by_id_route,
create_message_event_route,
sync_route,
options_route,
],
)
.manage(data)

View file

@ -1,28 +1,26 @@
use {
rocket::data::{FromDataSimple, Outcome},
rocket::http::Status,
rocket::response::Responder,
rocket::Outcome::*,
rocket::Request,
rocket::State,
ruma_api::{
error::{FromHttpRequestError, FromHttpResponseError},
Endpoint, Outgoing,
},
ruma_client_api::error::Error,
ruma_identifiers::UserId,
std::ops::Deref,
std::{
convert::{TryFrom, TryInto},
io::{Cursor, Read},
},
use rocket::{
data::{FromDataSimple, Outcome},
http::Status,
response::Responder,
Outcome::*,
Request, State,
};
use ruma_api::{
error::{FromHttpRequestError, FromHttpResponseError},
Endpoint, Outgoing,
};
use ruma_client_api::error::Error;
use ruma_identifiers::UserId;
use std::{
convert::{TryFrom, TryInto},
io::{Cursor, Read},
ops::Deref,
};
const MESSAGE_LIMIT: u64 = 65535;
/// This struct converts rocket requests into ruma structs by converting them into http requests
/// first.
#[derive(Debug)]
pub struct Ruma<T: Outgoing> {
body: T::Incoming,
pub user_id: Option<UserId>,

View file

@ -24,6 +24,11 @@ pub fn increment(old: Option<&[u8]>) -> Option<Vec<u8>> {
Some(number.to_be_bytes().to_vec())
}
pub fn u64_from_bytes(bytes: &[u8]) -> u64 {
let array: [u8; 8] = bytes.try_into().expect("bytes are valid u64");
u64::from_be_bytes(array)
}
pub fn string_from_bytes(bytes: &[u8]) -> String {
String::from_utf8(bytes.to_vec()).expect("bytes are valid utf8")
}