diff --git a/Cargo.lock b/Cargo.lock index 1b08c7d..d4b681c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -295,6 +295,7 @@ name = "libslonk" version = "0.1.0" dependencies = [ "axum", + "serde", "tokio", "tower-http", "tracing", diff --git a/Cargo.toml b/Cargo.toml index 208f2eb..4745b3c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,13 +4,15 @@ version = "0.1.0" edition = "2024" [features] -default = ['macros', 'request-id'] +default = ['macros', 'request-id', 'filtered-path'] macros = [] request-id = ['dep:axum', 'dep:ulid', 'ulid/std'] +filtered-path = ['dep:axum', 'dep:serde'] [dependencies] -axum = { version = "0.8.1", optional = true, default-features = false } -ulid = { version = "1.1.4", optional = true, default-features = false } +axum = { version = "0.8.1", default-features = false, optional = true } +serde = { version = "1.0.217", default-features = false, optional = true } +ulid = { version = "1.1.4", default-features = false, optional = true } [dev-dependencies] axum = "0.8.1" diff --git a/examples/web.rs b/examples/web.rs index b973421..ce875b5 100644 --- a/examples/web.rs +++ b/examples/web.rs @@ -1,12 +1,15 @@ -use axum::extract::Path; use axum::middleware::from_fn; use axum::routing::get; use axum::{Extension, Router}; +use libslonk::filtered_path::{FilteredPath, SlashFilter}; use libslonk::{request_id, trace_layer_with_ulid}; use tokio::net::TcpListener; use ulid::Ulid; -async fn say_hello(Extension(id): Extension, Path(name): Path) -> String { +async fn say_hello( + Extension(id): Extension, + FilteredPath(name, ..): FilteredPath, +) -> String { format!("Hello, {name}. Your request has the ULID: {id}\n") } diff --git a/src/filtered_path.rs b/src/filtered_path.rs new file mode 100644 index 0000000..eddfd16 --- /dev/null +++ b/src/filtered_path.rs @@ -0,0 +1,107 @@ +use std::marker::PhantomData; + +use axum::extract::rejection::PathRejection; +use axum::extract::{FromRequestParts, Path}; +use axum::http::StatusCode; +use axum::http::request::Parts; +use axum::response::IntoResponse; +use serde::de::DeserializeOwned; + +/// Apply filters (specified at compile time) over the path, reject if it doesn't match. +pub struct FilteredPath(pub T, pub PhantomData); + +#[repr(u8)] +pub enum FilteredPathRejection { + FilterRejection(&'static str), + PathRejection(PathRejection), +} + +impl IntoResponse for FilteredPathRejection { + fn into_response(self) -> axum::response::Response { + match self { + Self::FilterRejection(s) => (StatusCode::BAD_REQUEST, s).into_response(), + Self::PathRejection(v) => v.into_response(), + } + } +} + +impl From for FilteredPathRejection { + fn from(value: PathRejection) -> Self { + Self::PathRejection(value) + } +} + +impl FromRequestParts for FilteredPath +where + T: AsRef, + F: Filter, + T: DeserializeOwned + Send, + S: Send + Sync, +{ + type Rejection = FilteredPathRejection; + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let path: Path = Path::from_request_parts(parts, state).await?; + F::apply(path.0.as_ref()) + .map(|_| Self(path.0, PhantomData)) + .map_err(FilteredPathRejection::FilterRejection) + } +} + +/// Trait to apply a filter over request paths extracted with [`axum::extract::Path`]. +/// +/// [`axum::extract::Path`]: https://docs.rs/axum/0.8.1/axum/extract/struct.Path.html +pub trait Filter { + /// `Ok(())` if the filter passed, `Err(/*reason*/)` if it failed. + fn apply(path: &str) -> Result<(), &'static str>; +} + +/// Do not allow any forward slashes ('/'). +pub struct SlashFilter; +/// Do not allow '..' as a path component (between forward slashes). +pub struct DotDotFilter; + +impl Filter for SlashFilter { + fn apply(path: &str) -> Result<(), &'static str> { + if path.contains('/') { + Err("path contains a slash") + } else { + Ok(()) + } + } +} + +impl Filter for DotDotFilter { + fn apply(path: &str) -> Result<(), &'static str> { + if path.split('/').any(|p| p == "..") { + Err("path contains '..'") + } else { + Ok(()) + } + } +} + +/// Convenience type alias for `FilteredPath`. +pub type SafePath = FilteredPath; + +macro_rules! impl_filter_tuple { + ($($a:ident),+) => { + impl<$($a),+> Filter for ($($a),+,) + where $($a: Filter),+ { + fn apply(path: &str) -> Result<(), &'static str> { + $($a::apply(path)?;)+ + Ok(()) + } + } + }; +} + +macro_rules! impl_filter_tuple_for { + ($lhs:ident $(, $rhs:ident)* ) => { + impl_filter_tuple!($lhs $(, $rhs)*); + impl_filter_tuple_for!($($rhs),*); + }; + () => {}; // implemented all +} + +impl_filter_tuple_for!(T14, T13, T12, T11, T10, T9, T8, T7, T6, T5, T4, T3, T2, T1); diff --git a/src/lib.rs b/src/lib.rs index e9bb263..d44dcc2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,11 @@ +#[cfg(feature = "filtered-path")] +pub mod filtered_path; #[cfg(feature = "macros")] mod macros; #[cfg(feature = "request-id")] mod request_id; +#[cfg(feature = "filtered-path")] +pub use crate::filtered_path::FilteredPath; #[cfg(feature = "request-id")] pub use crate::request_id::*;