diff --git a/exampleconfig.toml b/exampleconfig.toml index 5c4ef07..11c2ad3 100644 --- a/exampleconfig.toml +++ b/exampleconfig.toml @@ -23,6 +23,8 @@ memes = [ { keyword = "realtox", id = 168 }, { keyword = "sklave", id = 304 }, { keyword = "tilera", id = 316 }, + { keyword = "wtf", randomcat = "random", match = "contains" }, + { keyword = "wtuff", randomcat = "random", match = "contains" }, # üffen { keyword = "uff", randomcat = "uff" }, diff --git a/src/config.rs b/src/config.rs index 5e7308c..6e9f84d 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,4 +1,4 @@ -use crate::meme::{Meme, MemeIdent}; +use crate::meme::{Matcher, Meme, MemeIdent}; use serde::{ de::{self, Deserializer, MapAccess, Visitor}, Deserialize, @@ -31,6 +31,7 @@ impl<'de> Deserialize<'de> for Meme { enum Field { Keyword, Ident(IdentField), + Matcher, } enum IdentField { @@ -62,6 +63,7 @@ impl<'de> Deserialize<'de> for Meme { "keyword" => Field::Keyword, "randomcat" => Field::Ident(IdentField::RandomCat), "id" => Field::Ident(IdentField::Id), + "match" => Field::Matcher, _ => return Err(de::Error::unknown_field(v, FIELDS)), }) } @@ -85,6 +87,7 @@ impl<'de> Deserialize<'de> for Meme { { let mut keyword = None; let mut ident = None; + let mut matcher = None; while let Some(key) = map.next_key()? { match key { Field::Keyword => { @@ -93,7 +96,7 @@ impl<'de> Deserialize<'de> for Meme { } keyword = Some(map.next_value()?); - }, + } Field::Ident(i) => { if ident.is_some() { @@ -105,17 +108,32 @@ impl<'de> Deserialize<'de> for Meme { match i { IdentField::Id => { ident = Some(MemeIdent::Id(map.next_value()?)); - }, + } IdentField::RandomCat => { ident = Some(MemeIdent::RandomCat(map.next_value()?)); - }, + } } - }, + } + + Field::Matcher => { + if matcher.is_some() { + return Err(de::Error::duplicate_field("match")); + } + + matcher = Some(map.next_value()?); + } } } + let keyword = keyword.ok_or_else(|| de::Error::missing_field("keyword"))?; let ident = ident.ok_or_else(|| de::Error::missing_field("ident"))?; - Ok(Meme { keyword, ident }) + let matcher = matcher.unwrap_or(Matcher::Begins); + + Ok(Meme { + keyword, + ident, + matcher, + }) } } diff --git a/src/main.rs b/src/main.rs index 9759411..44799dc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,14 +10,11 @@ use matrix_sdk::{ member::MemberEventContent, message::{MessageEventContent, MessageType, TextMessageEventContent}, }, - AnyToDeviceEvent, - StrippedStateEvent, - SyncMessageEvent, + AnyToDeviceEvent, StrippedStateEvent, SyncMessageEvent, }, room::Room, verification::Verification, - EventHandler, - LoopCtrl, + EventHandler, LoopCtrl, }; use rand::{rngs::StdRng, SeedableRng}; use sled::Db; @@ -233,7 +230,7 @@ async fn on_response(response: &SyncResponse, client: &Client) -> anyhow::Result error!("Error accepting key verification request: {}", e); } } - }, + } AnyToDeviceEvent::KeyVerificationKey(e) => { if let Some(Verification::SasV1(sas)) = &client @@ -244,9 +241,9 @@ async fn on_response(response: &SyncResponse, client: &Client) -> anyhow::Result error!("Error confirming key verification request: {}", e); } } - }, + } - _ => {}, + _ => {} } } diff --git a/src/meme.rs b/src/meme.rs index 261724e..568bb86 100644 --- a/src/meme.rs +++ b/src/meme.rs @@ -1,19 +1,52 @@ use crate::Bot; use rand::seq::IteratorRandom; +use serde::Deserialize; + +pub const ALLOWED_SPACES: &str = " ,.;:!?({-_"; #[derive(Debug)] pub struct Meme { pub keyword: String, pub ident: MemeIdent, + pub matcher: Matcher, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum Matcher { + Begins, + Contains, } impl Meme { /// checks if the meme should be triggered for the given message pub fn matches(&self, msg: &str) -> bool { let msg = msg.to_ascii_lowercase(); - msg.starts_with(&self.keyword) && - // msg must have one of allowed chars after keyword - msg.chars().nth(self.keyword.len()).map(|c|" ,.;:!?({-_".contains(c)).unwrap_or(true) + + match self.matcher { + Matcher::Begins => { + msg.starts_with(&self.keyword) && + // msg must have one of allowed chars after keyword + msg.chars().nth(self.keyword.len()).map(|c| ALLOWED_SPACES.contains(c)).unwrap_or(true) + } + + Matcher::Contains => msg + .match_indices(&self.keyword) + .map(|(idx, subs)| { + idx == 0 + || msg + .chars() + .nth(idx - 1) + .map(|c| ALLOWED_SPACES.contains(c)) + .unwrap_or(true) + && msg + .chars() + .nth(idx + subs.len()) + .map(|c| ALLOWED_SPACES.contains(c)) + .unwrap_or(true) + }) + .any(|b| b), + } } pub async fn get_meme(&self, bot: &Bot) -> anyhow::Result> { @@ -37,3 +70,42 @@ pub enum MemeIdent { RandomCat(String), Id(u32), } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn matches_begins_test() { + let meme = Meme { + keyword: String::from("test"), + ident: MemeIdent::Id(42), + matcher: Matcher::Begins, + }; + + assert!(!meme.matches("xxx")); + assert!(!meme.matches("testxxx")); + assert!(meme.matches("test")); + assert!(meme.matches("test xxx")); + assert!(meme.matches("test; xxx")); + assert!(meme.matches("test;xxx")); + assert!(meme.matches("TEST")); + assert!(meme.matches("TeSt")); + } + + #[test] + fn matches_contains_test() { + let meme = Meme { + keyword: String::from("test"), + ident: MemeIdent::Id(42), + matcher: Matcher::Contains, + }; + + assert!(!meme.matches("xxx")); + assert!(!meme.matches("xxxtestxxx")); + assert!(meme.matches("xxx test xxx")); + assert!(meme.matches("xxx,test.xxx")); + assert!(meme.matches("xxx,TEST.xxx")); + assert!(meme.matches("xxx,TeSt.xxx")); + } +} diff --git a/src/responder.rs b/src/responder.rs index cee4e30..6bca4cc 100644 --- a/src/responder.rs +++ b/src/responder.rs @@ -6,10 +6,7 @@ use matrix_sdk::{ events::{ room::{ message::{ - ImageMessageEventContent, - MessageEventContent, - MessageType, - VideoInfo, + ImageMessageEventContent, MessageEventContent, MessageType, VideoInfo, VideoMessageEventContent, }, ImageInfo, @@ -76,7 +73,7 @@ pub async fn on_msg(msg: &str, room: Room, bot: &Bot) -> anyhow::Result<()> { room.name() ); return Ok(()); - }, + } }; for meme in &bot.config.memes { @@ -92,7 +89,7 @@ pub async fn on_msg(msg: &str, room: Room, bot: &Bot) -> anyhow::Result<()> { again: {}", &meme, e ); - }, + } Ok(id) => { if let Some(ivec) = bot.memecache.get(id.to_le_bytes())? { let cached = bincode::deserialize::(&ivec)?; @@ -137,7 +134,7 @@ pub async fn on_msg(msg: &str, room: Room, bot: &Bot) -> anyhow::Result<()> { ); } } - }, + } } } else { error!("Found meme with invalid id! {:?}", &meme); diff --git a/src/util.rs b/src/util.rs index b729ba0..7b970b6 100644 --- a/src/util.rs +++ b/src/util.rs @@ -2,8 +2,7 @@ pub mod mime_serialize { use mime::Mime; use serde::{ de::{self, Unexpected, Visitor}, - Deserializer, - Serializer, + Deserializer, Serializer, }; pub fn serialize(data: &Mime, serializer: S) -> Result