Replace ConfVars with JMService
All checks were successful
continuous-integration/drone/push Build is passing
continuous-integration/drone Build is passing

This commit is contained in:
Timo Ley 2022-01-16 22:43:45 +01:00
parent 9133d1ea9e
commit fbca4b7c06
6 changed files with 49 additions and 59 deletions

View file

@ -14,7 +14,7 @@ use reqwest::{
}; };
use sqlx::MySqlPool; use sqlx::MySqlPool;
use crate::config::ConfVars; use crate::JMService;
use self::{ use self::{
error::CDNError, error::CDNError,
@ -36,12 +36,11 @@ pub fn routes() -> Router<BoxRoute> {
async fn image( async fn image(
Path((user, filename)): Path<(String, String)>, Path((user, filename)): Path<(String, String)>,
Extension(db_pool): Extension<MySqlPool>, Extension(db_pool): Extension<MySqlPool>,
Extension(vars): Extension<ConfVars>, Extension(service): Extension<JMService>,
) -> Result<impl IntoResponse, CDNError> { ) -> Result<impl IntoResponse, CDNError> {
let filename = urlencoding::decode(&filename)?.into_owned(); let filename = urlencoding::decode(&filename)?.into_owned();
let cid = sql::get_cid(user, filename.clone(), &db_pool).await?; let cid = sql::get_cid(user, filename.clone(), &db_pool).await?;
let ipfs = vars.ipfs_client()?; let res = service.cat(cid).await?;
let res = ipfs.cat(cid).await?;
let clength = res let clength = res
.headers() .headers()
.get(HeaderName::from_static("x-content-length")) .get(HeaderName::from_static("x-content-length"))
@ -61,12 +60,12 @@ async fn image(
async fn users( async fn users(
Extension(db_pool): Extension<MySqlPool>, Extension(db_pool): Extension<MySqlPool>,
Extension(vars): Extension<ConfVars>, Extension(service): Extension<JMService>,
) -> Result<impl IntoResponse, CDNError> { ) -> Result<impl IntoResponse, CDNError> {
let users = sql::get_users(&db_pool).await?; let users = sql::get_users(&db_pool).await?;
Ok(HtmlTemplate(DirTemplate { Ok(HtmlTemplate(DirTemplate {
entries: users, entries: users,
prefix: vars.cdn, prefix: service.cdn_url(),
suffix: "/".to_string(), suffix: "/".to_string(),
})) }))
} }

View file

@ -1,6 +1,8 @@
use reqwest::Url; use reqwest::Url;
use serde::Deserialize; use serde::Deserialize;
use std::net::SocketAddr; use std::{net::SocketAddr, sync::Arc};
use crate::{error::JMError, JMService, JMServiceInner};
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct Config { pub struct Config {
@ -10,25 +12,19 @@ pub struct Config {
pub ipfs_api: Url, pub ipfs_api: Url,
} }
pub struct ConfVars {
pub cdn: String,
pub ipfs_api: Url,
}
impl Config { impl Config {
pub fn vars(&self) -> ConfVars { pub fn service(&self) -> Result<JMService, JMError> {
ConfVars { let client = reqwest::ClientBuilder::new().user_agent("curl").build()?;
cdn: self.cdn.clone(), Ok(Arc::new(JMServiceInner {
ipfs_api: self.ipfs_api.clone(), client,
} ipfs_url: self.ipfs_api.clone(),
cdn_url: self.cdn.clone(),
}))
} }
} }
impl Clone for ConfVars { impl JMServiceInner {
fn clone(&self) -> Self { pub fn cdn_url(&self) -> String {
Self { self.cdn_url.clone()
cdn: self.cdn.clone(),
ipfs_api: self.ipfs_api.clone(),
}
} }
} }

View file

@ -10,4 +10,6 @@ pub enum JMError {
Database(#[from] sqlx::Error), Database(#[from] sqlx::Error),
#[error("Axum error: {0}")] #[error("Axum error: {0}")]
Axum(#[from] hyper::Error), Axum(#[from] hyper::Error),
#[error("Reqwest error: {0}")]
Reqwest(#[from] reqwest::Error),
} }

View file

@ -3,11 +3,11 @@ use std::time::Duration;
use axum::body::Bytes; use axum::body::Bytes;
use reqwest::{ use reqwest::{
multipart::{Form, Part}, multipart::{Form, Part},
Client, Response, Url, Response,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::config::ConfVars; use crate::JMServiceInner;
use self::error::IPFSError; use self::error::IPFSError;
@ -38,16 +38,11 @@ pub struct PinQuery {
pub arg: String, pub arg: String,
} }
pub struct IpfsClient { impl JMServiceInner {
url: Url,
client: Client,
}
impl IpfsClient {
pub async fn cat(&self, cid: String) -> Result<Response, IPFSError> { pub async fn cat(&self, cid: String) -> Result<Response, IPFSError> {
let request = self let request = self
.client .client
.post(self.url.join("/api/v0/cat")?) .post(self.ipfs_url.join("/api/v0/cat")?)
.query(&CatQuery::new(cid)); .query(&CatQuery::new(cid));
Ok(request.send().await?) Ok(request.send().await?)
} }
@ -55,7 +50,7 @@ impl IpfsClient {
pub async fn add(&self, file: Bytes, filename: String) -> Result<IPFSFile, IPFSError> { pub async fn add(&self, file: Bytes, filename: String) -> Result<IPFSFile, IPFSError> {
let request = self let request = self
.client .client
.post(self.url.join("/api/v0/add")?) .post(self.ipfs_url.join("/api/v0/add")?)
.query(&AddQuery::new(false)) .query(&AddQuery::new(false))
.multipart(Form::new().part("file", Part::stream(file).file_name(filename))); .multipart(Form::new().part("file", Part::stream(file).file_name(filename)));
let response = request.send().await?; let response = request.send().await?;
@ -66,7 +61,7 @@ impl IpfsClient {
pub async fn pin(&self, cid: String) -> Result<(), IPFSError> { pub async fn pin(&self, cid: String) -> Result<(), IPFSError> {
let request = self let request = self
.client .client
.post(self.url.join("/api/v0/pin/add")?) .post(self.ipfs_url.join("/api/v0/pin/add")?)
.query(&PinQuery::new(cid)) .query(&PinQuery::new(cid))
.timeout(Duration::from_secs(60)); .timeout(Duration::from_secs(60));
request.send().await?; request.send().await?;
@ -91,13 +86,3 @@ impl PinQuery {
Self { arg: cid } Self { arg: cid }
} }
} }
impl ConfVars {
pub fn ipfs_client(&self) -> Result<IpfsClient, IPFSError> {
let client = reqwest::ClientBuilder::new().user_agent("curl").build()?;
Ok(IpfsClient {
url: self.ipfs_api.clone(),
client,
})
}
}

View file

@ -5,8 +5,9 @@ use axum::{
}; };
use config::Config; use config::Config;
use error::JMError; use error::JMError;
use reqwest::{Client, Url};
use sqlx::MySqlPool; use sqlx::MySqlPool;
use std::path::PathBuf; use std::{path::PathBuf, sync::Arc};
use structopt::StructOpt; use structopt::StructOpt;
use tower_http::{add_extension::AddExtensionLayer, set_header::SetResponseHeaderLayer}; use tower_http::{add_extension::AddExtensionLayer, set_header::SetResponseHeaderLayer};
@ -30,6 +31,14 @@ struct Opt {
config: PathBuf, config: PathBuf,
} }
pub struct JMServiceInner {
client: Client,
ipfs_url: Url,
cdn_url: String,
}
pub type JMService = Arc<JMServiceInner>;
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), JMError> { async fn main() -> Result<(), JMError> {
let opt = Opt::from_args(); let opt = Opt::from_args();
@ -37,12 +46,13 @@ async fn main() -> Result<(), JMError> {
let config = toml::from_slice::<Config>(&config)?; let config = toml::from_slice::<Config>(&config)?;
let db_pool = MySqlPool::new(&config.database).await?; let db_pool = MySqlPool::new(&config.database).await?;
let service = config.service()?;
let app = Router::new() let app = Router::new()
.nest("/api/v1", v1::routes()) .nest("/api/v1", v1::routes())
.nest("/cdn", cdn::routes()) .nest("/cdn", cdn::routes())
.layer(AddExtensionLayer::new(db_pool)) .layer(AddExtensionLayer::new(db_pool))
.layer(AddExtensionLayer::new(config.vars())) .layer(AddExtensionLayer::new(service))
.layer(SetResponseHeaderLayer::<_, Request<Body>>::if_not_present( .layer(SetResponseHeaderLayer::<_, Request<Body>>::if_not_present(
header::ACCESS_CONTROL_ALLOW_ORIGIN, header::ACCESS_CONTROL_ALLOW_ORIGIN,
HeaderValue::from_static("*"), HeaderValue::from_static("*"),

View file

@ -1,8 +1,8 @@
use crate::config::ConfVars;
use crate::ipfs::IPFSFile; use crate::ipfs::IPFSFile;
use crate::lib::ExtractIP; use crate::lib::ExtractIP;
use crate::models::{Category, Meme, MemeFilter, User}; use crate::models::{Category, Meme, MemeFilter, User};
use crate::v1::models::*; use crate::v1::models::*;
use crate::JMService;
use axum::extract::{ContentLengthLimit, Extension, Multipart}; use axum::extract::{ContentLengthLimit, Extension, Multipart};
use axum::handler::{get, post}; use axum::handler::{get, post};
@ -18,13 +18,13 @@ use super::Query;
async fn meme( async fn meme(
Query(params): Query<MemeIDQuery>, Query(params): Query<MemeIDQuery>,
Extension(db_pool): Extension<MySqlPool>, Extension(db_pool): Extension<MySqlPool>,
Extension(vars): Extension<ConfVars>, Extension(service): Extension<JMService>,
) -> Result<impl IntoResponse, APIError> { ) -> Result<impl IntoResponse, APIError> {
let meme = V1Meme::new( let meme = V1Meme::new(
Meme::get(params.id, &db_pool) Meme::get(params.id, &db_pool)
.await? .await?
.ok_or_else(|| APIError::NotFound("Meme not found".to_string()))?, .ok_or_else(|| APIError::NotFound("Meme not found".to_string()))?,
vars.cdn, service.cdn_url(),
); );
Ok(Json(MemeResponse { Ok(Json(MemeResponse {
status: 200, status: 200,
@ -36,12 +36,12 @@ async fn meme(
async fn memes( async fn memes(
Query(params): Query<MemeFilter>, Query(params): Query<MemeFilter>,
Extension(db_pool): Extension<MySqlPool>, Extension(db_pool): Extension<MySqlPool>,
Extension(vars): Extension<ConfVars>, Extension(service): Extension<JMService>,
) -> Result<impl IntoResponse, APIError> { ) -> Result<impl IntoResponse, APIError> {
let memes = Meme::get_all(params, &db_pool) let memes = Meme::get_all(params, &db_pool)
.await? .await?
.into_iter() .into_iter()
.map(|meme| V1Meme::new(meme, vars.cdn.clone())) .map(|meme| V1Meme::new(meme, service.cdn_url()))
.collect(); .collect();
Ok(Json(MemesResponse { Ok(Json(MemesResponse {
status: 200, status: 200,
@ -101,9 +101,9 @@ async fn users(Extension(db_pool): Extension<MySqlPool>) -> Result<impl IntoResp
async fn random( async fn random(
Query(params): Query<MemeFilter>, Query(params): Query<MemeFilter>,
Extension(db_pool): Extension<MySqlPool>, Extension(db_pool): Extension<MySqlPool>,
Extension(vars): Extension<ConfVars>, Extension(service): Extension<JMService>,
) -> Result<impl IntoResponse, APIError> { ) -> Result<impl IntoResponse, APIError> {
let random = V1Meme::new(Meme::get_random(params, &db_pool).await?, vars.cdn); let random = V1Meme::new(Meme::get_random(params, &db_pool).await?, service.cdn_url());
Ok(Json(MemeResponse { Ok(Json(MemeResponse {
status: 200, status: 200,
error: None, error: None,
@ -114,15 +114,13 @@ async fn random(
async fn upload( async fn upload(
ContentLengthLimit(mut form): ContentLengthLimit<Multipart, { 1024 * 1024 * 1024 }>, ContentLengthLimit(mut form): ContentLengthLimit<Multipart, { 1024 * 1024 * 1024 }>,
Extension(db_pool): Extension<MySqlPool>, Extension(db_pool): Extension<MySqlPool>,
Extension(vars): Extension<ConfVars>, Extension(service): Extension<JMService>,
ExtractIP(ip): ExtractIP, ExtractIP(ip): ExtractIP,
) -> Result<impl IntoResponse, APIError> { ) -> Result<impl IntoResponse, APIError> {
let mut category: Option<String> = None; let mut category: Option<String> = None;
let mut token: Option<String> = None; let mut token: Option<String> = None;
let mut files: Vec<IPFSFile> = vec![]; let mut files: Vec<IPFSFile> = vec![];
let ipfs = vars.ipfs_client()?;
while let Some(field) = form.next_field().await? { while let Some(field) = form.next_field().await? {
match field.name().ok_or_else(|| { match field.name().ok_or_else(|| {
APIError::BadRequest("A multipart-form field is missing a name".to_string()) APIError::BadRequest("A multipart-form field is missing a name".to_string())
@ -136,7 +134,7 @@ async fn upload(
APIError::BadRequest("A file field has no filename".to_string()) APIError::BadRequest("A file field has no filename".to_string())
})? })?
.to_string(); .to_string();
let file = ipfs.add(field.bytes().await?, filename).await?; let file = service.add(field.bytes().await?, filename).await?;
files.push(file); files.push(file);
} }
_ => (), _ => (),
@ -169,10 +167,10 @@ async fn upload(
return Err(APIError::Internal("Database insertion error".to_string())); return Err(APIError::Internal("Database insertion error".to_string()));
} }
ipfs.pin(f.hash).await?; service.pin(f.hash).await?;
links.push(format!( links.push(format!(
"{}/{}/{}", "{}/{}/{}",
vars.cdn, service.cdn_url(),
user.id.clone(), user.id.clone(),
f.name.clone() f.name.clone()
)); ));