Middleware Creation

Guide to creating custom middleware.

Table of Contents


Middleware trait

All middleware must implement the Middleware trait.

use async_trait::async_trait;
use reinhardt::{Handler, Middleware, Request, Response, Result};
use std::sync::Arc;

#[async_trait]
pub trait Middleware: Send + Sync {
    /// Process the request and optionally call the next handler
    async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response>;

    /// Determine if this middleware should process the request
    fn should_continue(&self, _request: &Request) -> bool {
        true
    }
}

Basic Middleware

Request Logging Middleware

use async_trait::async_trait;
use reinhardt::{Handler, Middleware, Request, Response, Result};
use std::sync::Arc;

pub struct LoggingMiddleware;

#[async_trait]
impl Middleware for LoggingMiddleware {
    async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
        println!("Request: {} {}", request.method, request.uri.path());

        // Call next handler or middleware
        let response = next.handle(request).await?;

        println!("Response: {}", response.status);
        Ok(response)
    }
}

Custom Header Middleware

use async_trait::async_trait;
use reinhardt::{Handler, Middleware, Request, Response, Result};
use std::sync::Arc;

pub struct CustomHeaderMiddleware {
    pub header_name: String,
    pub header_value: String,
}

impl CustomHeaderMiddleware {
    pub fn new(name: &str, value: &str) -> Self {
        Self {
            header_name: name.to_string(),
            header_value: value.to_string(),
        }
    }
}

#[async_trait]
impl Middleware for CustomHeaderMiddleware {
    async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
        let mut response = next.handle(request).await?;

        // Add custom header to response
        response.headers.insert(
            hyper::header::HeaderName::from_bytes(self.header_name.as_bytes()).unwrap(),
            hyper::header::HeaderValue::from_str(&self.header_value).unwrap(),
        );

        Ok(response)
    }
}

Conditional Execution

Implementing should_continue()

Execute middleware only under certain conditions.

use async_trait::async_trait;
use reinhardt::{Handler, Middleware, Request, Response, Result};
use std::sync::Arc;

pub struct AdminOnlyMiddleware;

#[async_trait]
impl Middleware for AdminOnlyMiddleware {
    fn should_continue(&self, request: &Request) -> bool {
        // Only process /admin/ paths
        request.uri.path().starts_with("/admin/")
    }

    async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
        // Admin check logic
        let auth_header = request.get_header("Authorization");

        if auth_header.is_some() && auth_header.unwrap().starts_with("Bearer admin") {
            next.handle(request).await
        } else {
            Ok(Response::forbidden())
        }
    }
}

Stateful Middleware

Rate Limiting Middleware

use async_trait::async_trait;
use reinhardt::{Handler, Middleware, Request, Response, Result};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};

pub struct RateLimiter {
    requests: Arc<Mutex<HashMap<String, Vec<Instant>>>>,
    max_requests: usize,
    window: Duration,
}

impl RateLimiter {
    pub fn new(max_requests: usize, window: Duration) -> Self {
        Self {
            requests: Arc::new(Mutex::new(HashMap::new())),
            max_requests,
            window,
        }
    }

    fn check_rate_limit(&self, key: &str) -> bool {
        let mut requests = self.requests.lock().unwrap();
        let now = Instant::now();
        let entry = requests.entry(key.to_string()).or_insert_with(Vec::new);

        // Remove old requests
        entry.retain(|&t| now.duration_since(t) < self.window);

        if entry.len() >= self.max_requests {
            false
        } else {
            entry.push(now);
            true
        }
    }
}

#[async_trait]
impl Middleware for RateLimiter {
    fn should_continue(&self, _request: &Request) -> bool {
        true
    }

    async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
        let client_ip = request.get_client_ip()
            .map(|ip| ip.to_string())
            .unwrap_or_else(|| "unknown".to_string());

        if self.check_rate_limit(&client_ip) {
            next.handle(request).await
        } else {
            Ok(Response::internal_server_error()
                .with_body("Rate limit exceeded")
                .with_stop_chain(true))
        }
    }
}

Middleware Ordering

Middleware execution order matters. A typical recommended order:

  1. RequestIdMiddleware - Generate request ID first
  2. LoggingMiddleware - Log all requests
  3. TracingMiddleware - Start tracing span
  4. SecurityMiddleware - Apply security headers
  5. CorsMiddleware - Handle CORS preflight
  6. SessionMiddleware - Load session
  7. AuthenticationMiddleware - Authenticate user
  8. CsrfMiddleware - Validate CSRF token
  9. RateLimitMiddleware - Apply rate limits
  10. Application handlers
use reinhardt::ServerRouter;
use reinhardt::{
    LoggingMiddleware, SecurityMiddleware, CorsMiddleware,
    SessionMiddleware, CsrfMiddleware, RateLimitMiddleware
};

let router = ServerRouter::new()
    .with_middleware(LoggingMiddleware::new())
    .with_middleware(SecurityMiddleware::new())
    .with_middleware(CorsMiddleware::permissive())
    .with_middleware(SessionMiddleware::new(store))
    .with_middleware(CsrfMiddleware::default())
    .with_middleware(RateLimitMiddleware::new(strategy, store));

Available Middleware

Reinhardt includes 30+ built-in middleware components.

Authentication & Authorization

MiddlewareDescription
AuthenticationMiddlewareSession-based user authentication

Security

MiddlewareDescription
CorsMiddlewareCross-Origin Resource Sharing
CsrfMiddlewareCSRF token protection
CspMiddlewareContent Security Policy headers
XFrameOptionsMiddlewareClickjacking protection
HttpsRedirectMiddlewareForce HTTPS connections
SecurityMiddlewareCombined security headers

Performance

MiddlewareDescription
CacheMiddlewareHTTP response caching
GZipMiddlewareGzip compression
BrotliMiddlewareBrotli compression
ETagMiddlewareETag generation and validation
ConditionalGetMiddlewareConditional GET support

Observability

MiddlewareDescription
LoggingMiddlewareRequest/response logging
TracingMiddlewareDistributed tracing
MetricsMiddlewarePerformance metrics collection
RequestIdMiddlewareUnique request ID generation

Rate Limiting

MiddlewareDescription
RateLimitMiddlewareAPI rate limiting

Resilience

MiddlewareDescription
CircuitBreakerMiddlewareCircuit breaker pattern
TimeoutMiddlewareRequest timeout handling

Session & State

MiddlewareDescription
SessionMiddlewareSession management
SiteMiddlewareMulti-site support
LocaleMiddlewareInternationalization and locale detection

Utility

MiddlewareDescription
CommonMiddlewareCommon HTTP functionality
BrokenLinkEmailsMiddlewareBroken link notification
FlatpagesMiddlewareStatic page serving from database

See Also