diff --git a/src/cdn/error.rs b/src/cdn/error.rs index 8b4f078..97d36f5 100644 --- a/src/cdn/error.rs +++ b/src/cdn/error.rs @@ -5,42 +5,35 @@ use axum::{ response::IntoResponse, }; use reqwest::StatusCode; +use thiserror::Error; -pub struct Error(StatusCode); +use crate::ipfs::error::IPFSError; -impl Error { - pub fn new() -> Self { - Error(StatusCode::INTERNAL_SERVER_ERROR) - } +#[derive(Error, Debug)] +pub enum CDNError { + #[error("SQL error: {0}")] + SQL(#[from] sqlx::Error), + #[error("IPFS error: {0}")] + IPFS(#[from] IPFSError), + #[error("Decode error: {0}")] + Decode(#[from] FromUtf8Error), + #[error("Internal server error")] + Internal, } -impl IntoResponse for Error { +impl IntoResponse for CDNError { type Body = Empty; type BodyError = Infallible; fn into_response(self) -> axum::http::Response { - self.0.into_response() - } -} - -impl From for Error { - fn from(err: sqlx::Error) -> Self { - match err { - sqlx::Error::RowNotFound => Error(StatusCode::NOT_FOUND), - _ => Error(StatusCode::INTERNAL_SERVER_ERROR), - } - } -} - -impl From for Error { - fn from(err: reqwest::Error) -> Self { - Error(err.status().unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)) - } -} - -impl From for Error { - fn from(_: FromUtf8Error) -> Self { - Error(StatusCode::INTERNAL_SERVER_ERROR) + let status = match self { + CDNError::SQL(err) => match err { + sqlx::Error::RowNotFound => StatusCode::NOT_FOUND, + _ => StatusCode::INTERNAL_SERVER_ERROR, + }, + _ => StatusCode::INTERNAL_SERVER_ERROR, + }; + status.into_response() } } diff --git a/src/cdn/mod.rs b/src/cdn/mod.rs index f9a3178..d888cb0 100644 --- a/src/cdn/mod.rs +++ b/src/cdn/mod.rs @@ -17,7 +17,7 @@ use sqlx::MySqlPool; use crate::config::ConfVars; use self::{ - error::Error, + error::CDNError, templates::{DirTemplate, HtmlTemplate}, }; @@ -37,58 +37,48 @@ async fn image( Path((user, filename)): Path<(String, String)>, Extension(db_pool): Extension, Extension(vars): Extension, -) -> Result { +) -> Result { let filename = urlencoding::decode(&filename)?.into_owned(); let cid = sql::get_cid(user, filename.clone(), &db_pool).await?; let ipfs = vars.ipfs_client()?; let res = ipfs.cat(cid).await?; let clength = res .headers() - .get(HeaderName::from_static("x-content-length")); - match clength { - Some(h) => { - let mut headers = HeaderMap::new(); - let ctype = - ContentType::from(new_mime_guess::from_path(filename).first_or_octet_stream()); - headers.typed_insert(ctype); - headers.insert(CONTENT_LENGTH, h.clone()); + .get(HeaderName::from_static("x-content-length")) + .ok_or(CDNError::Internal)?; - Ok(( - StatusCode::OK, - headers, - Body::wrap_stream(res.bytes_stream()), - )) - } - None => Err(Error::new()), - } + let mut headers = HeaderMap::new(); + let ctype = ContentType::from(new_mime_guess::from_path(filename).first_or_octet_stream()); + headers.typed_insert(ctype); + headers.insert(CONTENT_LENGTH, clength.clone()); + + Ok(( + StatusCode::OK, + headers, + Body::wrap_stream(res.bytes_stream()), + )) } async fn users( Extension(db_pool): Extension, Extension(vars): Extension, -) -> Result { - let q = sql::get_users(&db_pool).await; - match q { - Ok(users) => Ok(HtmlTemplate(DirTemplate { - entries: users, - prefix: vars.cdn, - suffix: "/".to_string(), - })), - Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR), - } +) -> Result { + let users = sql::get_users(&db_pool).await?; + Ok(HtmlTemplate(DirTemplate { + entries: users, + prefix: vars.cdn, + suffix: "/".to_string(), + })) } async fn memes( Path(user): Path, Extension(db_pool): Extension, -) -> Result { - let q = sql::get_memes(user, &db_pool).await; - match q { - Ok(memes) => Ok(HtmlTemplate(DirTemplate { - entries: memes, - prefix: ".".to_string(), - suffix: "".to_string(), - })), - Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR), - } +) -> Result { + let memes = sql::get_memes(user, &db_pool).await?; + Ok(HtmlTemplate(DirTemplate { + entries: memes, + prefix: ".".to_string(), + suffix: "".to_string(), + })) } diff --git a/src/error.rs b/src/error.rs index e14bbfc..5a81af8 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,6 +1,5 @@ use thiserror::Error; - #[derive(Error, Debug)] pub enum JMError { #[error("File read error: {0}")] @@ -11,4 +10,4 @@ pub enum JMError { Database(#[from] sqlx::Error), #[error("Axum error: {0}")] Axum(#[from] hyper::Error), -} \ No newline at end of file +} diff --git a/src/ipfs/error.rs b/src/ipfs/error.rs new file mode 100644 index 0000000..2f17a96 --- /dev/null +++ b/src/ipfs/error.rs @@ -0,0 +1,10 @@ +use thiserror::Error; +use url::ParseError; + +#[derive(Error, Debug)] +pub enum IPFSError { + #[error("Reqwest error: {0}")] + Reqwest(#[from] reqwest::Error), + #[error("URL parse error: {0}")] + URL(#[from] ParseError), +} diff --git a/src/ipfs/mod.rs b/src/ipfs/mod.rs index 99ce58b..d47b553 100644 --- a/src/ipfs/mod.rs +++ b/src/ipfs/mod.rs @@ -1,12 +1,20 @@ -use reqwest::{Client, Response, Result, Url}; +use std::time::Duration; + +use axum::{body::Bytes, http::request}; +use reqwest::{ + multipart::{Form, Part}, + Body, Client, Response, Url, +}; use serde::{Deserialize, Serialize}; use crate::config::ConfVars; +use self::error::IPFSError; + +pub(crate) mod error; + #[derive(Deserialize)] -pub struct AddResponse { - #[serde(rename = "Bytes")] - pub bytes: String, +pub struct IPFSFile { #[serde(rename = "Hash")] pub hash: String, #[serde(rename = "Name")] @@ -20,46 +28,76 @@ pub struct CatQuery { pub arg: String, } +#[derive(Serialize)] +pub struct AddQuery { + pub pin: bool, +} + +#[derive(Serialize)] +pub struct PinQuery { + pub arg: String, +} + pub struct IpfsClient { url: Url, client: Client, } impl IpfsClient { - - pub fn cat_url(&self) -> Url { - self.url.join("/api/v0/cat").expect("Something went wrong with the IPFS URL") + pub async fn cat(&self, cid: String) -> Result { + let request = self + .client + .post(self.url.join("/api/v0/cat")?) + .query(&CatQuery::new(cid)); + Ok(request.send().await?) } - pub fn add_url(&self) -> Url { - self.url.join("/api/v0/add").expect("Something went wrong with the IPFS URL") + pub async fn add(&self, file: Bytes, filename: String) -> Result { + let request = self + .client + .post(self.url.join("/api/v0/add")?) + .query(&AddQuery::new(false)) + .multipart(Form::new().part("file", Part::stream(file).file_name(filename))); + let response = request.send().await?; + let res: IPFSFile = response.json().await?; + Ok(res) } - pub async fn cat(&self, cid: String) -> Result { - let request = self.client.post(self.cat_url()).query(&CatQuery::new(cid)); - request.send().await + pub async fn pin(&self, cid: String) -> Result<(), IPFSError> { + let request = self + .client + .post(self.url.join("/api/v0/pin/add")?) + .query(&PinQuery::new(cid)) + .timeout(Duration::from_secs(60)); + let response = request.send().await?; + Ok(()) } - } impl CatQuery { - pub fn new(cid: String) -> Self { - Self { - arg: cid, - } + Self { arg: cid } } +} +impl AddQuery { + pub fn new(pin: bool) -> Self { + Self { pin } + } +} + +impl PinQuery { + pub fn new(cid: String) -> Self { + Self { arg: cid } + } } impl ConfVars { - - pub fn ipfs_client(&self) -> Result { - let client =reqwest::ClientBuilder::new().user_agent("curl").build()?; + pub fn ipfs_client(&self) -> Result { + let client = reqwest::ClientBuilder::new().user_agent("curl").build()?; Ok(IpfsClient { url: self.ipfs_api.clone(), client, }) } - -} \ No newline at end of file +} diff --git a/src/main.rs b/src/main.rs index d035681..af19ae0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,9 +12,9 @@ use tower_http::{add_extension::AddExtensionLayer, set_header::SetResponseHeader mod cdn; mod config; +mod error; mod ipfs; mod v1; -mod error; #[derive(StructOpt)] struct Opt { @@ -33,8 +33,7 @@ async fn main() -> Result<(), JMError> { let config = std::fs::read(&opt.config)?; let config = toml::from_slice::(&config)?; - let db_pool = MySqlPool::new(&config.database) - .await?; + let db_pool = MySqlPool::new(&config.database).await?; let app = Router::new() .nest("/api/v1", v1::routes()) diff --git a/src/v1/error.rs b/src/v1/error.rs index b29a242..b470090 100644 --- a/src/v1/error.rs +++ b/src/v1/error.rs @@ -10,6 +10,7 @@ use reqwest::StatusCode; use thiserror::Error; use super::models::ErrorResponse; +use crate::ipfs::error::IPFSError; #[derive(Error, Debug)] pub enum APIError { @@ -19,6 +20,8 @@ pub enum APIError { Multipart(#[from] MultipartError), #[error("Bad request: {0}")] BadRequest(String), + #[error("IPFS error: {0}")] + IPFS(#[from] IPFSError), } impl ErrorResponse { @@ -44,6 +47,7 @@ impl IntoResponse for APIError { }, APIError::Multipart(_) => ErrorResponse::new(StatusCode::INTERNAL_SERVER_ERROR, None), APIError::BadRequest(err) => ErrorResponse::new(StatusCode::BAD_REQUEST, Some(err)), + APIError::IPFS(_) => ErrorResponse::new(StatusCode::INTERNAL_SERVER_ERROR, None), }; let status = res.status.clone(); (status, Json(res)).into_response()