Files
mailgosend/internal/middleware/security.go
T
2026-05-24 17:15:48 +00:00

166 lines
4.3 KiB
Go

// Package middleware provides reusable HTTP middleware for security headers
// and in-memory rate limiting (token bucket per IP).
package middleware
import (
"net"
"net/http"
"strings"
"sync"
"time"
)
// ---- Security Headers ----
// SecureHeaders wraps an http.Handler and adds security-relevant response headers.
// The CSP is intentionally strict: scripts only from same-origin CDN; no inline scripts.
func SecureHeaders(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h := w.Header()
h.Set("Strict-Transport-Security", "max-age=63072000; includeSubDomains; preload")
h.Set("X-Content-Type-Options", "nosniff")
h.Set("X-Frame-Options", "DENY")
h.Set("Referrer-Policy", "strict-origin-when-cross-origin")
h.Set("Permissions-Policy", "camera=(), microphone=(), geolocation=(), payment=()")
h.Set("Content-Security-Policy",
"default-src 'self'; "+
"script-src 'self' https://cdn.tailwindcss.com; "+
"style-src 'self' 'unsafe-inline' https://cdn.tailwindcss.com; "+
"img-src 'self' data:; "+
"font-src 'self'; "+
"frame-src 'self'; "+
"object-src 'none'; "+
"base-uri 'self'; "+
"form-action 'self'")
next.ServeHTTP(w, r)
})
}
// ---- Token-bucket rate limiter ----
// bucket is one client's token state.
type bucket struct {
tokens float64
lastSeen time.Time
}
// RateLimiter is an in-memory token-bucket limiter keyed by IP.
// Safe for concurrent use.
type RateLimiter struct {
mu sync.Mutex
buckets map[string]*bucket
rate float64 // tokens refilled per second
capacity float64 // max token count
status int // HTTP status on exceeded (default 429)
}
// NewRateLimiter creates a limiter.
// - ratePerMin: tokens refilled per minute (e.g. 20 = 20 req/min steady-state)
// - burst: max burst size (e.g. 5 = 5 simultaneous requests)
func NewRateLimiter(ratePerMin int, burst int) *RateLimiter {
rl := &RateLimiter{
buckets: make(map[string]*bucket),
rate: float64(ratePerMin) / 60.0,
capacity: float64(burst),
status: http.StatusTooManyRequests,
}
// Periodic cleanup goroutine (runs for process lifetime).
go rl.cleanup()
return rl
}
// Allow returns true if the request is within limit, consuming one token.
func (rl *RateLimiter) Allow(ip string) bool {
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
b, ok := rl.buckets[ip]
if !ok {
b = &bucket{tokens: rl.capacity, lastSeen: now}
rl.buckets[ip] = b
}
// Refill tokens based on elapsed time.
elapsed := now.Sub(b.lastSeen).Seconds()
b.tokens += elapsed * rl.rate
if b.tokens > rl.capacity {
b.tokens = rl.capacity
}
b.lastSeen = now
if b.tokens < 1 {
return false
}
b.tokens--
return true
}
// Middleware returns an http.Handler that rate-limits by remote IP.
func (rl *RateLimiter) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip := remoteIP(r)
if !rl.Allow(ip) {
w.Header().Set("Retry-After", "60")
http.Error(w, http.StatusText(rl.status), rl.status)
return
}
next.ServeHTTP(w, r)
})
}
// cleanup removes stale buckets every 5 minutes (no activity for >10 min).
func (rl *RateLimiter) cleanup() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
cutoff := time.Now().Add(-10 * time.Minute)
rl.mu.Lock()
for ip, b := range rl.buckets {
if b.lastSeen.Before(cutoff) {
delete(rl.buckets, ip)
}
}
rl.mu.Unlock()
}
}
// remoteIP extracts the real client IP, honoring X-Real-IP and X-Forwarded-For.
// Falls back to RemoteAddr.
func remoteIP(r *http.Request) string {
if xri := r.Header.Get("X-Real-IP"); xri != "" {
if ip := net.ParseIP(xri); ip != nil {
return ip.String()
}
}
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
// First address in the chain is the original client.
for _, part := range splitComma(xff) {
trimmed := strings.TrimSpace(part)
if ip := net.ParseIP(trimmed); ip != nil {
return ip.String()
}
}
}
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
}
return host
}
// splitComma splits on comma without allocating a regex.
func splitComma(s string) []string {
var out []string
start := 0
for i := 0; i < len(s); i++ {
if s[i] == ',' {
out = append(out, s[start:i])
start = i + 1
}
}
out = append(out, s[start:])
return out
}