155 lines
5.1 KiB
Python
155 lines
5.1 KiB
Python
"""Security utilities for CSRF protection and input validation."""
|
|
import secrets
|
|
import string
|
|
from typing import Optional
|
|
from fastapi import Request
|
|
import html
|
|
|
|
|
|
class CSRFTokenManager:
|
|
"""Manage CSRF tokens for forms."""
|
|
|
|
@staticmethod
|
|
def generate_token() -> str:
|
|
"""Generate a secure random CSRF token."""
|
|
return secrets.token_urlsafe(32)
|
|
|
|
@staticmethod
|
|
def store_token(request: Request, token: str) -> None:
|
|
"""Store CSRF token in session."""
|
|
# For simplicity, we'll use a cookie-based approach
|
|
# In production, use sessions or secure storage
|
|
pass
|
|
|
|
@staticmethod
|
|
def verify_token(token1: str, token2: str) -> bool:
|
|
"""Verify CSRF token using constant-time comparison."""
|
|
return secrets.compare_digest(token1, token2)
|
|
|
|
|
|
class InputValidator:
|
|
"""Validate and sanitize user inputs."""
|
|
|
|
@staticmethod
|
|
def sanitize_string(value: str, max_length: int = 500, allowed_chars: Optional[str] = None) -> str:
|
|
"""Sanitize string input."""
|
|
if not isinstance(value, str):
|
|
raise ValueError("Input must be a string")
|
|
|
|
# Strip whitespace
|
|
value = value.strip()
|
|
|
|
# Check length
|
|
if len(value) > max_length:
|
|
raise ValueError(f"Input exceeds maximum length of {max_length}")
|
|
|
|
# Check for empty
|
|
if not value:
|
|
raise ValueError("Input cannot be empty")
|
|
|
|
# HTML escape to prevent XSS
|
|
value = html.escape(value)
|
|
|
|
# If specific characters allowed, validate
|
|
if allowed_chars is not None:
|
|
# Allow letters, numbers, and some special chars
|
|
valid_chars = set(allowed_chars)
|
|
if not all(c in valid_chars for c in value.replace(" ", "")):
|
|
raise ValueError(f"Input contains invalid characters")
|
|
|
|
return value
|
|
|
|
@staticmethod
|
|
def sanitize_email(email: str) -> str:
|
|
"""Validate and sanitize email."""
|
|
email = email.strip().lower()
|
|
|
|
if not email or len(email) > 254:
|
|
raise ValueError("Invalid email address")
|
|
|
|
# Basic email validation
|
|
if "@" not in email or "." not in email.split("@")[-1]:
|
|
raise ValueError("Invalid email format")
|
|
|
|
return email
|
|
|
|
@staticmethod
|
|
def sanitize_slug(slug: str) -> str:
|
|
"""Validate and sanitize URL slug."""
|
|
slug = slug.strip().lower()
|
|
|
|
if not slug or len(slug) > 50:
|
|
raise ValueError("Slug must be between 1 and 50 characters")
|
|
|
|
# Only allow alphanumeric, hyphens, underscores
|
|
valid_chars = set(string.ascii_lowercase + string.digits + "-_")
|
|
if not all(c in valid_chars for c in slug):
|
|
raise ValueError("Slug can only contain letters, numbers, hyphens, and underscores")
|
|
|
|
# Must start and end with alphanumeric
|
|
if not slug[0].isalnum() or not slug[-1].isalnum():
|
|
raise ValueError("Slug must start and end with a letter or number")
|
|
|
|
return slug
|
|
|
|
@staticmethod
|
|
def sanitize_password(password: str) -> str:
|
|
"""Validate password requirements."""
|
|
if not password or len(password) < 8:
|
|
raise ValueError("Password must be at least 8 characters")
|
|
|
|
if len(password) > 128:
|
|
raise ValueError("Password is too long")
|
|
|
|
return password
|
|
|
|
@staticmethod
|
|
def sanitize_url_list(urls: str) -> str:
|
|
"""Validate and sanitize URL list."""
|
|
if not urls or not urls.strip():
|
|
raise ValueError("URL list cannot be empty")
|
|
|
|
lines = urls.strip().split("\n")
|
|
|
|
if not lines or len(lines) == 0:
|
|
raise ValueError("URL list must contain at least one URL")
|
|
|
|
if len(lines) > 1000:
|
|
raise ValueError("URL list cannot contain more than 1000 URLs")
|
|
|
|
# Validate each URL
|
|
validated_lines = []
|
|
for line in lines:
|
|
line = line.strip()
|
|
if not line:
|
|
continue # Skip empty lines
|
|
|
|
if len(line) > 2048:
|
|
raise ValueError("URL is too long")
|
|
|
|
# Basic URL validation
|
|
if not (line.startswith("http://") or line.startswith("https://") or
|
|
line.startswith("ftp://") or line.startswith("ftps://")):
|
|
raise ValueError(f"Invalid URL: {line[:50]}...")
|
|
|
|
validated_lines.append(line)
|
|
|
|
if not validated_lines:
|
|
raise ValueError("URL list must contain at least one valid URL")
|
|
|
|
return "\n".join(validated_lines)
|
|
|
|
@staticmethod
|
|
def validate_list_name(name: str) -> str:
|
|
"""Validate list name."""
|
|
name = name.strip()
|
|
|
|
if not name or len(name) < 1:
|
|
raise ValueError("List name cannot be empty")
|
|
|
|
if len(name) > 100:
|
|
raise ValueError("List name cannot exceed 100 characters")
|
|
|
|
# HTML escape
|
|
return html.escape(name)
|