From 21ae63d46b744a73a3497dddde2e336993981b38 Mon Sep 17 00:00:00 2001 From: Jonas Platte Date: Wed, 9 Feb 2022 12:32:18 +0100 Subject: [PATCH] Rewrite query parameter parsing --- src/ruma_wrapper/axum.rs | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/src/ruma_wrapper/axum.rs b/src/ruma_wrapper/axum.rs index d2cf3f15..cec82126 100644 --- a/src/ruma_wrapper/axum.rs +++ b/src/ruma_wrapper/axum.rs @@ -18,7 +18,8 @@ use ruma::{ signatures::CanonicalJsonValue, DeviceId, Outgoing, ServerName, UserId, }; -use tracing::{debug, warn}; +use serde::Deserialize; +use tracing::{debug, error, warn}; use super::{Ruma, RumaResponse}; use crate::{database::DatabaseGuard, server_server, Error, Result}; @@ -35,18 +36,31 @@ where type Rejection = Error; async fn from_request(req: &mut RequestParts) -> Result { + #[derive(Deserialize)] + struct QueryParams { + access_token: Option, + user_id: Option, + } + let metadata = T::Incoming::METADATA; let db = DatabaseGuard::from_request(req).await?; let auth_header = Option::>>::from_request(req).await?; - // FIXME: Do this more efficiently - let query: BTreeMap = - ruma::serde::urlencoded::from_str(req.uri().query().unwrap_or_default()) - .expect("Query to string map deserialization should be fine"); + let query = req.uri().query().unwrap_or_default(); + let query_params: QueryParams = match ruma::serde::urlencoded::from_str(query) { + Ok(params) => params, + Err(e) => { + error!(%query, "Failed to deserialize query parameters: {}", e); + return Err(Error::BadRequest( + ErrorKind::Unknown, + "Failed to read query parameters", + )); + } + }; let token = match &auth_header { Some(TypedHeader(Authorization(bearer))) => Some(bearer.token()), - None => query.get("access_token").map(|tok| tok.as_str()), + None => query_params.access_token.as_deref(), }; let mut body = Bytes::from_request(req) @@ -67,7 +81,7 @@ where if let Some((_id, registration)) = appservice_registration { match metadata.authentication { AuthScheme::AccessToken | AuthScheme::QueryOnlyAccessToken => { - let user_id = query.get("user_id").map_or_else( + let user_id = query_params.user_id.map_or_else( || { UserId::parse_with_server_name( registration @@ -79,7 +93,7 @@ where ) .unwrap() }, - |s| UserId::parse(s.as_str()).unwrap(), + |s| UserId::parse(s).unwrap(), ); if !db.users.exists(&user_id).unwrap() {