RUFF/src/main.rs
2021-07-26 23:58:04 +02:00

295 lines
8.3 KiB
Rust

use anyhow::{anyhow, bail, Context};
use jm_client_core::JMClient;
use log::{error, info, warn};
use matrix_sdk::{
api::r0::session::login,
async_trait,
deserialized_responses::SyncResponse,
events::{
room::{
member::MemberEventContent,
message::{MessageEventContent, MessageType, TextMessageEventContent},
},
AnyToDeviceEvent,
StrippedStateEvent,
SyncMessageEvent,
},
room::Room,
verification::Verification,
EventHandler,
LoopCtrl,
};
use rand::{rngs::StdRng, SeedableRng};
use sled::Db;
use std::{
collections::BTreeMap,
path::PathBuf,
sync::{
atomic::{AtomicBool, AtomicU32},
Arc,
},
time::Duration,
};
use structopt::StructOpt;
use tokio::sync::{Mutex, RwLock};
use config::Config;
use matrix_sdk::{self, api::r0::uiaa::AuthData, identifiers::UserId, Client, SyncSettings};
use serde_json::json;
mod config;
mod meme;
mod responder;
mod util;
#[derive(Debug, StructOpt)]
struct Opt {
#[structopt(
short,
long,
help = "config file to use",
default_value = "~/.config/ruff/config.toml"
)]
config: PathBuf,
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
env_logger::init();
let opt = Opt::from_args();
let config = std::fs::read(&opt.config).map_err(|e| anyhow!("Error reading config: {}", e))?;
let config =
toml::from_slice::<Config>(&config).map_err(|e| anyhow!("Error parsing config: {}", e))?;
let config = Arc::new(config);
let client = Arc::new(RwLock::new(Client::new(config.homeserver_url.clone())?));
let device_name = config.device_name.as_ref().map(String::as_ref);
let login::Response { user_id, .. } = client
.read()
.await
.login(&config.user_id, &config.password, device_name, device_name)
.await?;
client
.write()
.await
.set_event_handler(Box::new(Bot {
client: Arc::clone(&client),
jm_client: RwLock::new(JMClient::new()),
memecache: sled::open(config.store_path.join("memecache"))
.map_err(|e| anyhow!("error opening memecache: {}", e))?,
config: Arc::clone(&config),
meme_count: AtomicU32::new(0),
rng: Mutex::new(StdRng::from_rng(rand::thread_rng())?),
}))
.await;
let initial = AtomicBool::from(true);
let initial_ref = &initial;
let client_ref = &client.read().await;
let config_ref = &config;
let user_id_ref = &user_id;
client
.read()
.await
.sync_with_callback(SyncSettings::new(), |response| async move {
if let Err(e) = on_response(&response, client_ref).await {
error!("Error processing response: {}", e);
}
let initial = initial_ref;
if initial.load(std::sync::atomic::Ordering::SeqCst) {
if let Err(e) =
on_initial_response(client_ref, &user_id_ref, &config_ref.password).await
{
error!("Error processing initial response: {}", e);
}
initial.store(false, std::sync::atomic::Ordering::SeqCst);
}
LoopCtrl::Continue
})
.await;
Ok(())
}
pub struct Bot {
client: Arc<RwLock<Client>>,
jm_client: RwLock<JMClient>,
memecache: Db,
config: Arc<Config>,
/// used to keep track of how many memes have been sent.
/// this is reset once the threshold set in the config has been reached, and
/// the JMClient cache is cleared.
meme_count: AtomicU32,
rng: Mutex<StdRng>,
}
#[async_trait]
impl EventHandler for Bot {
async fn on_stripped_state_member(
&self,
room: Room,
room_member: &StrippedStateEvent<MemberEventContent>,
_: Option<MemberEventContent>,
) {
if room_member.state_key == self.client.read().await.user_id().await.unwrap() {
return;
}
if let Room::Invited(room) = room {
info!("Autojoining room {}", room.room_id());
let mut delay = 2;
while let Err(err) = self
.client
.read()
.await
.join_room_by_id(&room.room_id())
.await
{
// retry autojoin due to synapse sending invites, before the
// invited user can join for more information see
// https://github.com/matrix-org/synapse/issues/4345
warn!(
"Failed to join room {} ({:?}), retrying in {}s",
room.room_id(),
err,
delay
);
tokio::time::sleep(Duration::from_secs(delay)).await;
delay *= 2;
if delay > 3600 {
error!("Can't join room {} ({:?})", room.room_id(), err);
break;
}
}
info!("Successfully joined room {}", room.room_id());
}
}
async fn on_room_message(&self, room: Room, msg: &SyncMessageEvent<MessageEventContent>) {
if self
.client
.read()
.await
.user_id()
.await
.map(|u| u == msg.sender)
.unwrap_or(true)
{
return;
}
if let SyncMessageEvent {
content:
MessageEventContent {
msgtype: MessageType::Text(TextMessageEventContent { body: msg_body, .. }),
..
},
..
} = msg
{
if let Err(e) = responder::on_msg(msg_body, room, self).await {
error!("Responder error: {}", e);
}
}
}
}
async fn on_initial_response(
client: &Client,
user_id: &UserId,
password: &str,
) -> anyhow::Result<()> {
bootstrap_cross_signing(client, user_id, password).await?;
Ok(())
}
async fn on_response(response: &SyncResponse, client: &Client) -> anyhow::Result<()> {
for event in response
.to_device
.events
.iter()
.filter_map(|e| e.deserialize().ok())
{
match event {
AnyToDeviceEvent::KeyVerificationStart(e) => {
info!("Starting verification");
if let Some(Verification::SasV1(sas)) = &client
.get_verification(&e.sender, &e.content.transaction_id)
.await
{
if let Err(e) = sas.accept().await {
error!("Error accepting key verification request: {}", e);
}
}
},
AnyToDeviceEvent::KeyVerificationKey(e) => {
if let Some(Verification::SasV1(sas)) = &client
.get_verification(&e.sender, &e.content.transaction_id)
.await
{
if let Err(e) = sas.confirm().await {
error!("Error confirming key verification request: {}", e);
}
}
},
_ => {},
}
}
Ok(())
}
fn auth_data<'a>(user: &UserId, password: &str, session: Option<&'a str>) -> AuthData<'a> {
let mut auth_parameters = BTreeMap::new();
let identifier = json!({
"type": "m.id.user",
"user": user,
});
auth_parameters.insert("identifier".to_owned(), identifier);
auth_parameters.insert("password".to_owned(), password.to_owned().into());
AuthData::DirectRequest {
kind: "m.login.password",
auth_parameters,
session,
}
}
async fn bootstrap_cross_signing(
client: &Client,
user_id: &UserId,
password: &str,
) -> anyhow::Result<()> {
info!("bootstrapping e2e");
if let Err(e) = client.bootstrap_cross_signing(None).await {
if let Some(response) = e.uiaa_response() {
let auth_data = auth_data(&user_id, &password, response.session.as_deref());
client
.bootstrap_cross_signing(Some(auth_data))
.await
.context("Couldn't bootstrap cross signing")?;
} else {
bail!("Error during cross-signing bootstrap {:#?}", e);
}
}
info!("bootstrapped e2e");
Ok(())
}