package dashboard import ( "net/http" "strings" ) // SecurityHeadersMiddleware adds security headers to all responses func SecurityHeadersMiddleware(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Prevent clickjacking w.Header().Set("X-Frame-Options", "DENY") // Prevent MIME type sniffing w.Header().Set("X-Content-Type-Options", "nosniff") // Enable XSS protection w.Header().Set("X-XSS-Protection", "1; mode=block") // Strict transport security (HTTPS only) if r.TLS != nil { w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains") } // Content Security Policy csp := strings.Join([]string{ "default-src 'self'", "script-src 'self' 'unsafe-inline' https://cdn.tailwindcss.com", "style-src 'self' 'unsafe-inline' https://cdn.tailwindcss.com", "img-src 'self' data:", "font-src 'self'", "connect-src 'self'", "frame-ancestors 'none'", "base-uri 'self'", "form-action 'self'", }, "; ") w.Header().Set("Content-Security-Policy", csp) // Referrer policy w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") // Permissions policy w.Header().Set("Permissions-Policy", "geolocation=(), microphone=(), camera=()") // Remove server identification w.Header().Set("Server", "") next(w, r) } } // IPWhitelistMiddleware restricts dashboard access to specific IPs func IPWhitelistMiddleware(allowedIPs []string) func(http.HandlerFunc) http.HandlerFunc { return func(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { clientIP := getClientIP(r) // Check if IP is in whitelist allowed := false for _, ip := range allowedIPs { if ip == clientIP || ip == "0.0.0.0" { // 0.0.0.0 allows all allowed = true break } } if !allowed { http.Error(w, "Access denied", http.StatusForbidden) return } next(w, r) } } } // getClientIP extracts the real client IP from request func getClientIP(r *http.Request) string { // Check X-Forwarded-For header xff := r.Header.Get("X-Forwarded-For") if xff != "" { ips := strings.Split(xff, ",") return strings.TrimSpace(ips[0]) } // Check X-Real-IP header xri := r.Header.Get("X-Real-IP") if xri != "" { return strings.TrimSpace(xri) } // Fall back to RemoteAddr ip := r.RemoteAddr if colon := strings.LastIndex(ip, ":"); colon != -1 { ip = ip[:colon] } return ip } // RateLimitMiddleware implements basic rate limiting func RateLimitMiddleware(requestsPerMinute int) func(http.HandlerFunc) http.HandlerFunc { // Simple in-memory rate limiter (use Redis in production) // This is a basic implementation - consider using a proper rate limiter return func(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // For now, just pass through - implement proper rate limiting next(w, r) } } }