diff --git a/Cargo.toml b/Cargo.toml index 1410f28..d9dc16d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,12 +11,22 @@ anyhow = "1.0" clap = "2.33.3" env_logger = "0.8.4" log = "0.4.13" +mime = "0.3.16" +mime_guess = "2.0.3" +rand = "0.8.4" serde_json = "1.0.61" +sled = "0.34.6" structopt = "0.3.21" toml = "0.5.8" +[dependencies.jm_client_core] +#path = "/home/lordmzte/dev/jensmemesclient/jensmemesclient/jm_client_core" +git = "https://tilera.xyz/git/lordmzte/jensmemesclient.git" +package = "jm_client_core" +rev = "0d3a77" + [dependencies.matrix-sdk] -version = "0.2.0" +git = "https://github.com/matrix-org/matrix-rust-sdk.git" features = ["encryption"] [dependencies.url] @@ -28,5 +38,5 @@ version = "1.0" features = ["derive"] [dependencies.tokio] -version = "0.2" +version = "1.7.0" features = ["macros"] diff --git a/exampleconfig.toml b/exampleconfig.toml index 9f1210e..060d912 100644 --- a/exampleconfig.toml +++ b/exampleconfig.toml @@ -1,7 +1,56 @@ +# URL of the homeserver to use homeserver_url = "https://matrix.org" user_id = "@rufftest:matrix.org" password = "xxx" -device_name = "RUFF" +# path to store databases +store_path = "store" + +# MEMES!! +memes = [ + # random stuff + { keyword = "alec", id = 650 }, + { keyword = "bastard", id = 375 }, + { keyword = "drogen", id = 191 }, + { keyword = "fresse", id = 375 }, + { keyword = "hendrik", randomcat = "hendrik" }, + { keyword = "hey", id = 243 }, + { keyword = "itbyhf", id = 314 }, + { keyword = "jonasled", id = 164 }, + { keyword = "kappa", id = 182 }, + { keyword = "lordmzte", id = 315 }, + { keyword = "realtox", id = 168 }, + { keyword = "sklave", id = 304 }, + { keyword = "tilera", id = 316 }, + + # üffen + { keyword = "uff", randomcat = "uff" }, + { keyword = "biguff", id = 771 }, + { keyword = "hmm", id = 892 }, + { keyword = "hmmm", id = 891 }, + { keyword = "longuff", id = 771 }, + { keyword = "uffal", id = 654 }, + { keyword = "uffat", id = 257 }, + { keyword = "uffch", id = 286 }, + { keyword = "uffde", id = 144 }, + { keyword = "uffgo", id = 568 }, + { keyword = "uffhf", id = 645 }, + { keyword = "uffhre", id = 312 }, + { keyword = "uffhs", id = 331 }, + { keyword = "uffj", id = 626 }, + { keyword = "uffjl", id = 773 }, + { keyword = "uffjs", id = 615 }, + { keyword = "uffkt", id = 627 }, + { keyword = "ufflie", id = 284 }, + { keyword = "uffmj", id = 831 }, + { keyword = "uffmz", id = 646 }, + { keyword = "uffns", id = 287 }, + { keyword = "uffpy", id = 477 }, + { keyword = "uffrs", id = 616 }, + { keyword = "uffrt", id = 986 }, + { keyword = "uffru", id = 999 }, + { keyword = "uffsb", id = 818 }, + { keyword = "uffsr", id = 585 }, + { keyword = "ufftl", id = 644 }, + { keyword = "uffwe", id = 779 }, +] -[memes] -uff = 144 diff --git a/src/config.rs b/src/config.rs index bbe9a90..7008077 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,9 +1,8 @@ -use matrix_sdk::identifiers::UserId; -use serde::{ - de::{self, Deserializer, Visitor}, - Deserialize, -}; -use std::{collections::BTreeMap, convert::TryFrom}; +use crate::meme::Meme; +use crate::meme::MemeIdent; +use serde::de::{self, Deserializer, MapAccess, Visitor}; +use serde::Deserialize; +use std::path::PathBuf; use url::Url; #[derive(Debug, Deserialize)] @@ -12,5 +11,113 @@ pub struct Config { pub user_id: String, pub password: String, pub device_name: Option, - pub memes: BTreeMap, + pub memes: Vec, + pub store_path: PathBuf, + #[serde(default = "default_clear_threshold")] + pub clear_cache_threshold: u32, +} + +fn default_clear_threshold() -> u32 { + 10 +} + +impl<'de> Deserialize<'de> for Meme { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + const FIELDS: &[&str] = &["keyword", "id", "randomcat"]; + enum Field { + Keyword, + Ident(IdentField), + } + + enum IdentField { + RandomCat, + Id, + } + + impl<'de> Deserialize<'de> for Field { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct Vis; + + impl<'de> Visitor<'de> for Vis { + type Value = Field; + fn expecting( + &self, + fmt: &mut std::fmt::Formatter<'_>, + ) -> Result<(), std::fmt::Error> { + fmt.write_str("a field for a meme") + } + + fn visit_str(self, v: &str) -> Result + where + E: de::Error, + { + Ok(match v { + "keyword" => Field::Keyword, + "randomcat" => Field::Ident(IdentField::RandomCat), + "id" => Field::Ident(IdentField::Id), + _ => return Err(de::Error::unknown_field(v, FIELDS)), + }) + } + } + + deserializer.deserialize_identifier(Vis) + } + } + + struct Vis; + + impl<'de> Visitor<'de> for Vis { + type Value = Meme; + fn expecting(&self, fmt: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + fmt.write_str("a meme") + } + + fn visit_map(self, mut map: A) -> Result + where + A: MapAccess<'de>, + { + let mut keyword = None; + let mut ident = None; + while let Some(key) = map.next_key()? { + match key { + Field::Keyword => { + if keyword.is_some() { + return Err(de::Error::duplicate_field("keyword")); + } + + keyword = Some(map.next_value()?); + } + + Field::Ident(i) => { + if ident.is_some() { + return Err(de::Error::duplicate_field( + "ident, can only have one.", + )); + } + + match i { + IdentField::Id => { + ident = Some(MemeIdent::Id(map.next_value()?)); + } + IdentField::RandomCat => { + ident = Some(MemeIdent::RandomCat(map.next_value()?)); + } + } + } + } + } + let keyword = keyword.ok_or_else(|| de::Error::missing_field("keyword"))?; + let ident = ident.ok_or_else(|| de::Error::missing_field("ident"))?; + Ok(Meme { keyword, ident }) + } + } + + deserializer.deserialize_struct("Meme", FIELDS, Vis) + } } diff --git a/src/main.rs b/src/main.rs index 7bae30a..396d385 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,25 +1,34 @@ use anyhow::{anyhow, bail, Context}; +use jm_client_core::JMClient; use log::{error, info, warn}; -use matrix_sdk::SyncRoom; use matrix_sdk::{ - api::r0::{session::login, sync::sync_events}, + api::r0::session::login, async_trait, + deserialized_responses::SyncResponse, events::{ room::{ member::MemberEventContent, - message::{MessageEventContent, TextMessageEventContent}, + message::{MessageEventContent, MessageType, TextMessageEventContent}, }, - AnyMessageEventContent, AnyToDeviceEvent, StrippedStateEvent, SyncMessageEvent, + AnyToDeviceEvent, StrippedStateEvent, SyncMessageEvent, }, - EventEmitter, LoopCtrl, + room::Room, + verification::Verification, + EventHandler, LoopCtrl, }; +use rand::{rngs::StdRng, SeedableRng}; +use sled::Db; use std::{ collections::BTreeMap, path::PathBuf, - sync::{atomic::AtomicBool, Arc}, + sync::{ + atomic::{AtomicBool, AtomicU32}, + Arc, + }, time::Duration, }; use structopt::StructOpt; +use tokio::sync::Mutex; use tokio::sync::RwLock; use config::Config; @@ -27,6 +36,8 @@ use config::Config; use matrix_sdk::{self, api::r0::uiaa::AuthData, identifiers::UserId, Client, SyncSettings}; use serde_json::json; mod config; +mod meme; +mod responder; #[derive(Debug, StructOpt)] struct Opt { @@ -47,6 +58,7 @@ async fn main() -> anyhow::Result<()> { let config = std::fs::read(&opt.config).map_err(|e| anyhow!("Error reading config: {}", e))?; let config = toml::from_slice::(&config).map_err(|e| anyhow!("Error parsing config: {}", e))?; + let config = Arc::new(config); let client = Arc::new(RwLock::new(Client::new(config.homeserver_url.clone())?)); @@ -60,8 +72,14 @@ async fn main() -> anyhow::Result<()> { client .write() .await - .add_event_emitter(Box::new(Bot { + .set_event_handler(Box::new(Bot { client: Arc::clone(&client), + jm_client: RwLock::new(JMClient::new()), + memecache: sled::open(config.store_path.join("memecache")) + .map_err(|e| anyhow!("error opening memecache: {}", e))?, + config: Arc::clone(&config), + meme_count: AtomicU32::new(0), + rng: Mutex::new(StdRng::from_rng(rand::thread_rng())?), })) .await; @@ -82,8 +100,7 @@ async fn main() -> anyhow::Result<()> { if initial.load(std::sync::atomic::Ordering::SeqCst) { if let Err(e) = - on_initial_response(&response, client_ref, &user_id_ref, &config_ref.password) - .await + on_initial_response(client_ref, &user_id_ref, &config_ref.password).await { error!("Error processing initial response: {}", e); } @@ -98,15 +115,23 @@ async fn main() -> anyhow::Result<()> { Ok(()) } -struct Bot { +pub struct Bot { client: Arc>, + jm_client: RwLock, + memecache: Db, + config: Arc, + /// used to keep track of how many memes have been sent. + /// this is reset once the threshold set in the config has been reached, and the JMClient cache + /// is cleared. + meme_count: AtomicU32, + rng: Mutex, } #[async_trait] -impl EventEmitter for Bot { +impl EventHandler for Bot { async fn on_stripped_state_member( &self, - room: SyncRoom, + room: Room, room_member: &StrippedStateEvent, _: Option, ) { @@ -114,16 +139,15 @@ impl EventEmitter for Bot { return; } - if let SyncRoom::Invited(room) = room { - let room = room.read().await; - println!("Autojoining room {}", room.room_id); + if let Room::Invited(room) = room { + println!("Autojoining room {}", room.room_id()); let mut delay = 2; while let Err(err) = self .client .read() .await - .join_room_by_id(&room.room_id) + .join_room_by_id(&room.room_id()) .await { // retry autojoin due to synapse sending invites, before the @@ -131,44 +155,64 @@ impl EventEmitter for Bot { // https://github.com/matrix-org/synapse/issues/4345 warn!( "Failed to join room {} ({:?}), retrying in {}s", - room.room_id, err, delay + room.room_id(), + err, + delay ); - tokio::time::delay_for(Duration::from_secs(delay)).await; + tokio::time::sleep(Duration::from_secs(delay)).await; delay *= 2; if delay > 3600 { - error!("Can't join room {} ({:?})", room.room_id, err); + error!("Can't join room {} ({:?})", room.room_id(), err); break; } } - info!("Successfully joined room {}", room.room_id); + + info!("Successfully joined room {}", room.room_id()); } } - async fn on_room_message(&self, _room: SyncRoom, msg: &SyncMessageEvent) { - dbg!(msg); + async fn on_room_message(&self, room: Room, msg: &SyncMessageEvent) { + if self + .client + .read() + .await + .user_id() + .await + .map(|u| u == msg.sender) + .unwrap_or(true) + { + return; + } + + if let SyncMessageEvent { + content: + MessageEventContent { + msgtype: MessageType::Text(TextMessageEventContent { body: msg_body, .. }), + .. + }, + .. + } = msg + { + if let Err(e) = responder::on_msg(msg_body, room, self).await { + error!("Responder error: {}", e); + } + } } } async fn on_initial_response( - _response: &sync_events::Response, client: &Client, user_id: &UserId, password: &str, ) -> anyhow::Result<()> { bootstrap_cross_signing(client, user_id, password).await?; - for (id, _room) in client.joined_rooms().read().await.iter() { - let content = AnyMessageEventContent::RoomMessage(MessageEventContent::Text( - TextMessageEventContent::plain("Hello world"), - )); - client.room_send(id, content, None).await?; - } Ok(()) } -async fn on_response(response: &sync_events::Response, client: &Client) -> anyhow::Result<()> { +async fn on_response(response: &SyncResponse, client: &Client) -> anyhow::Result<()> { for event in response .to_device .events @@ -178,7 +222,10 @@ async fn on_response(response: &sync_events::Response, client: &Client) -> anyho match event { AnyToDeviceEvent::KeyVerificationStart(e) => { info!("Starting verification"); - if let Some(sas) = &client.get_verification(&e.content.transaction_id).await { + if let Some(Verification::SasV1(sas)) = &client + .get_verification(&e.sender, &e.content.transaction_id) + .await + { if let Err(e) = sas.accept().await { error!("Error accepting key verification request: {}", e); } @@ -186,7 +233,10 @@ async fn on_response(response: &sync_events::Response, client: &Client) -> anyho } AnyToDeviceEvent::KeyVerificationKey(e) => { - if let Some(sas) = &client.get_verification(&e.content.transaction_id).await { + if let Some(Verification::SasV1(sas)) = &client + .get_verification(&e.sender, &e.content.transaction_id) + .await + { if let Err(e) = sas.confirm().await { error!("Error confirming key verification request: {}", e); } @@ -223,7 +273,7 @@ async fn bootstrap_cross_signing( password: &str, ) -> anyhow::Result<()> { info!("bootstrapping e2e"); - if let Err(e) = dbg!(client.bootstrap_cross_signing(None).await) { + if let Err(e) = client.bootstrap_cross_signing(None).await { warn!("couldnt bootstrap e2e without auth data"); if let Some(response) = e.uiaa_response() { let auth_data = auth_data(&user_id, &password, response.session.as_deref()); @@ -231,7 +281,7 @@ async fn bootstrap_cross_signing( .bootstrap_cross_signing(Some(auth_data)) .await .context("Couldn't bootstrap cross signing")?; - info!("bootstrapped e2e with auth data"); + info!("bootstrapped e2e with auth data"); } else { bail!("Error during cross-signing bootstrap {:#?}", e); } diff --git a/src/meme.rs b/src/meme.rs new file mode 100644 index 0000000..261724e --- /dev/null +++ b/src/meme.rs @@ -0,0 +1,39 @@ +use crate::Bot; +use rand::seq::IteratorRandom; + +#[derive(Debug)] +pub struct Meme { + pub keyword: String, + pub ident: MemeIdent, +} + +impl Meme { + /// checks if the meme should be triggered for the given message + pub fn matches(&self, msg: &str) -> bool { + let msg = msg.to_ascii_lowercase(); + msg.starts_with(&self.keyword) && + // msg must have one of allowed chars after keyword + msg.chars().nth(self.keyword.len()).map(|c|" ,.;:!?({-_".contains(c)).unwrap_or(true) + } + + pub async fn get_meme(&self, bot: &Bot) -> anyhow::Result> { + let memes = bot.jm_client.read().await.get_memes().await?; + match &self.ident { + MemeIdent::Id(i) => Ok(memes + .iter() + .find(|m| m.id.parse::().ok() == Some(*i)) + .cloned()), + MemeIdent::RandomCat(c) => Ok(memes + .iter() + .filter(|m| &m.category == c) + .choose(&mut *bot.rng.lock().await) + .cloned()), + } + } +} + +#[derive(Debug)] +pub enum MemeIdent { + RandomCat(String), + Id(u32), +} diff --git a/src/responder.rs b/src/responder.rs new file mode 100644 index 0000000..3f2badc --- /dev/null +++ b/src/responder.rs @@ -0,0 +1,107 @@ +use crate::Bot; +use anyhow::Context; +use log::{error, info, warn}; +use matrix_sdk::{ + api::r0::media::create_content, + events::{ + room::message::{ImageMessageEventContent, MessageEventContent, MessageType}, + AnyMessageEventContent, + }, + identifiers::MxcUri, + room::{Joined, Room}, +}; +use std::io::Cursor; +use std::sync::atomic::Ordering; + +pub async fn on_msg(msg: &str, room: Room, bot: &Bot) -> anyhow::Result<()> { + let room = match room { + Room::Joined(room) => room, + _ => { + warn!( + "Received message '{}' in room {:?} that's not joined", + msg, + room.name() + ); + return Ok(()); + } + }; + + for meme in &bot.config.memes { + if meme.matches(msg) { + bot.meme_count.fetch_add(1, Ordering::SeqCst); + if bot.meme_count.load(Ordering::SeqCst) >= bot.config.clear_cache_threshold { + bot.jm_client.write().await.clear_cache().await; + bot.meme_count.store(0, Ordering::SeqCst); + } + + let meme_name = &meme.keyword; + if let Some(meme) = meme.get_meme(bot).await? { + match meme.id.parse::() { + Err(e) => { + error!("Meme {:?} has invalid ID! tilera, you messed up with your stupid php again: {}", &meme, e); + } + Ok(id) => { + if let Some(ivec) = bot.memecache.get(id.to_le_bytes())? { + info!("Meme {} found in cache!", id); + let mxc = String::from_utf8(ivec.as_ref().to_vec()).context( + "Found invalid utf8 mxc url in memecache! Is the cache borked?", + )?; + + send_meme(&room, mxc.into(), meme_name.clone()).await?; + } else { + info!("Meme {} not found in cache, uploading...", id); + let resp = bot + .jm_client + .read() + .await + .http + .get(&meme.link) + .send() + .await + .context("error downloading meme")?; + let resp = resp.bytes().await?; + + if let Some(mime) = mime_guess::from_path(&meme.link).first() { + let create_content::Response { content_uri, .. } = bot + .client + .read() + .await + .upload(&mime, &mut Cursor::new(resp)) + .await?; + + bot.memecache.insert( + id.to_le_bytes(), + content_uri.to_string().into_bytes(), + )?; + + send_meme(&room, content_uri, meme_name.clone()).await?; + } else { + error!( + "Couldn't guess MIME type of meme '{}', skipping.", + &meme.link + ); + } + } + } + } + } else { + error!("Found meme with invalid id! {:?}", &meme); + } + } + } + + Ok(()) +} + +async fn send_meme(room: &Joined, mxc: MxcUri, meme_name: String) -> anyhow::Result<()> { + room.send( + AnyMessageEventContent::RoomMessage(MessageEventContent::new(MessageType::Image( + ImageMessageEventContent::plain(meme_name, mxc, None), + ))), + None, + ) + .await + .context("Failed to send meme")?; + + Ok(()) +}