diff --git a/src/cdn/mod.rs b/src/cdn/mod.rs index d888cb0..9880a3c 100644 --- a/src/cdn/mod.rs +++ b/src/cdn/mod.rs @@ -14,7 +14,7 @@ use reqwest::{ }; use sqlx::MySqlPool; -use crate::config::ConfVars; +use crate::JMService; use self::{ error::CDNError, @@ -36,12 +36,11 @@ pub fn routes() -> Router { async fn image( Path((user, filename)): Path<(String, String)>, Extension(db_pool): Extension, - Extension(vars): Extension, + Extension(service): Extension, ) -> Result { let filename = urlencoding::decode(&filename)?.into_owned(); let cid = sql::get_cid(user, filename.clone(), &db_pool).await?; - let ipfs = vars.ipfs_client()?; - let res = ipfs.cat(cid).await?; + let res = service.cat(cid).await?; let clength = res .headers() .get(HeaderName::from_static("x-content-length")) @@ -61,12 +60,12 @@ async fn image( async fn users( Extension(db_pool): Extension, - Extension(vars): Extension, + Extension(service): Extension, ) -> Result { let users = sql::get_users(&db_pool).await?; Ok(HtmlTemplate(DirTemplate { entries: users, - prefix: vars.cdn, + prefix: service.cdn_url(), suffix: "/".to_string(), })) } diff --git a/src/config.rs b/src/config.rs index 877a2f1..d1d19bb 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,6 +1,8 @@ use reqwest::Url; use serde::Deserialize; -use std::net::SocketAddr; +use std::{net::SocketAddr, sync::Arc}; + +use crate::{error::JMError, JMService, JMServiceInner}; #[derive(Deserialize)] pub struct Config { @@ -10,25 +12,19 @@ pub struct Config { pub ipfs_api: Url, } -pub struct ConfVars { - pub cdn: String, - pub ipfs_api: Url, -} - impl Config { - pub fn vars(&self) -> ConfVars { - ConfVars { - cdn: self.cdn.clone(), - ipfs_api: self.ipfs_api.clone(), - } + pub fn service(&self) -> Result { + let client = reqwest::ClientBuilder::new().user_agent("curl").build()?; + Ok(Arc::new(JMServiceInner { + client, + ipfs_url: self.ipfs_api.clone(), + cdn_url: self.cdn.clone(), + })) } } -impl Clone for ConfVars { - fn clone(&self) -> Self { - Self { - cdn: self.cdn.clone(), - ipfs_api: self.ipfs_api.clone(), - } +impl JMServiceInner { + pub fn cdn_url(&self) -> String { + self.cdn_url.clone() } } diff --git a/src/error.rs b/src/error.rs index 5a81af8..a9a5383 100644 --- a/src/error.rs +++ b/src/error.rs @@ -10,4 +10,6 @@ pub enum JMError { Database(#[from] sqlx::Error), #[error("Axum error: {0}")] Axum(#[from] hyper::Error), + #[error("Reqwest error: {0}")] + Reqwest(#[from] reqwest::Error), } diff --git a/src/ipfs/mod.rs b/src/ipfs/mod.rs index 752f9d8..23691ea 100644 --- a/src/ipfs/mod.rs +++ b/src/ipfs/mod.rs @@ -3,11 +3,11 @@ use std::time::Duration; use axum::body::Bytes; use reqwest::{ multipart::{Form, Part}, - Client, Response, Url, + Response, }; use serde::{Deserialize, Serialize}; -use crate::config::ConfVars; +use crate::JMServiceInner; use self::error::IPFSError; @@ -38,16 +38,11 @@ pub struct PinQuery { pub arg: String, } -pub struct IpfsClient { - url: Url, - client: Client, -} - -impl IpfsClient { +impl JMServiceInner { pub async fn cat(&self, cid: String) -> Result { let request = self .client - .post(self.url.join("/api/v0/cat")?) + .post(self.ipfs_url.join("/api/v0/cat")?) .query(&CatQuery::new(cid)); Ok(request.send().await?) } @@ -55,7 +50,7 @@ impl IpfsClient { pub async fn add(&self, file: Bytes, filename: String) -> Result { let request = self .client - .post(self.url.join("/api/v0/add")?) + .post(self.ipfs_url.join("/api/v0/add")?) .query(&AddQuery::new(false)) .multipart(Form::new().part("file", Part::stream(file).file_name(filename))); let response = request.send().await?; @@ -66,7 +61,7 @@ impl IpfsClient { pub async fn pin(&self, cid: String) -> Result<(), IPFSError> { let request = self .client - .post(self.url.join("/api/v0/pin/add")?) + .post(self.ipfs_url.join("/api/v0/pin/add")?) .query(&PinQuery::new(cid)) .timeout(Duration::from_secs(60)); request.send().await?; @@ -91,13 +86,3 @@ impl PinQuery { Self { arg: cid } } } - -impl ConfVars { - pub fn ipfs_client(&self) -> Result { - let client = reqwest::ClientBuilder::new().user_agent("curl").build()?; - Ok(IpfsClient { - url: self.ipfs_api.clone(), - client, - }) - } -} diff --git a/src/main.rs b/src/main.rs index 90b8248..5e72eef 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,8 +5,9 @@ use axum::{ }; use config::Config; use error::JMError; +use reqwest::{Client, Url}; use sqlx::MySqlPool; -use std::path::PathBuf; +use std::{path::PathBuf, sync::Arc}; use structopt::StructOpt; use tower_http::{add_extension::AddExtensionLayer, set_header::SetResponseHeaderLayer}; @@ -30,6 +31,14 @@ struct Opt { config: PathBuf, } +pub struct JMServiceInner { + client: Client, + ipfs_url: Url, + cdn_url: String, +} + +pub type JMService = Arc; + #[tokio::main] async fn main() -> Result<(), JMError> { let opt = Opt::from_args(); @@ -37,12 +46,13 @@ async fn main() -> Result<(), JMError> { let config = toml::from_slice::(&config)?; let db_pool = MySqlPool::new(&config.database).await?; + let service = config.service()?; let app = Router::new() .nest("/api/v1", v1::routes()) .nest("/cdn", cdn::routes()) .layer(AddExtensionLayer::new(db_pool)) - .layer(AddExtensionLayer::new(config.vars())) + .layer(AddExtensionLayer::new(service)) .layer(SetResponseHeaderLayer::<_, Request>::if_not_present( header::ACCESS_CONTROL_ALLOW_ORIGIN, HeaderValue::from_static("*"), diff --git a/src/v1/routes.rs b/src/v1/routes.rs index 632af7e..2c97e58 100644 --- a/src/v1/routes.rs +++ b/src/v1/routes.rs @@ -1,8 +1,8 @@ -use crate::config::ConfVars; use crate::ipfs::IPFSFile; use crate::lib::ExtractIP; use crate::models::{Category, Meme, MemeFilter, User}; use crate::v1::models::*; +use crate::JMService; use axum::extract::{ContentLengthLimit, Extension, Multipart}; use axum::handler::{get, post}; @@ -18,13 +18,13 @@ use super::Query; async fn meme( Query(params): Query, Extension(db_pool): Extension, - Extension(vars): Extension, + Extension(service): Extension, ) -> Result { let meme = V1Meme::new( Meme::get(params.id, &db_pool) .await? .ok_or_else(|| APIError::NotFound("Meme not found".to_string()))?, - vars.cdn, + service.cdn_url(), ); Ok(Json(MemeResponse { status: 200, @@ -36,12 +36,12 @@ async fn meme( async fn memes( Query(params): Query, Extension(db_pool): Extension, - Extension(vars): Extension, + Extension(service): Extension, ) -> Result { let memes = Meme::get_all(params, &db_pool) .await? .into_iter() - .map(|meme| V1Meme::new(meme, vars.cdn.clone())) + .map(|meme| V1Meme::new(meme, service.cdn_url())) .collect(); Ok(Json(MemesResponse { status: 200, @@ -101,9 +101,9 @@ async fn users(Extension(db_pool): Extension) -> Result, Extension(db_pool): Extension, - Extension(vars): Extension, + Extension(service): Extension, ) -> Result { - let random = V1Meme::new(Meme::get_random(params, &db_pool).await?, vars.cdn); + let random = V1Meme::new(Meme::get_random(params, &db_pool).await?, service.cdn_url()); Ok(Json(MemeResponse { status: 200, error: None, @@ -114,15 +114,13 @@ async fn random( async fn upload( ContentLengthLimit(mut form): ContentLengthLimit, Extension(db_pool): Extension, - Extension(vars): Extension, + Extension(service): Extension, ExtractIP(ip): ExtractIP, ) -> Result { let mut category: Option = None; let mut token: Option = None; let mut files: Vec = vec![]; - let ipfs = vars.ipfs_client()?; - while let Some(field) = form.next_field().await? { match field.name().ok_or_else(|| { APIError::BadRequest("A multipart-form field is missing a name".to_string()) @@ -136,7 +134,7 @@ async fn upload( APIError::BadRequest("A file field has no filename".to_string()) })? .to_string(); - let file = ipfs.add(field.bytes().await?, filename).await?; + let file = service.add(field.bytes().await?, filename).await?; files.push(file); } _ => (), @@ -169,10 +167,10 @@ async fn upload( return Err(APIError::Internal("Database insertion error".to_string())); } - ipfs.pin(f.hash).await?; + service.pin(f.hash).await?; links.push(format!( "{}/{}/{}", - vars.cdn, + service.cdn_url(), user.id.clone(), f.name.clone() ));