Files
honeydany/app/dashboard/csrf.go
T

313 lines
7.2 KiB
Go

package dashboard
import (
"crypto/rand"
"encoding/base64"
"fmt"
"net/http"
"sync"
"time"
)
// CSRFManager handles CSRF token generation and validation
type CSRFManager struct {
tokens map[string]time.Time
mutex sync.RWMutex
}
// NewCSRFManager creates a new CSRF manager
func NewCSRFManager() *CSRFManager {
cm := &CSRFManager{
tokens: make(map[string]time.Time),
}
// Start cleanup goroutine
go cm.cleanupExpiredTokens()
return cm
}
// GenerateToken generates a new CSRF token
func (cm *CSRFManager) GenerateToken() (string, error) {
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
token := base64.URLEncoding.EncodeToString(bytes)
cm.mutex.Lock()
cm.tokens[token] = time.Now().Add(1 * time.Hour) // Token expires in 1 hour
cm.mutex.Unlock()
return token, nil
}
// ValidateToken validates a CSRF token
func (cm *CSRFManager) ValidateToken(token string) bool {
if token == "" {
return false
}
cm.mutex.RLock()
expiresAt, exists := cm.tokens[token]
cm.mutex.RUnlock()
if !exists {
return false
}
if time.Now().After(expiresAt) {
// Token expired, remove it
cm.mutex.Lock()
delete(cm.tokens, token)
cm.mutex.Unlock()
return false
}
return true
}
// ConsumeToken validates and removes a CSRF token (one-time use)
func (cm *CSRFManager) ConsumeToken(token string) bool {
if token == "" {
return false
}
cm.mutex.Lock()
defer cm.mutex.Unlock()
expiresAt, exists := cm.tokens[token]
if !exists {
return false
}
if time.Now().After(expiresAt) {
delete(cm.tokens, token)
return false
}
// Remove token after successful validation (one-time use)
delete(cm.tokens, token)
return true
}
// cleanupExpiredTokens periodically removes expired tokens
func (cm *CSRFManager) cleanupExpiredTokens() {
ticker := time.NewTicker(10 * time.Minute)
defer ticker.Stop()
for range ticker.C {
now := time.Now()
cm.mutex.Lock()
for token, expiresAt := range cm.tokens {
if now.After(expiresAt) {
delete(cm.tokens, token)
}
}
cm.mutex.Unlock()
}
}
// CSRFMiddleware provides CSRF protection for HTTP handlers
func (cm *CSRFManager) CSRFMiddleware(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Skip CSRF check for GET, HEAD, OPTIONS requests
if r.Method == "GET" || r.Method == "HEAD" || r.Method == "OPTIONS" {
next(w, r)
return
}
// Get CSRF token from header or form
token := r.Header.Get("X-CSRF-Token")
if token == "" {
token = r.FormValue("csrf_token")
}
if !cm.ConsumeToken(token) {
http.Error(w, "Invalid CSRF token", http.StatusForbidden)
return
}
next(w, r)
}
}
// AddCSRFTokenToResponse adds a CSRF token to the response
func (cm *CSRFManager) AddCSRFTokenToResponse(w http.ResponseWriter, data map[string]interface{}) error {
token, err := cm.GenerateToken()
if err != nil {
return fmt.Errorf("failed to generate CSRF token: %w", err)
}
data["CSRFToken"] = token
w.Header().Set("X-CSRF-Token", token)
return nil
}
// InputValidator provides input validation utilities
type InputValidator struct{}
// NewInputValidator creates a new input validator
func NewInputValidator() *InputValidator {
return &InputValidator{}
}
// ValidateUsername validates username format
func (iv *InputValidator) ValidateUsername(username string) error {
if len(username) < 3 {
return fmt.Errorf("username must be at least 3 characters long")
}
if len(username) > 50 {
return fmt.Errorf("username must be less than 50 characters long")
}
// Check for valid characters (alphanumeric, underscore, hyphen)
for _, char := range username {
if !((char >= 'a' && char <= 'z') || (char >= 'A' && char <= 'Z') ||
(char >= '0' && char <= '9') || char == '_' || char == '-') {
return fmt.Errorf("username can only contain letters, numbers, underscore, and hyphen")
}
}
return nil
}
// ValidatePassword validates password strength
func (iv *InputValidator) ValidatePassword(password string) error {
if len(password) < 8 {
return fmt.Errorf("password must be at least 8 characters long")
}
if len(password) > 128 {
return fmt.Errorf("password must be less than 128 characters long")
}
hasUpper := false
hasLower := false
hasDigit := false
hasSpecial := false
for _, char := range password {
switch {
case char >= 'A' && char <= 'Z':
hasUpper = true
case char >= 'a' && char <= 'z':
hasLower = true
case char >= '0' && char <= '9':
hasDigit = true
case char >= 32 && char <= 126: // Printable ASCII
if !((char >= 'a' && char <= 'z') || (char >= 'A' && char <= 'Z') || (char >= '0' && char <= '9')) {
hasSpecial = true
}
default:
return fmt.Errorf("password contains invalid characters")
}
}
if !hasUpper {
return fmt.Errorf("password must contain at least one uppercase letter")
}
if !hasLower {
return fmt.Errorf("password must contain at least one lowercase letter")
}
if !hasDigit {
return fmt.Errorf("password must contain at least one digit")
}
if !hasSpecial {
return fmt.Errorf("password must contain at least one special character")
}
return nil
}
// ValidateEmail validates email format (basic validation)
func (iv *InputValidator) ValidateEmail(email string) error {
if email == "" {
return nil // Email is optional
}
if len(email) > 254 {
return fmt.Errorf("email address is too long")
}
// Basic email validation
atCount := 0
dotAfterAt := false
for i, char := range email {
if char == '@' {
atCount++
if i == 0 || i == len(email)-1 {
return fmt.Errorf("invalid email format")
}
} else if char == '.' && atCount == 1 {
dotAfterAt = true
}
}
if atCount != 1 || !dotAfterAt {
return fmt.Errorf("invalid email format")
}
return nil
}
// ValidateRole validates user role
func (iv *InputValidator) ValidateRole(role string) error {
validRoles := map[string]bool{
"admin": true,
"user": true,
"readonly": true,
}
if !validRoles[role] {
return fmt.Errorf("invalid role: must be admin, user, or readonly")
}
return nil
}
// ValidateAPIKeyName validates API key name
func (iv *InputValidator) ValidateAPIKeyName(name string) error {
if len(name) < 1 {
return fmt.Errorf("API key name is required")
}
if len(name) > 100 {
return fmt.Errorf("API key name must be less than 100 characters")
}
// Check for valid characters
for _, char := range name {
if !((char >= 'a' && char <= 'z') || (char >= 'A' && char <= 'Z') ||
(char >= '0' && char <= '9') || char == '_' || char == '-' || char == ' ') {
return fmt.Errorf("API key name can only contain letters, numbers, underscore, hyphen, and spaces")
}
}
return nil
}
// SanitizeString removes potentially dangerous characters from strings
func (iv *InputValidator) SanitizeString(input string) string {
// Remove null bytes and control characters
result := ""
for _, char := range input {
if char >= 32 && char <= 126 { // Printable ASCII only
result += string(char)
}
}
return result
}
// ValidateInteger validates integer input within range
func (iv *InputValidator) ValidateInteger(value, min, max int) error {
if value < min {
return fmt.Errorf("value must be at least %d", min)
}
if value > max {
return fmt.Errorf("value must be at most %d", max)
}
return nil
}