2022-05-20 20:37:32 +02:00
|
|
|
use std::{
|
2023-03-30 17:18:59 +02:00
|
|
|
net::{IpAddr, SocketAddr},
|
|
|
|
sync::Arc,
|
2022-05-20 20:37:32 +02:00
|
|
|
time::Duration,
|
|
|
|
};
|
2019-12-06 22:19:07 +01:00
|
|
|
|
2023-05-25 20:57:31 +02:00
|
|
|
use chrono::{NaiveDateTime, Utc};
|
2022-05-20 20:37:32 +02:00
|
|
|
use rmpv::Value;
|
2023-03-30 17:18:59 +02:00
|
|
|
use rocket::{
|
|
|
|
futures::{SinkExt, StreamExt},
|
|
|
|
Route,
|
|
|
|
};
|
2022-05-20 20:37:32 +02:00
|
|
|
use tokio::{
|
|
|
|
net::{TcpListener, TcpStream},
|
|
|
|
sync::mpsc::Sender,
|
|
|
|
};
|
|
|
|
use tokio_tungstenite::{
|
|
|
|
accept_hdr_async,
|
|
|
|
tungstenite::{handshake, Message},
|
|
|
|
};
|
|
|
|
|
|
|
|
use crate::{
|
2023-03-30 17:18:59 +02:00
|
|
|
auth::ClientIp,
|
2023-06-11 13:28:18 +02:00
|
|
|
db::{
|
|
|
|
models::{Cipher, Folder, Send as DbSend, User},
|
|
|
|
DbConn,
|
|
|
|
},
|
2022-05-20 20:37:32 +02:00
|
|
|
Error, CONFIG,
|
|
|
|
};
|
2018-09-13 20:59:51 +02:00
|
|
|
|
2023-03-30 17:18:59 +02:00
|
|
|
use once_cell::sync::Lazy;
|
|
|
|
|
|
|
|
static WS_USERS: Lazy<Arc<WebSocketUsers>> = Lazy::new(|| {
|
|
|
|
Arc::new(WebSocketUsers {
|
|
|
|
map: Arc::new(dashmap::DashMap::new()),
|
|
|
|
})
|
|
|
|
});
|
|
|
|
|
2023-06-11 13:28:18 +02:00
|
|
|
use super::{push_cipher_update, push_folder_update, push_logout, push_send_update, push_user_update};
|
|
|
|
|
2018-08-24 19:02:34 +02:00
|
|
|
pub fn routes() -> Vec<Route> {
|
2023-03-30 17:18:59 +02:00
|
|
|
routes![websockets_hub]
|
|
|
|
}
|
|
|
|
|
|
|
|
#[derive(FromForm, Debug)]
|
|
|
|
struct WsAccessToken {
|
|
|
|
access_token: Option<String>,
|
2018-09-11 17:09:33 +02:00
|
|
|
}
|
|
|
|
|
2023-03-30 17:18:59 +02:00
|
|
|
struct WSEntryMapGuard {
|
|
|
|
users: Arc<WebSocketUsers>,
|
|
|
|
user_uuid: String,
|
|
|
|
entry_uuid: uuid::Uuid,
|
|
|
|
addr: IpAddr,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl WSEntryMapGuard {
|
|
|
|
fn new(users: Arc<WebSocketUsers>, user_uuid: String, entry_uuid: uuid::Uuid, addr: IpAddr) -> Self {
|
|
|
|
Self {
|
|
|
|
users,
|
|
|
|
user_uuid,
|
|
|
|
entry_uuid,
|
|
|
|
addr,
|
|
|
|
}
|
2019-12-06 22:19:07 +01:00
|
|
|
}
|
2018-08-24 19:02:34 +02:00
|
|
|
}
|
|
|
|
|
2023-03-30 17:18:59 +02:00
|
|
|
impl Drop for WSEntryMapGuard {
|
|
|
|
fn drop(&mut self) {
|
|
|
|
info!("Closing WS connection from {}", self.addr);
|
|
|
|
if let Some(mut entry) = self.users.map.get_mut(&self.user_uuid) {
|
|
|
|
entry.retain(|(uuid, _)| uuid != &self.entry_uuid);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
#[get("/hub?<data..>")]
|
2023-07-04 20:12:50 +02:00
|
|
|
fn websockets_hub<'r>(
|
2023-03-30 17:18:59 +02:00
|
|
|
ws: rocket_ws::WebSocket,
|
|
|
|
data: WsAccessToken,
|
|
|
|
ip: ClientIp,
|
2023-04-12 15:59:05 +02:00
|
|
|
) -> Result<rocket_ws::Stream!['r], Error> {
|
2023-03-30 17:18:59 +02:00
|
|
|
let addr = ip.ip;
|
|
|
|
info!("Accepting Rocket WS connection from {addr}");
|
|
|
|
|
|
|
|
let Some(token) = data.access_token else { err_code!("Invalid claim", 401) };
|
|
|
|
let Ok(claims) = crate::auth::decode_login(&token) else { err_code!("Invalid token", 401) };
|
|
|
|
|
|
|
|
let (mut rx, guard) = {
|
|
|
|
let users = Arc::clone(&WS_USERS);
|
|
|
|
|
|
|
|
// Add a channel to send messages to this client to the map
|
|
|
|
let entry_uuid = uuid::Uuid::new_v4();
|
|
|
|
let (tx, rx) = tokio::sync::mpsc::channel::<Message>(100);
|
|
|
|
users.map.entry(claims.sub.clone()).or_default().push((entry_uuid, tx));
|
|
|
|
|
|
|
|
// Once the guard goes out of scope, the connection will have been closed and the entry will be deleted from the map
|
|
|
|
(rx, WSEntryMapGuard::new(users, claims.sub, entry_uuid, addr))
|
|
|
|
};
|
|
|
|
|
2023-04-12 15:59:05 +02:00
|
|
|
Ok({
|
|
|
|
rocket_ws::Stream! { ws => {
|
|
|
|
let mut ws = ws;
|
2023-03-30 17:18:59 +02:00
|
|
|
let _guard = guard;
|
|
|
|
let mut interval = tokio::time::interval(Duration::from_secs(15));
|
|
|
|
loop {
|
|
|
|
tokio::select! {
|
2023-04-12 15:59:05 +02:00
|
|
|
res = ws.next() => {
|
2023-03-30 17:18:59 +02:00
|
|
|
match res {
|
|
|
|
Some(Ok(message)) => {
|
|
|
|
match message {
|
|
|
|
// Respond to any pings
|
2023-04-12 15:59:05 +02:00
|
|
|
Message::Ping(ping) => yield Message::Pong(ping),
|
2023-03-30 17:18:59 +02:00
|
|
|
Message::Pong(_) => {/* Ignored */},
|
|
|
|
|
|
|
|
// We should receive an initial message with the protocol and version, and we will reply to it
|
|
|
|
Message::Text(ref message) => {
|
|
|
|
let msg = message.strip_suffix(RECORD_SEPARATOR as char).unwrap_or(message);
|
|
|
|
|
|
|
|
if serde_json::from_str(msg).ok() == Some(INITIAL_MESSAGE) {
|
2023-04-12 15:59:05 +02:00
|
|
|
yield Message::binary(INITIAL_RESPONSE);
|
2023-03-30 17:18:59 +02:00
|
|
|
continue;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// Just echo anything else the client sends
|
2023-04-12 15:59:05 +02:00
|
|
|
_ => yield message,
|
2023-03-30 17:18:59 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
_ => break,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
res = rx.recv() => {
|
|
|
|
match res {
|
2023-04-12 15:59:05 +02:00
|
|
|
Some(res) => yield res,
|
2023-03-30 17:18:59 +02:00
|
|
|
None => break,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-04-12 15:59:05 +02:00
|
|
|
_ = interval.tick() => yield Message::Ping(create_ping())
|
2023-03-30 17:18:59 +02:00
|
|
|
}
|
|
|
|
}
|
2023-04-12 15:59:05 +02:00
|
|
|
}}
|
|
|
|
})
|
2023-03-30 17:18:59 +02:00
|
|
|
}
|
|
|
|
|
2018-12-30 23:34:31 +01:00
|
|
|
//
|
|
|
|
// Websockets server
|
|
|
|
//
|
2018-08-30 17:43:46 +02:00
|
|
|
|
|
|
|
fn serialize(val: Value) -> Vec<u8> {
|
|
|
|
use rmpv::encode::write_value;
|
|
|
|
|
|
|
|
let mut buf = Vec::new();
|
|
|
|
write_value(&mut buf, &val).expect("Error encoding MsgPack");
|
|
|
|
|
|
|
|
// Add size bytes at the start
|
|
|
|
// Extracted from BinaryMessageFormat.js
|
2018-09-13 21:55:23 +02:00
|
|
|
let mut size: usize = buf.len();
|
2018-08-30 17:43:46 +02:00
|
|
|
let mut len_buf: Vec<u8> = Vec::new();
|
|
|
|
|
|
|
|
loop {
|
|
|
|
let mut size_part = size & 0x7f;
|
2018-09-13 21:55:23 +02:00
|
|
|
size >>= 7;
|
2018-08-30 17:43:46 +02:00
|
|
|
|
|
|
|
if size > 0 {
|
2018-09-13 21:55:23 +02:00
|
|
|
size_part |= 0x80;
|
2018-08-30 17:43:46 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
len_buf.push(size_part as u8);
|
|
|
|
|
2018-09-13 21:55:23 +02:00
|
|
|
if size == 0 {
|
2018-08-30 17:43:46 +02:00
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
len_buf.append(&mut buf);
|
|
|
|
len_buf
|
|
|
|
}
|
|
|
|
|
|
|
|
fn serialize_date(date: NaiveDateTime) -> Value {
|
|
|
|
let seconds: i64 = date.timestamp();
|
2019-02-20 17:54:18 +01:00
|
|
|
let nanos: i64 = date.timestamp_subsec_nanos().into();
|
2018-08-30 17:43:46 +02:00
|
|
|
let timestamp = nanos << 34 | seconds;
|
2019-01-25 18:23:51 +01:00
|
|
|
|
2019-01-16 22:14:17 +01:00
|
|
|
let bs = timestamp.to_be_bytes();
|
2018-08-30 17:43:46 +02:00
|
|
|
|
|
|
|
// -1 is Timestamp
|
|
|
|
// https://github.com/msgpack/msgpack/blob/master/spec.md#timestamp-extension-type
|
|
|
|
Value::Ext(-1, bs.to_vec())
|
|
|
|
}
|
|
|
|
|
|
|
|
fn convert_option<T: Into<Value>>(option: Option<T>) -> Value {
|
|
|
|
match option {
|
|
|
|
Some(a) => a.into(),
|
|
|
|
None => Value::Nil,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
const RECORD_SEPARATOR: u8 = 0x1e;
|
|
|
|
const INITIAL_RESPONSE: [u8; 3] = [0x7b, 0x7d, RECORD_SEPARATOR]; // {, }, <RS>
|
|
|
|
|
2022-05-20 20:37:32 +02:00
|
|
|
#[derive(Deserialize, Copy, Clone, Eq, PartialEq)]
|
|
|
|
struct InitialMessage<'a> {
|
|
|
|
protocol: &'a str,
|
2018-08-30 17:43:46 +02:00
|
|
|
version: i32,
|
|
|
|
}
|
|
|
|
|
2022-05-20 20:37:32 +02:00
|
|
|
static INITIAL_MESSAGE: InitialMessage<'static> = InitialMessage {
|
|
|
|
protocol: "messagepack",
|
|
|
|
version: 1,
|
|
|
|
};
|
2018-08-30 17:43:46 +02:00
|
|
|
|
2022-05-20 20:37:32 +02:00
|
|
|
// We attach the UUID to the sender so we can differentiate them when we need to remove them from the Vec
|
|
|
|
type UserSenders = (uuid::Uuid, Sender<Message>);
|
2018-08-30 17:43:46 +02:00
|
|
|
#[derive(Clone)]
|
|
|
|
pub struct WebSocketUsers {
|
2022-05-20 20:37:32 +02:00
|
|
|
map: Arc<dashmap::DashMap<String, Vec<UserSenders>>>,
|
2018-08-30 17:43:46 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
impl WebSocketUsers {
|
2022-05-20 20:37:32 +02:00
|
|
|
async fn send_update(&self, user_uuid: &str, data: &[u8]) {
|
|
|
|
if let Some(user) = self.map.get(user_uuid).map(|v| v.clone()) {
|
|
|
|
for (_, sender) in user.iter() {
|
2023-03-30 17:18:59 +02:00
|
|
|
if let Err(e) = sender.send(Message::binary(data)).await {
|
|
|
|
error!("Error sending WS update {e}");
|
2022-05-20 20:37:32 +02:00
|
|
|
}
|
2018-08-30 17:43:46 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// NOTE: The last modified date needs to be updated before calling these methods
|
2022-05-20 20:37:32 +02:00
|
|
|
pub async fn send_user_update(&self, ut: UpdateType, user: &User) {
|
2018-08-30 17:43:46 +02:00
|
|
|
let data = create_update(
|
2021-04-06 22:54:42 +02:00
|
|
|
vec![("UserId".into(), user.uuid.clone().into()), ("Date".into(), serialize_date(user.updated_at))],
|
2018-08-30 17:43:46 +02:00
|
|
|
ut,
|
2022-12-30 21:23:55 +01:00
|
|
|
None,
|
2018-08-30 17:43:46 +02:00
|
|
|
);
|
|
|
|
|
2022-05-20 20:37:32 +02:00
|
|
|
self.send_update(&user.uuid, &data).await;
|
2023-06-11 13:28:18 +02:00
|
|
|
|
|
|
|
if CONFIG.push_enabled() {
|
2023-06-16 23:34:16 +02:00
|
|
|
push_user_update(ut, user);
|
2023-06-11 13:28:18 +02:00
|
|
|
}
|
2018-08-30 17:43:46 +02:00
|
|
|
}
|
|
|
|
|
2023-06-16 23:34:16 +02:00
|
|
|
pub async fn send_logout(&self, user: &User, acting_device_uuid: Option<String>) {
|
2023-01-20 15:43:45 +01:00
|
|
|
let data = create_update(
|
|
|
|
vec![("UserId".into(), user.uuid.clone().into()), ("Date".into(), serialize_date(user.updated_at))],
|
|
|
|
UpdateType::LogOut,
|
2023-06-11 13:28:18 +02:00
|
|
|
acting_device_uuid.clone(),
|
2023-01-20 15:43:45 +01:00
|
|
|
);
|
|
|
|
|
|
|
|
self.send_update(&user.uuid, &data).await;
|
2023-06-11 13:28:18 +02:00
|
|
|
|
|
|
|
if CONFIG.push_enabled() {
|
2023-06-16 23:34:16 +02:00
|
|
|
push_logout(user, acting_device_uuid);
|
2023-06-11 13:28:18 +02:00
|
|
|
}
|
2023-01-20 15:43:45 +01:00
|
|
|
}
|
|
|
|
|
2023-06-11 13:28:18 +02:00
|
|
|
pub async fn send_folder_update(
|
|
|
|
&self,
|
|
|
|
ut: UpdateType,
|
|
|
|
folder: &Folder,
|
|
|
|
acting_device_uuid: &String,
|
|
|
|
conn: &mut DbConn,
|
|
|
|
) {
|
2018-08-30 17:43:46 +02:00
|
|
|
let data = create_update(
|
|
|
|
vec![
|
|
|
|
("Id".into(), folder.uuid.clone().into()),
|
|
|
|
("UserId".into(), folder.user_uuid.clone().into()),
|
|
|
|
("RevisionDate".into(), serialize_date(folder.updated_at)),
|
2018-09-13 21:55:23 +02:00
|
|
|
],
|
2018-08-30 17:43:46 +02:00
|
|
|
ut,
|
2022-12-30 21:23:55 +01:00
|
|
|
Some(acting_device_uuid.into()),
|
2018-08-30 17:43:46 +02:00
|
|
|
);
|
|
|
|
|
2022-05-20 20:37:32 +02:00
|
|
|
self.send_update(&folder.user_uuid, &data).await;
|
2023-06-11 13:28:18 +02:00
|
|
|
|
|
|
|
if CONFIG.push_enabled() {
|
|
|
|
push_folder_update(ut, folder, acting_device_uuid, conn).await;
|
|
|
|
}
|
2018-08-30 17:43:46 +02:00
|
|
|
}
|
|
|
|
|
2022-12-30 21:23:55 +01:00
|
|
|
pub async fn send_cipher_update(
|
|
|
|
&self,
|
|
|
|
ut: UpdateType,
|
|
|
|
cipher: &Cipher,
|
|
|
|
user_uuids: &[String],
|
|
|
|
acting_device_uuid: &String,
|
2023-05-25 20:57:31 +02:00
|
|
|
collection_uuids: Option<Vec<String>>,
|
2023-06-11 13:28:18 +02:00
|
|
|
conn: &mut DbConn,
|
2022-12-30 21:23:55 +01:00
|
|
|
) {
|
2018-08-30 17:43:46 +02:00
|
|
|
let org_uuid = convert_option(cipher.organization_uuid.clone());
|
2023-05-25 20:57:31 +02:00
|
|
|
// Depending if there are collections provided or not, we need to have different values for the following variables.
|
|
|
|
// The user_uuid should be `null`, and the revision date should be set to now, else the clients won't sync the collection change.
|
|
|
|
let (user_uuid, collection_uuids, revision_date) = if let Some(collection_uuids) = collection_uuids {
|
|
|
|
(
|
|
|
|
Value::Nil,
|
|
|
|
Value::Array(collection_uuids.into_iter().map(|v| v.into()).collect::<Vec<rmpv::Value>>()),
|
|
|
|
serialize_date(Utc::now().naive_utc()),
|
|
|
|
)
|
|
|
|
} else {
|
|
|
|
(convert_option(cipher.user_uuid.clone()), Value::Nil, serialize_date(cipher.updated_at))
|
|
|
|
};
|
2018-08-30 17:43:46 +02:00
|
|
|
|
|
|
|
let data = create_update(
|
|
|
|
vec![
|
|
|
|
("Id".into(), cipher.uuid.clone().into()),
|
|
|
|
("UserId".into(), user_uuid),
|
|
|
|
("OrganizationId".into(), org_uuid),
|
2023-05-25 20:57:31 +02:00
|
|
|
("CollectionIds".into(), collection_uuids),
|
|
|
|
("RevisionDate".into(), revision_date),
|
2018-09-13 21:55:23 +02:00
|
|
|
],
|
2018-08-30 17:43:46 +02:00
|
|
|
ut,
|
2022-12-30 21:23:55 +01:00
|
|
|
Some(acting_device_uuid.into()),
|
2018-08-30 17:43:46 +02:00
|
|
|
);
|
|
|
|
|
2021-08-03 17:33:59 +02:00
|
|
|
for uuid in user_uuids {
|
2022-05-20 20:37:32 +02:00
|
|
|
self.send_update(uuid, &data).await;
|
2021-08-03 17:33:59 +02:00
|
|
|
}
|
2023-06-11 13:28:18 +02:00
|
|
|
|
|
|
|
if CONFIG.push_enabled() && user_uuids.len() == 1 {
|
|
|
|
push_cipher_update(ut, cipher, acting_device_uuid, conn).await;
|
|
|
|
}
|
2021-08-03 17:33:59 +02:00
|
|
|
}
|
|
|
|
|
2023-06-16 23:34:16 +02:00
|
|
|
pub async fn send_send_update(
|
|
|
|
&self,
|
|
|
|
ut: UpdateType,
|
|
|
|
send: &DbSend,
|
|
|
|
user_uuids: &[String],
|
|
|
|
acting_device_uuid: &String,
|
|
|
|
conn: &mut DbConn,
|
|
|
|
) {
|
2021-08-03 17:33:59 +02:00
|
|
|
let user_uuid = convert_option(send.user_uuid.clone());
|
|
|
|
|
|
|
|
let data = create_update(
|
|
|
|
vec![
|
|
|
|
("Id".into(), send.uuid.clone().into()),
|
|
|
|
("UserId".into(), user_uuid),
|
|
|
|
("RevisionDate".into(), serialize_date(send.revision_date)),
|
|
|
|
],
|
|
|
|
ut,
|
2022-12-30 21:23:55 +01:00
|
|
|
None,
|
2021-08-03 17:33:59 +02:00
|
|
|
);
|
|
|
|
|
2018-09-01 06:30:53 +02:00
|
|
|
for uuid in user_uuids {
|
2022-05-20 20:37:32 +02:00
|
|
|
self.send_update(uuid, &data).await;
|
2018-08-30 17:43:46 +02:00
|
|
|
}
|
2023-06-11 13:28:18 +02:00
|
|
|
if CONFIG.push_enabled() && user_uuids.len() == 1 {
|
2023-06-16 23:34:16 +02:00
|
|
|
push_send_update(ut, send, acting_device_uuid, conn).await;
|
2023-06-11 13:28:18 +02:00
|
|
|
}
|
2018-08-30 17:43:46 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/* Message Structure
|
|
|
|
[
|
|
|
|
1, // MessageType.Invocation
|
2020-08-31 19:05:07 +02:00
|
|
|
{}, // Headers (map)
|
2018-08-30 17:43:46 +02:00
|
|
|
null, // InvocationId
|
|
|
|
"ReceiveMessage", // Target
|
|
|
|
[ // Arguments
|
|
|
|
{
|
2022-12-30 21:23:55 +01:00
|
|
|
"ContextId": acting_device_uuid || Nil,
|
2018-08-30 17:43:46 +02:00
|
|
|
"Type": ut as i32,
|
|
|
|
"Payload": {}
|
|
|
|
}
|
|
|
|
]
|
|
|
|
]
|
|
|
|
*/
|
2022-12-30 21:23:55 +01:00
|
|
|
fn create_update(payload: Vec<(Value, Value)>, ut: UpdateType, acting_device_uuid: Option<String>) -> Vec<u8> {
|
2018-08-30 17:43:46 +02:00
|
|
|
use rmpv::Value as V;
|
|
|
|
|
|
|
|
let value = V::Array(vec![
|
|
|
|
1.into(),
|
2020-08-31 19:05:07 +02:00
|
|
|
V::Map(vec![]),
|
2018-08-30 17:43:46 +02:00
|
|
|
V::Nil,
|
|
|
|
"ReceiveMessage".into(),
|
|
|
|
V::Array(vec![V::Map(vec![
|
2022-12-30 21:23:55 +01:00
|
|
|
("ContextId".into(), acting_device_uuid.map(|v| v.into()).unwrap_or_else(|| V::Nil)),
|
2018-08-30 17:43:46 +02:00
|
|
|
("Type".into(), (ut as i32).into()),
|
|
|
|
("Payload".into(), payload.into()),
|
|
|
|
])]),
|
|
|
|
]);
|
|
|
|
|
|
|
|
serialize(value)
|
|
|
|
}
|
|
|
|
|
|
|
|
fn create_ping() -> Vec<u8> {
|
|
|
|
serialize(Value::Array(vec![6.into()]))
|
|
|
|
}
|
|
|
|
|
|
|
|
#[allow(dead_code)]
|
2023-06-11 13:28:18 +02:00
|
|
|
#[derive(Copy, Clone, Eq, PartialEq)]
|
2018-08-30 17:43:46 +02:00
|
|
|
pub enum UpdateType {
|
2022-12-30 21:23:55 +01:00
|
|
|
SyncCipherUpdate = 0,
|
|
|
|
SyncCipherCreate = 1,
|
|
|
|
SyncLoginDelete = 2,
|
|
|
|
SyncFolderDelete = 3,
|
|
|
|
SyncCiphers = 4,
|
|
|
|
|
|
|
|
SyncVault = 5,
|
|
|
|
SyncOrgKeys = 6,
|
|
|
|
SyncFolderCreate = 7,
|
|
|
|
SyncFolderUpdate = 8,
|
|
|
|
SyncCipherDelete = 9,
|
2018-08-30 17:43:46 +02:00
|
|
|
SyncSettings = 10,
|
|
|
|
|
|
|
|
LogOut = 11,
|
2019-01-28 00:39:14 +01:00
|
|
|
|
2021-03-14 23:35:55 +01:00
|
|
|
SyncSendCreate = 12,
|
|
|
|
SyncSendUpdate = 13,
|
|
|
|
SyncSendDelete = 14,
|
|
|
|
|
2022-12-30 21:23:55 +01:00
|
|
|
AuthRequest = 15,
|
|
|
|
AuthRequestResponse = 16,
|
|
|
|
|
2019-01-28 00:39:14 +01:00
|
|
|
None = 100,
|
2018-08-30 17:43:46 +02:00
|
|
|
}
|
|
|
|
|
2023-03-30 17:18:59 +02:00
|
|
|
pub type Notify<'a> = &'a rocket::State<Arc<WebSocketUsers>>;
|
2018-08-30 17:43:46 +02:00
|
|
|
|
2023-03-30 17:18:59 +02:00
|
|
|
pub fn start_notification_server() -> Arc<WebSocketUsers> {
|
|
|
|
let users = Arc::clone(&WS_USERS);
|
2019-01-25 18:23:51 +01:00
|
|
|
if CONFIG.websocket_enabled() {
|
2023-03-30 17:18:59 +02:00
|
|
|
let users2 = Arc::<WebSocketUsers>::clone(&users);
|
2022-05-20 20:37:32 +02:00
|
|
|
tokio::spawn(async move {
|
|
|
|
let addr = (CONFIG.websocket_address(), CONFIG.websocket_port());
|
2022-05-21 19:12:38 +02:00
|
|
|
info!("Starting WebSockets server on {}:{}", addr.0, addr.1);
|
2022-05-20 20:37:32 +02:00
|
|
|
let listener = TcpListener::bind(addr).await.expect("Can't listen on websocket port");
|
|
|
|
|
|
|
|
let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel::<()>();
|
|
|
|
CONFIG.set_ws_shutdown_handle(shutdown_tx);
|
|
|
|
|
|
|
|
loop {
|
|
|
|
tokio::select! {
|
|
|
|
Ok((stream, addr)) = listener.accept() => {
|
2023-03-30 17:18:59 +02:00
|
|
|
tokio::spawn(handle_connection(stream, Arc::<WebSocketUsers>::clone(&users2), addr));
|
2022-05-20 20:37:32 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
_ = &mut shutdown_rx => {
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
2021-11-07 18:53:39 +01:00
|
|
|
|
2022-05-20 20:37:32 +02:00
|
|
|
info!("Shutting down WebSockets server!")
|
2018-10-15 16:08:15 +02:00
|
|
|
});
|
|
|
|
}
|
2018-08-30 17:43:46 +02:00
|
|
|
|
|
|
|
users
|
|
|
|
}
|
2022-05-20 20:37:32 +02:00
|
|
|
|
2023-03-30 17:18:59 +02:00
|
|
|
async fn handle_connection(stream: TcpStream, users: Arc<WebSocketUsers>, addr: SocketAddr) -> Result<(), Error> {
|
2022-05-20 20:37:32 +02:00
|
|
|
let mut user_uuid: Option<String> = None;
|
|
|
|
|
2022-05-21 19:12:38 +02:00
|
|
|
info!("Accepting WS connection from {addr}");
|
|
|
|
|
2022-05-20 20:37:32 +02:00
|
|
|
// Accept connection, do initial handshake, validate auth token and get the user ID
|
|
|
|
use handshake::server::{Request, Response};
|
|
|
|
let mut stream = accept_hdr_async(stream, |req: &Request, res: Response| {
|
|
|
|
if let Some(token) = get_request_token(req) {
|
|
|
|
if let Ok(claims) = crate::auth::decode_login(&token) {
|
|
|
|
user_uuid = Some(claims.sub);
|
|
|
|
return Ok(res);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
Err(Response::builder().status(401).body(None).unwrap())
|
|
|
|
})
|
|
|
|
.await?;
|
|
|
|
|
|
|
|
let user_uuid = user_uuid.expect("User UUID should be set after the handshake");
|
|
|
|
|
2023-03-30 17:18:59 +02:00
|
|
|
let (mut rx, guard) = {
|
|
|
|
// Add a channel to send messages to this client to the map
|
|
|
|
let entry_uuid = uuid::Uuid::new_v4();
|
|
|
|
let (tx, rx) = tokio::sync::mpsc::channel::<Message>(100);
|
|
|
|
users.map.entry(user_uuid.clone()).or_default().push((entry_uuid, tx));
|
2022-05-20 20:37:32 +02:00
|
|
|
|
2023-03-30 17:18:59 +02:00
|
|
|
// Once the guard goes out of scope, the connection will have been closed and the entry will be deleted from the map
|
|
|
|
(rx, WSEntryMapGuard::new(users, user_uuid, entry_uuid, addr.ip()))
|
|
|
|
};
|
|
|
|
|
|
|
|
let _guard = guard;
|
2022-05-20 20:37:32 +02:00
|
|
|
let mut interval = tokio::time::interval(Duration::from_secs(15));
|
|
|
|
loop {
|
|
|
|
tokio::select! {
|
|
|
|
res = stream.next() => {
|
|
|
|
match res {
|
|
|
|
Some(Ok(message)) => {
|
2023-03-30 17:18:59 +02:00
|
|
|
match message {
|
|
|
|
// Respond to any pings
|
|
|
|
Message::Ping(ping) => stream.send(Message::Pong(ping)).await?,
|
|
|
|
Message::Pong(_) => {/* Ignored */},
|
|
|
|
|
|
|
|
// We should receive an initial message with the protocol and version, and we will reply to it
|
|
|
|
Message::Text(ref message) => {
|
|
|
|
let msg = message.strip_suffix(RECORD_SEPARATOR as char).unwrap_or(message);
|
|
|
|
|
|
|
|
if serde_json::from_str(msg).ok() == Some(INITIAL_MESSAGE) {
|
|
|
|
stream.send(Message::binary(INITIAL_RESPONSE)).await?;
|
|
|
|
continue;
|
|
|
|
}
|
2022-05-20 20:37:32 +02:00
|
|
|
}
|
2023-03-30 17:18:59 +02:00
|
|
|
// Just echo anything else the client sends
|
|
|
|
_ => stream.send(message).await?,
|
2022-05-20 20:37:32 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
_ => break,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
res = rx.recv() => {
|
|
|
|
match res {
|
2023-03-30 17:18:59 +02:00
|
|
|
Some(res) => stream.send(res).await?,
|
2022-05-20 20:37:32 +02:00
|
|
|
None => break,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-03-30 17:18:59 +02:00
|
|
|
_ = interval.tick() => stream.send(Message::Ping(create_ping())).await?
|
2022-05-20 20:37:32 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
}
|
|
|
|
|
|
|
|
fn get_request_token(req: &handshake::server::Request) -> Option<String> {
|
|
|
|
const ACCESS_TOKEN_KEY: &str = "access_token=";
|
|
|
|
|
|
|
|
if let Some(Ok(auth)) = req.headers().get("Authorization").map(|a| a.to_str()) {
|
|
|
|
if let Some(token_part) = auth.strip_prefix("Bearer ") {
|
|
|
|
return Some(token_part.to_owned());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if let Some(params) = req.uri().query() {
|
|
|
|
let params_iter = params.split('&').take(1);
|
|
|
|
for val in params_iter {
|
|
|
|
if let Some(stripped) = val.strip_prefix(ACCESS_TOKEN_KEY) {
|
|
|
|
return Some(stripped.to_owned());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
None
|
|
|
|
}
|