166 lines
4.3 KiB
Go
166 lines
4.3 KiB
Go
// Package middleware provides reusable HTTP middleware for security headers
|
|
// and in-memory rate limiting (token bucket per IP).
|
|
package middleware
|
|
|
|
import (
|
|
"net"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
// ---- Security Headers ----
|
|
|
|
// SecureHeaders wraps an http.Handler and adds security-relevant response headers.
|
|
// The CSP is intentionally strict: scripts only from same-origin CDN; no inline scripts.
|
|
func SecureHeaders(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
h := w.Header()
|
|
h.Set("Strict-Transport-Security", "max-age=63072000; includeSubDomains; preload")
|
|
h.Set("X-Content-Type-Options", "nosniff")
|
|
h.Set("X-Frame-Options", "DENY")
|
|
h.Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
|
h.Set("Permissions-Policy", "camera=(), microphone=(), geolocation=(), payment=()")
|
|
h.Set("Content-Security-Policy",
|
|
"default-src 'self'; "+
|
|
"script-src 'self' https://cdn.tailwindcss.com; "+
|
|
"style-src 'self' 'unsafe-inline' https://cdn.tailwindcss.com; "+
|
|
"img-src 'self' data:; "+
|
|
"font-src 'self'; "+
|
|
"frame-src 'self'; "+
|
|
"object-src 'none'; "+
|
|
"base-uri 'self'; "+
|
|
"form-action 'self'")
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// ---- Token-bucket rate limiter ----
|
|
|
|
// bucket is one client's token state.
|
|
type bucket struct {
|
|
tokens float64
|
|
lastSeen time.Time
|
|
}
|
|
|
|
// RateLimiter is an in-memory token-bucket limiter keyed by IP.
|
|
// Safe for concurrent use.
|
|
type RateLimiter struct {
|
|
mu sync.Mutex
|
|
buckets map[string]*bucket
|
|
rate float64 // tokens refilled per second
|
|
capacity float64 // max token count
|
|
status int // HTTP status on exceeded (default 429)
|
|
}
|
|
|
|
// NewRateLimiter creates a limiter.
|
|
// - ratePerMin: tokens refilled per minute (e.g. 20 = 20 req/min steady-state)
|
|
// - burst: max burst size (e.g. 5 = 5 simultaneous requests)
|
|
func NewRateLimiter(ratePerMin int, burst int) *RateLimiter {
|
|
rl := &RateLimiter{
|
|
buckets: make(map[string]*bucket),
|
|
rate: float64(ratePerMin) / 60.0,
|
|
capacity: float64(burst),
|
|
status: http.StatusTooManyRequests,
|
|
}
|
|
// Periodic cleanup goroutine (runs for process lifetime).
|
|
go rl.cleanup()
|
|
return rl
|
|
}
|
|
|
|
// Allow returns true if the request is within limit, consuming one token.
|
|
func (rl *RateLimiter) Allow(ip string) bool {
|
|
rl.mu.Lock()
|
|
defer rl.mu.Unlock()
|
|
|
|
now := time.Now()
|
|
b, ok := rl.buckets[ip]
|
|
if !ok {
|
|
b = &bucket{tokens: rl.capacity, lastSeen: now}
|
|
rl.buckets[ip] = b
|
|
}
|
|
|
|
// Refill tokens based on elapsed time.
|
|
elapsed := now.Sub(b.lastSeen).Seconds()
|
|
b.tokens += elapsed * rl.rate
|
|
if b.tokens > rl.capacity {
|
|
b.tokens = rl.capacity
|
|
}
|
|
b.lastSeen = now
|
|
|
|
if b.tokens < 1 {
|
|
return false
|
|
}
|
|
b.tokens--
|
|
return true
|
|
}
|
|
|
|
// Middleware returns an http.Handler that rate-limits by remote IP.
|
|
func (rl *RateLimiter) Middleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
ip := remoteIP(r)
|
|
if !rl.Allow(ip) {
|
|
w.Header().Set("Retry-After", "60")
|
|
http.Error(w, http.StatusText(rl.status), rl.status)
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// cleanup removes stale buckets every 5 minutes (no activity for >10 min).
|
|
func (rl *RateLimiter) cleanup() {
|
|
ticker := time.NewTicker(5 * time.Minute)
|
|
defer ticker.Stop()
|
|
for range ticker.C {
|
|
cutoff := time.Now().Add(-10 * time.Minute)
|
|
rl.mu.Lock()
|
|
for ip, b := range rl.buckets {
|
|
if b.lastSeen.Before(cutoff) {
|
|
delete(rl.buckets, ip)
|
|
}
|
|
}
|
|
rl.mu.Unlock()
|
|
}
|
|
}
|
|
|
|
// remoteIP extracts the real client IP, honoring X-Real-IP and X-Forwarded-For.
|
|
// Falls back to RemoteAddr.
|
|
func remoteIP(r *http.Request) string {
|
|
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
|
if ip := net.ParseIP(xri); ip != nil {
|
|
return ip.String()
|
|
}
|
|
}
|
|
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
|
// First address in the chain is the original client.
|
|
for _, part := range splitComma(xff) {
|
|
trimmed := strings.TrimSpace(part)
|
|
if ip := net.ParseIP(trimmed); ip != nil {
|
|
return ip.String()
|
|
}
|
|
}
|
|
}
|
|
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
|
if err != nil {
|
|
return r.RemoteAddr
|
|
}
|
|
return host
|
|
}
|
|
|
|
// splitComma splits on comma without allocating a regex.
|
|
func splitComma(s string) []string {
|
|
var out []string
|
|
start := 0
|
|
for i := 0; i < len(s); i++ {
|
|
if s[i] == ',' {
|
|
out = append(out, s[start:i])
|
|
start = i + 1
|
|
}
|
|
}
|
|
out = append(out, s[start:])
|
|
return out
|
|
}
|
|
|