Files
honeydany/app/dashboard/security_headers.go
T

112 lines
2.9 KiB
Go

package dashboard
import (
"net/http"
"strings"
)
// SecurityHeadersMiddleware adds security headers to all responses
func SecurityHeadersMiddleware(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Prevent clickjacking
w.Header().Set("X-Frame-Options", "DENY")
// Prevent MIME type sniffing
w.Header().Set("X-Content-Type-Options", "nosniff")
// Enable XSS protection
w.Header().Set("X-XSS-Protection", "1; mode=block")
// Strict transport security (HTTPS only)
if r.TLS != nil {
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
}
// Content Security Policy
csp := strings.Join([]string{
"default-src 'self'",
"script-src 'self' 'unsafe-inline' https://cdn.tailwindcss.com",
"style-src 'self' 'unsafe-inline' https://cdn.tailwindcss.com",
"img-src 'self' data:",
"font-src 'self'",
"connect-src 'self'",
"frame-ancestors 'none'",
"base-uri 'self'",
"form-action 'self'",
}, "; ")
w.Header().Set("Content-Security-Policy", csp)
// Referrer policy
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
// Permissions policy
w.Header().Set("Permissions-Policy", "geolocation=(), microphone=(), camera=()")
// Remove server identification
w.Header().Set("Server", "")
next(w, r)
}
}
// IPWhitelistMiddleware restricts dashboard access to specific IPs
func IPWhitelistMiddleware(allowedIPs []string) func(http.HandlerFunc) http.HandlerFunc {
return func(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
clientIP := getClientIP(r)
// Check if IP is in whitelist
allowed := false
for _, ip := range allowedIPs {
if ip == clientIP || ip == "0.0.0.0" { // 0.0.0.0 allows all
allowed = true
break
}
}
if !allowed {
http.Error(w, "Access denied", http.StatusForbidden)
return
}
next(w, r)
}
}
}
// getClientIP extracts the real client IP from request
func getClientIP(r *http.Request) string {
// Check X-Forwarded-For header
xff := r.Header.Get("X-Forwarded-For")
if xff != "" {
ips := strings.Split(xff, ",")
return strings.TrimSpace(ips[0])
}
// Check X-Real-IP header
xri := r.Header.Get("X-Real-IP")
if xri != "" {
return strings.TrimSpace(xri)
}
// Fall back to RemoteAddr
ip := r.RemoteAddr
if colon := strings.LastIndex(ip, ":"); colon != -1 {
ip = ip[:colon]
}
return ip
}
// RateLimitMiddleware implements basic rate limiting
func RateLimitMiddleware(requestsPerMinute int) func(http.HandlerFunc) http.HandlerFunc {
// Simple in-memory rate limiter (use Redis in production)
// This is a basic implementation - consider using a proper rate limiter
return func(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// For now, just pass through - implement proper rate limiting
next(w, r)
}
}
}