Files
winauthmon-server/utils/rate_limiter.py
T
2025-05-25 20:26:18 +01:00

228 lines
8.1 KiB
Python

import time
import logging
from functools import wraps
from flask import request, jsonify, g
import redis
from werkzeug.exceptions import TooManyRequests
import os
logger = logging.getLogger(__name__)
# Configure Redis connection (if available)
def configure_redis(config=None):
"""Configure Redis connection from config"""
global redis_client
redis_url = None
if config and config.has_option('rate_limiting', 'REDIS_URL'):
redis_url = config.get('rate_limiting', 'REDIS_URL')
if not redis_url:
redis_url = os.environ.get('REDIS_URL', None)
if redis_url:
try:
redis_client = redis.from_url(redis_url)
redis_client.ping() # Test connection
logger.info("Redis connected for rate limiting")
except Exception as e:
logger.warning(f"Redis connection failed for rate limiting: {str(e)}")
redis_client = None
else:
logger.info("No Redis URL configured, using in-memory rate limiting")
# Initialize Redis to None
redis_client = None
# In-memory rate limit storage (fallback if Redis is not available)
rate_limit_storage = {}
def rate_limit(limit=60, per=60, scope_func=None):
"""
Rate limiting decorator for routes.
Args:
limit (int): Maximum number of requests allowed in the time period
per (int): Time period in seconds
scope_func (callable): Function to determine rate limit scope (default: by IP)
Example usage:
@app.route('/api/endpoint')
@rate_limit(limit=10, per=60) # 10 requests per minute
def api_endpoint():
return jsonify({"status": "success"})
"""
def decorator(f):
@wraps(f)
def wrapped(*args, **kwargs):
# Get scope key (default to IP address)
if scope_func:
scope = scope_func()
else:
# Default to client IP
scope = get_remote_address()
# Create a unique key for this route and scope
key = f"rate_limit:{request.path}:{scope}"
# Check rate limit
current = get_rate_limit_value(key)
# If this is a new key or expired, initialize it
if current is None:
set_rate_limit_value(key, 1, per)
current = 1
else:
# Increment counter
current = increment_rate_limit_value(key)
# Add rate limit headers
g.rate_limit_headers = {
'X-RateLimit-Limit': str(limit),
'X-RateLimit-Remaining': str(max(0, limit - current)),
'X-RateLimit-Reset': str(int(time.time() + get_ttl(key)))
}
# If over limit, return 429 Too Many Requests
if current > limit:
logger.warning(f"Rate limit exceeded for {scope} on {request.path}")
response = jsonify({
'error': 'Too many requests',
'message': f'Rate limit of {limit} requests per {per} seconds exceeded'
})
response.status_code = 429
# Add rate limit headers to response
for key, value in g.rate_limit_headers.items():
response.headers[key] = value
return response
# Execute the original route function
response = f(*args, **kwargs)
# Convert string response to Response object if needed
from flask import make_response
if isinstance(response, str):
response = make_response(response)
# Add rate limit headers to response
if hasattr(response, 'headers'):
for key, value in g.rate_limit_headers.items():
response.headers[key] = value
return response
return wrapped
return decorator
def get_remote_address():
"""Get the client's IP address, respecting proxy headers"""
# Check X-Forwarded-For header (used by Traefik and other proxies)
if request.headers.get('X-Forwarded-For'):
# Get the first IP in the chain (original client IP)
forwarded_for = request.headers.get('X-Forwarded-For').split(',')[0].strip()
logger.debug(f"Using X-Forwarded-For IP: {forwarded_for} (from: {request.headers.get('X-Forwarded-For')})")
return forwarded_for
# Check X-Real-IP header (alternative proxy header)
if request.headers.get('X-Real-IP'):
real_ip = request.headers.get('X-Real-IP').strip()
logger.debug(f"Using X-Real-IP: {real_ip}")
return real_ip
# Fallback to direct connection IP
remote_addr = request.remote_addr
logger.debug(f"Using direct remote_addr: {remote_addr}")
return remote_addr
# Redis implementations
def get_rate_limit_value(key):
"""Get current rate limit counter value"""
if redis_client:
value = redis_client.get(key)
return int(value) if value else None
else:
# Fallback to in-memory storage
if key in rate_limit_storage:
# Check if expired
if time.time() > rate_limit_storage[key]['expires']:
del rate_limit_storage[key]
return None
return rate_limit_storage[key]['value']
return None
def set_rate_limit_value(key, value, ttl):
"""Set rate limit counter with TTL"""
if redis_client:
redis_client.setex(key, ttl, value)
else:
# Fallback to in-memory storage
rate_limit_storage[key] = {
'value': value,
'expires': time.time() + ttl
}
def increment_rate_limit_value(key):
"""Increment rate limit counter"""
if redis_client:
return redis_client.incr(key)
else:
# Fallback to in-memory storage
if key in rate_limit_storage:
rate_limit_storage[key]['value'] += 1
return rate_limit_storage[key]['value']
return 1
def get_ttl(key):
"""Get remaining TTL for a key"""
if redis_client:
ttl = redis_client.ttl(key)
return max(0, ttl)
else:
# Fallback to in-memory storage
if key in rate_limit_storage:
return max(0, rate_limit_storage[key]['expires'] - time.time())
return 0
# Utility to apply rate limits to entire blueprints
def apply_rate_limits(app, config=None):
"""Apply rate limits to sensitive routes"""
if not config:
return
# Check if rate limiting is enabled
if not config.getboolean('rate_limiting', 'ENABLE_RATE_LIMITING', fallback=True):
return
# Configure Redis connection
configure_redis(config)
# Get rate limiting configuration
login_limit = config.getint('rate_limiting', 'LOGIN_LIMIT', fallback=10)
login_period = config.getint('rate_limiting', 'LOGIN_PERIOD', fallback=60)
register_limit = config.getint('rate_limiting', 'REGISTER_LIMIT', fallback=5)
register_period = config.getint('rate_limiting', 'REGISTER_PERIOD', fallback=300)
api_limit = config.getint('rate_limiting', 'API_LIMIT', fallback=60)
api_period = config.getint('rate_limiting', 'API_PERIOD', fallback=60)
# Login endpoint
if 'auth.login' in app.view_functions:
app.view_functions['auth.login'] = rate_limit(
limit=login_limit, per=login_period
)(app.view_functions['auth.login'])
# Register endpoint
if 'auth.register' in app.view_functions:
app.view_functions['auth.register'] = rate_limit(
limit=register_limit, per=register_period
)(app.view_functions['auth.register'])
# API endpoints
api_routes = [route for route in app.url_map.iter_rules()
if route.rule.startswith('/api/')]
for route in api_routes:
if route.endpoint in app.view_functions:
app.view_functions[route.endpoint] = rate_limit(
limit=api_limit, per=api_period
)(app.view_functions[route.endpoint])