From 1d724b69722b605c158c3e2def38d86e412238a5 Mon Sep 17 00:00:00 2001 From: slonkazoid Date: Tue, 4 Feb 2025 23:01:02 +0300 Subject: [PATCH] proper regex --- src/config.rs | 31 ++++++++++++++++++++++--------- src/main.rs | 5 +++-- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/src/config.rs b/src/config.rs index 5596d6f..6d9d35a 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,7 +1,6 @@ use std::collections::HashMap; use std::env; use std::net::{IpAddr, Ipv4Addr}; -use std::ops::Deref; use std::path::PathBuf; use color_eyre::eyre::{self, bail, Context}; @@ -11,13 +10,18 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tracing::{error, info, instrument}; #[derive(Deserialize, Serialize, Debug)] -pub struct DomainMatcher(#[serde(with = "regex")] Regex); +#[serde(untagged)] +pub enum DomainMatcher { + Regex(#[serde(with = "regex")] Regex), + Exact(String), +} -impl Deref for DomainMatcher { - type Target = Regex; - - fn deref(&self) -> &Self::Target { - &self.0 +impl DomainMatcher { + pub fn is_match(&self, value: impl AsRef) -> bool { + match self { + Self::Regex(regex) => regex.is_match(value.as_ref()).unwrap_or(false), + Self::Exact(exact) => exact == value.as_ref(), + } } } @@ -134,15 +138,24 @@ mod regex { impl Visitor<'_> for RegexVisitor { type Value = Regex; - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(formatter, "a regex string") + fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "a regex string starting and ending with slashes ('/')") } fn visit_str(self, v: &str) -> Result where E: serde::de::Error, { + let v = v + .strip_prefix('/') + .and_then(|v| v.strip_suffix('/')) + .ok_or(serde::de::Error::invalid_value( + serde::de::Unexpected::Str(v), + &self, + ))?; + let regex = Regex::new(v); + regex.map_err(|err| serde::de::Error::custom(err)) } } diff --git a/src/main.rs b/src/main.rs index b2c6854..bb2a648 100644 --- a/src/main.rs +++ b/src/main.rs @@ -53,7 +53,8 @@ async fn update_repo_handler( if repo.secret.as_ref().is_some_and(|secret| { headers - .get(&secret.header).is_none_or(|header| header != &secret.value) + .get(&secret.header) + .is_none_or(|header| header != &secret.value) }) { Ok(StatusCode::UNAUTHORIZED) } else { @@ -77,7 +78,7 @@ async fn ask( Query(AskQuery { domain }): Query, ) -> StatusCode { for matcher in &config.domains { - if matcher.is_match(&domain).is_ok_and(|x| x) { + if matcher.is_match(&domain) { return StatusCode::OK; } }