diff --git a/CHANGELOG.md b/CHANGELOG.md index 4f59f3b..327a82b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,2 +1 @@ -- Port API to Rust -- Use IPFS as CDN \ No newline at end of file +- Return API error response on missing query parameter \ No newline at end of file diff --git a/src/v1/error.rs b/src/v1/error.rs index 648a135..c59fa3f 100644 --- a/src/v1/error.rs +++ b/src/v1/error.rs @@ -2,7 +2,7 @@ use std::convert::Infallible; use axum::{ body::{Bytes, Full}, - extract::multipart::MultipartError, + extract::{multipart::MultipartError, rejection::QueryRejection}, response::IntoResponse, Json, }; @@ -28,6 +28,8 @@ pub enum APIError { Internal(String), #[error("IPFS error: {0}")] Ipfs(#[from] IPFSError), + #[error("Query rejection: {0}")] + Query(#[from] QueryRejection), } impl ErrorResponse { @@ -59,6 +61,7 @@ impl IntoResponse for APIError { ErrorResponse::new(StatusCode::INTERNAL_SERVER_ERROR, Some(err)) } APIError::Ipfs(_) => ErrorResponse::new(StatusCode::INTERNAL_SERVER_ERROR, None), + APIError::Query(_) => ErrorResponse::new(StatusCode::BAD_REQUEST, None), }; let status = res.status; (status, Json(res)).into_response() diff --git a/src/v1/mod.rs b/src/v1/mod.rs index db74d77..40964c5 100644 --- a/src/v1/mod.rs +++ b/src/v1/mod.rs @@ -3,4 +3,25 @@ pub mod models; mod routes; mod sql; +use async_trait::async_trait; +use axum::extract::{FromRequest, RequestParts}; pub use routes::routes; +use serde::de::DeserializeOwned; + +use self::error::APIError; + +pub struct Query(pub T); + +#[async_trait] +impl FromRequest for Query +where + T: DeserializeOwned, + B: Send, +{ + type Rejection = APIError; + + async fn from_request(req: &mut RequestParts) -> Result { + let query = axum::extract::Query::::from_request(req).await?; + Ok(Self(query.0)) + } +} diff --git a/src/v1/routes.rs b/src/v1/routes.rs index afa59f1..5e0f431 100644 --- a/src/v1/routes.rs +++ b/src/v1/routes.rs @@ -3,7 +3,7 @@ use crate::ipfs::IPFSFile; use crate::lib::ExtractIP; use crate::v1::models::*; -use axum::extract::{ContentLengthLimit, Extension, Multipart, Query}; +use axum::extract::{ContentLengthLimit, Extension, Multipart}; use axum::handler::{get, post}; use axum::response::IntoResponse; use axum::routing::BoxRoute; @@ -12,9 +12,10 @@ use hyper::StatusCode; use sqlx::MySqlPool; use super::error::APIError; +use super::Query; async fn meme( - params: Query, + Query(params): Query, Extension(db_pool): Extension, Extension(vars): Extension, ) -> Result { @@ -27,11 +28,11 @@ async fn meme( } async fn memes( - params: Query, + Query(params): Query, Extension(db_pool): Extension, Extension(vars): Extension, ) -> Result { - let memes = Meme::get_all(params.0, &db_pool, vars.cdn).await?; + let memes = Meme::get_all(params, &db_pool, vars.cdn).await?; Ok(Json(MemesResponse { status: 200, error: None, @@ -40,7 +41,7 @@ async fn memes( } async fn category( - params: Query, + Query(params): Query, Extension(db_pool): Extension, ) -> Result { let category = Category::get(¶ms.id, &db_pool).await?; @@ -63,10 +64,10 @@ async fn categories( } async fn user( - params: Query, + Query(params): Query, Extension(db_pool): Extension, ) -> Result { - let user = User::get(params.0, &db_pool).await?; + let user = User::get(params, &db_pool).await?; Ok(Json(UserResponse { status: 200, error: None, @@ -84,11 +85,11 @@ async fn users(Extension(db_pool): Extension) -> Result, + Query(params): Query, Extension(db_pool): Extension, Extension(vars): Extension, ) -> Result { - let random = Meme::get_random(params.0, &db_pool, vars.cdn).await?; + let random = Meme::get_random(params, &db_pool, vars.cdn).await?; Ok(Json(MemeResponse { status: 200, error: None,