From 0ded637b4a6b885fc8d8015baeaaf1534b6b1d29 Mon Sep 17 00:00:00 2001 From: Jonas Platte Date: Thu, 29 Jun 2023 11:20:52 +0200 Subject: [PATCH] Upgrade axum to 0.6 --- Cargo.lock | 54 +++++++------- Cargo.toml | 2 +- src/api/ruma_wrapper/axum.rs | 139 +++++++++++++++++++++++++---------- src/main.rs | 7 +- 4 files changed, 130 insertions(+), 72 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9f62c18f..41483941 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -89,9 +89,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "axum" -version = "0.5.17" +version = "0.6.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "acee9fd5073ab6b045a275b3e709c163dd36c90685219cb21804a147b58dba43" +checksum = "f8175979259124331c1d7bf6586ee7e0da434155e4b2d48ec2c8386281d8df39" dependencies = [ "async-trait", "axum-core", @@ -108,22 +108,22 @@ dependencies = [ "mime", "percent-encoding", "pin-project-lite", + "rustversion", "serde", "serde_json", + "serde_path_to_error", "serde_urlencoded", "sync_wrapper", - "tokio", "tower", - "tower-http 0.3.5", "tower-layer", "tower-service", ] [[package]] name = "axum-core" -version = "0.2.9" +version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37e5939e02c56fecd5c017c37df4238c0a839fa76b7f97acdd7efb804fd181cc" +checksum = "759fa577a247914fd3f7f76d62972792636412fbfd634cd452f6a385a74d2d2c" dependencies = [ "async-trait", "bytes", @@ -131,6 +131,7 @@ dependencies = [ "http", "http-body", "mime", + "rustversion", "tower-layer", "tower-service", ] @@ -407,7 +408,7 @@ dependencies = [ "tikv-jemallocator", "tokio", "tower", - "tower-http 0.4.1", + "tower-http", "tracing", "tracing-flame", "tracing-opentelemetry", @@ -1449,9 +1450,9 @@ checksum = "2532096657941c2fea9c289d370a250971c689d4f143798ff67113ec042024a5" [[package]] name = "matchit" -version = "0.5.0" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "73cbba799671b762df5a175adf59ce145165747bb891505c43d09aefbbf38beb" +checksum = "b87248edafb776e59e6ee64a79086f65890d3510f2c656c000bf2a7e8a0aea40" [[package]] name = "memchr" @@ -2363,6 +2364,12 @@ dependencies = [ "untrusted", ] +[[package]] +name = "rustversion" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f3208ce4d8448b3f3e7d168a73f5e0c43a61e32930de3bceeccedb388b6bf06" + [[package]] name = "ryu" version = "1.0.13" @@ -2467,6 +2474,15 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7f05c1d5476066defcdfacce1f52fc3cae3af1d3089727100c02ae92e5abbe0" +dependencies = [ + "serde", +] + [[package]] name = "serde_spanned" version = "0.6.3" @@ -2954,31 +2970,11 @@ dependencies = [ "futures-util", "pin-project", "pin-project-lite", - "tokio", "tower-layer", "tower-service", "tracing", ] -[[package]] -name = "tower-http" -version = "0.3.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f873044bf02dd1e8239e9c1293ea39dad76dc594ec16185d0a1bf31d8dc8d858" -dependencies = [ - "bitflags 1.3.2", - "bytes", - "futures-core", - "futures-util", - "http", - "http-body", - "http-range-header", - "pin-project-lite", - "tower", - "tower-layer", - "tower-service", -] - [[package]] name = "tower-http" version = "0.4.1" diff --git a/Cargo.toml b/Cargo.toml index 9698caff..424007c1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,7 +19,7 @@ rust-version = "1.70.0" [dependencies] # Web framework -axum = { version = "0.5.16", default-features = false, features = ["form", "headers", "http1", "http2", "json", "matched-path"], optional = true } +axum = { version = "0.6.18", default-features = false, features = ["form", "headers", "http1", "http2", "json", "matched-path"], optional = true } axum-server = { version = "0.5.1", features = ["tls-rustls"] } tower = { version = "0.4.13", features = ["util"] } tower-http = { version = "0.4.1", features = ["add-extension", "cors", "sensitive-headers", "trace", "util"] } diff --git a/src/api/ruma_wrapper/axum.rs b/src/api/ruma_wrapper/axum.rs index 2d2af705..069e12b3 100644 --- a/src/api/ruma_wrapper/axum.rs +++ b/src/api/ruma_wrapper/axum.rs @@ -3,18 +3,16 @@ use std::{collections::BTreeMap, iter::FromIterator, str}; use axum::{ async_trait, body::{Full, HttpBody}, - extract::{ - rejection::TypedHeaderRejectionReason, FromRequest, Path, RequestParts, TypedHeader, - }, + extract::{rejection::TypedHeaderRejectionReason, FromRequest, Path, TypedHeader}, headers::{ authorization::{Bearer, Credentials}, Authorization, }, response::{IntoResponse, Response}, - BoxError, + BoxError, RequestExt, RequestPartsExt, }; -use bytes::{BufMut, Bytes, BytesMut}; -use http::StatusCode; +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use http::{Request, StatusCode}; use ruma::{ api::{client::error::ErrorKind, AuthScheme, IncomingRequest, OutgoingResponse}, CanonicalJsonValue, OwnedDeviceId, OwnedServerName, UserId, @@ -26,27 +24,44 @@ use super::{Ruma, RumaResponse}; use crate::{services, Error, Result}; #[async_trait] -impl FromRequest for Ruma +impl FromRequest for Ruma where T: IncomingRequest, - B: HttpBody + Send, + B: HttpBody + Send + 'static, B::Data: Send, B::Error: Into, { type Rejection = Error; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: Request, _state: &S) -> Result { #[derive(Deserialize)] struct QueryParams { access_token: Option, user_id: Option, } - let metadata = T::METADATA; - let auth_header = Option::>>::from_request(req).await?; - let path_params = Path::>::from_request(req).await?; + let (mut parts, mut body) = match req.with_limited_body() { + Ok(limited_req) => { + let (parts, body) = limited_req.into_parts(); + let body = to_bytes(body) + .await + .map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?; + (parts, body) + } + Err(original_req) => { + let (parts, body) = original_req.into_parts(); + let body = to_bytes(body) + .await + .map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?; + (parts, body) + } + }; - let query = req.uri().query().unwrap_or_default(); + let metadata = T::METADATA; + let auth_header: Option>> = parts.extract().await?; + let path_params: Path> = parts.extract().await?; + + let query = parts.uri.query().unwrap_or_default(); let query_params: QueryParams = match serde_html_form::from_str(query) { Ok(params) => params, Err(e) => { @@ -63,10 +78,6 @@ where None => query_params.access_token.as_deref(), }; - let mut body = Bytes::from_request(req) - .await - .map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?; - let mut json_body = serde_json::from_slice::(&body).ok(); let appservices = services().appservice.all().unwrap(); @@ -138,24 +149,24 @@ where } } AuthScheme::ServerSignatures => { - let TypedHeader(Authorization(x_matrix)) = - TypedHeader::>::from_request(req) - .await - .map_err(|e| { - warn!("Missing or invalid Authorization header: {}", e); + let TypedHeader(Authorization(x_matrix)) = parts + .extract::>>() + .await + .map_err(|e| { + warn!("Missing or invalid Authorization header: {}", e); - let msg = match e.reason() { - TypedHeaderRejectionReason::Missing => { - "Missing Authorization header." - } - TypedHeaderRejectionReason::Error(_) => { - "Invalid X-Matrix signatures." - } - _ => "Unknown header-related error", - }; + let msg = match e.reason() { + TypedHeaderRejectionReason::Missing => { + "Missing Authorization header." + } + TypedHeaderRejectionReason::Error(_) => { + "Invalid X-Matrix signatures." + } + _ => "Unknown header-related error", + }; - Error::BadRequest(ErrorKind::Forbidden, msg) - })?; + Error::BadRequest(ErrorKind::Forbidden, msg) + })?; let origin_signatures = BTreeMap::from_iter([( x_matrix.key.clone(), @@ -170,11 +181,11 @@ where let mut request_map = BTreeMap::from_iter([ ( "method".to_owned(), - CanonicalJsonValue::String(req.method().to_string()), + CanonicalJsonValue::String(parts.method.to_string()), ), ( "uri".to_owned(), - CanonicalJsonValue::String(req.uri().to_string()), + CanonicalJsonValue::String(parts.uri.to_string()), ), ( "origin".to_owned(), @@ -224,7 +235,7 @@ where x_matrix.origin, e, request_map ); - if req.uri().to_string().contains('@') { + if parts.uri.to_string().contains('@') { warn!( "Request uri contained '@' character. Make sure your \ reverse proxy gives Conduit the raw uri (apache: use \ @@ -243,8 +254,8 @@ where } }; - let mut http_request = http::Request::builder().uri(req.uri()).method(req.method()); - *http_request.headers_mut().unwrap() = req.headers().clone(); + let mut http_request = http::Request::builder().uri(parts.uri).method(parts.method); + *http_request.headers_mut().unwrap() = parts.headers; if let Some(CanonicalJsonValue::Object(json_body)) = &mut json_body { let user_id = sender_user.clone().unwrap_or_else(|| { @@ -364,3 +375,55 @@ impl IntoResponse for RumaResponse { } } } + +// copied from hyper under the following license: +// Copyright (c) 2014-2021 Sean McArthur + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: + +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. + +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. +pub(crate) async fn to_bytes(body: T) -> Result +where + T: HttpBody, +{ + futures_util::pin_mut!(body); + + // If there's only 1 chunk, we can just return Buf::to_bytes() + let mut first = if let Some(buf) = body.data().await { + buf? + } else { + return Ok(Bytes::new()); + }; + + let second = if let Some(buf) = body.data().await { + buf? + } else { + return Ok(first.copy_to_bytes(first.remaining())); + }; + + // With more than 1 buf, we gotta flatten into a Vec first. + let cap = first.remaining() + second.remaining() + body.size_hint().lower() as usize; + let mut vec = Vec::with_capacity(cap); + vec.put(first); + vec.put(second); + + while let Some(buf) = body.data().await { + vec.put(buf?); + } + + Ok(vec.into()) +} diff --git a/src/main.rs b/src/main.rs index f9f88f49..e0f84d9d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,8 +10,7 @@ use std::{future::Future, io, net::SocketAddr, sync::atomic, time::Duration}; use axum::{ - extract::{DefaultBodyLimit, FromRequest, MatchedPath}, - handler::Handler, + extract::{DefaultBodyLimit, FromRequestParts, MatchedPath}, response::IntoResponse, routing::{get, on, MethodFilter}, Router, @@ -421,7 +420,7 @@ fn routes() -> Router { "/_matrix/client/v3/rooms/:room_id/initialSync", get(initial_sync), ) - .fallback(not_found.into_service()) + .fallback(not_found) } async fn shutdown_signal(handle: ServerHandle) { @@ -505,7 +504,7 @@ macro_rules! impl_ruma_handler { Fut: Future> + Send, E: IntoResponse, - $( $ty: FromRequest + Send + 'static, )* + $( $ty: FromRequestParts<()> + Send + 'static, )* { fn add_to_router(self, mut router: Router) -> Router { let meta = Req::METADATA;