228 lines
8.1 KiB
Python
228 lines
8.1 KiB
Python
import time
|
|
import logging
|
|
from functools import wraps
|
|
from flask import request, jsonify, g
|
|
import redis
|
|
from werkzeug.exceptions import TooManyRequests
|
|
import os
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Configure Redis connection (if available)
|
|
def configure_redis(config=None):
|
|
"""Configure Redis connection from config"""
|
|
global redis_client
|
|
|
|
redis_url = None
|
|
if config and config.has_option('rate_limiting', 'REDIS_URL'):
|
|
redis_url = config.get('rate_limiting', 'REDIS_URL')
|
|
|
|
if not redis_url:
|
|
redis_url = os.environ.get('REDIS_URL', None)
|
|
|
|
if redis_url:
|
|
try:
|
|
redis_client = redis.from_url(redis_url)
|
|
redis_client.ping() # Test connection
|
|
logger.info("Redis connected for rate limiting")
|
|
except Exception as e:
|
|
logger.warning(f"Redis connection failed for rate limiting: {str(e)}")
|
|
redis_client = None
|
|
else:
|
|
logger.info("No Redis URL configured, using in-memory rate limiting")
|
|
|
|
# Initialize Redis to None
|
|
redis_client = None
|
|
|
|
# In-memory rate limit storage (fallback if Redis is not available)
|
|
rate_limit_storage = {}
|
|
|
|
def rate_limit(limit=60, per=60, scope_func=None):
|
|
"""
|
|
Rate limiting decorator for routes.
|
|
|
|
Args:
|
|
limit (int): Maximum number of requests allowed in the time period
|
|
per (int): Time period in seconds
|
|
scope_func (callable): Function to determine rate limit scope (default: by IP)
|
|
|
|
Example usage:
|
|
@app.route('/api/endpoint')
|
|
@rate_limit(limit=10, per=60) # 10 requests per minute
|
|
def api_endpoint():
|
|
return jsonify({"status": "success"})
|
|
"""
|
|
def decorator(f):
|
|
@wraps(f)
|
|
def wrapped(*args, **kwargs):
|
|
# Get scope key (default to IP address)
|
|
if scope_func:
|
|
scope = scope_func()
|
|
else:
|
|
# Default to client IP
|
|
scope = get_remote_address()
|
|
|
|
# Create a unique key for this route and scope
|
|
key = f"rate_limit:{request.path}:{scope}"
|
|
|
|
# Check rate limit
|
|
current = get_rate_limit_value(key)
|
|
|
|
# If this is a new key or expired, initialize it
|
|
if current is None:
|
|
set_rate_limit_value(key, 1, per)
|
|
current = 1
|
|
else:
|
|
# Increment counter
|
|
current = increment_rate_limit_value(key)
|
|
|
|
# Add rate limit headers
|
|
g.rate_limit_headers = {
|
|
'X-RateLimit-Limit': str(limit),
|
|
'X-RateLimit-Remaining': str(max(0, limit - current)),
|
|
'X-RateLimit-Reset': str(int(time.time() + get_ttl(key)))
|
|
}
|
|
|
|
# If over limit, return 429 Too Many Requests
|
|
if current > limit:
|
|
logger.warning(f"Rate limit exceeded for {scope} on {request.path}")
|
|
response = jsonify({
|
|
'error': 'Too many requests',
|
|
'message': f'Rate limit of {limit} requests per {per} seconds exceeded'
|
|
})
|
|
response.status_code = 429
|
|
|
|
# Add rate limit headers to response
|
|
for key, value in g.rate_limit_headers.items():
|
|
response.headers[key] = value
|
|
|
|
return response
|
|
|
|
# Execute the original route function
|
|
response = f(*args, **kwargs)
|
|
|
|
# Convert string response to Response object if needed
|
|
from flask import make_response
|
|
if isinstance(response, str):
|
|
response = make_response(response)
|
|
|
|
# Add rate limit headers to response
|
|
if hasattr(response, 'headers'):
|
|
for key, value in g.rate_limit_headers.items():
|
|
response.headers[key] = value
|
|
|
|
return response
|
|
return wrapped
|
|
return decorator
|
|
|
|
def get_remote_address():
|
|
"""Get the client's IP address, respecting proxy headers"""
|
|
# Check X-Forwarded-For header (used by Traefik and other proxies)
|
|
if request.headers.get('X-Forwarded-For'):
|
|
# Get the first IP in the chain (original client IP)
|
|
forwarded_for = request.headers.get('X-Forwarded-For').split(',')[0].strip()
|
|
logger.debug(f"Using X-Forwarded-For IP: {forwarded_for} (from: {request.headers.get('X-Forwarded-For')})")
|
|
return forwarded_for
|
|
|
|
# Check X-Real-IP header (alternative proxy header)
|
|
if request.headers.get('X-Real-IP'):
|
|
real_ip = request.headers.get('X-Real-IP').strip()
|
|
logger.debug(f"Using X-Real-IP: {real_ip}")
|
|
return real_ip
|
|
|
|
# Fallback to direct connection IP
|
|
remote_addr = request.remote_addr
|
|
logger.debug(f"Using direct remote_addr: {remote_addr}")
|
|
return remote_addr
|
|
|
|
# Redis implementations
|
|
def get_rate_limit_value(key):
|
|
"""Get current rate limit counter value"""
|
|
if redis_client:
|
|
value = redis_client.get(key)
|
|
return int(value) if value else None
|
|
else:
|
|
# Fallback to in-memory storage
|
|
if key in rate_limit_storage:
|
|
# Check if expired
|
|
if time.time() > rate_limit_storage[key]['expires']:
|
|
del rate_limit_storage[key]
|
|
return None
|
|
return rate_limit_storage[key]['value']
|
|
return None
|
|
|
|
def set_rate_limit_value(key, value, ttl):
|
|
"""Set rate limit counter with TTL"""
|
|
if redis_client:
|
|
redis_client.setex(key, ttl, value)
|
|
else:
|
|
# Fallback to in-memory storage
|
|
rate_limit_storage[key] = {
|
|
'value': value,
|
|
'expires': time.time() + ttl
|
|
}
|
|
|
|
def increment_rate_limit_value(key):
|
|
"""Increment rate limit counter"""
|
|
if redis_client:
|
|
return redis_client.incr(key)
|
|
else:
|
|
# Fallback to in-memory storage
|
|
if key in rate_limit_storage:
|
|
rate_limit_storage[key]['value'] += 1
|
|
return rate_limit_storage[key]['value']
|
|
return 1
|
|
|
|
def get_ttl(key):
|
|
"""Get remaining TTL for a key"""
|
|
if redis_client:
|
|
ttl = redis_client.ttl(key)
|
|
return max(0, ttl)
|
|
else:
|
|
# Fallback to in-memory storage
|
|
if key in rate_limit_storage:
|
|
return max(0, rate_limit_storage[key]['expires'] - time.time())
|
|
return 0
|
|
|
|
# Utility to apply rate limits to entire blueprints
|
|
def apply_rate_limits(app, config=None):
|
|
"""Apply rate limits to sensitive routes"""
|
|
if not config:
|
|
return
|
|
|
|
# Check if rate limiting is enabled
|
|
if not config.getboolean('rate_limiting', 'ENABLE_RATE_LIMITING', fallback=True):
|
|
return
|
|
|
|
# Configure Redis connection
|
|
configure_redis(config)
|
|
|
|
# Get rate limiting configuration
|
|
login_limit = config.getint('rate_limiting', 'LOGIN_LIMIT', fallback=10)
|
|
login_period = config.getint('rate_limiting', 'LOGIN_PERIOD', fallback=60)
|
|
register_limit = config.getint('rate_limiting', 'REGISTER_LIMIT', fallback=5)
|
|
register_period = config.getint('rate_limiting', 'REGISTER_PERIOD', fallback=300)
|
|
api_limit = config.getint('rate_limiting', 'API_LIMIT', fallback=60)
|
|
api_period = config.getint('rate_limiting', 'API_PERIOD', fallback=60)
|
|
|
|
# Login endpoint
|
|
if 'auth.login' in app.view_functions:
|
|
app.view_functions['auth.login'] = rate_limit(
|
|
limit=login_limit, per=login_period
|
|
)(app.view_functions['auth.login'])
|
|
|
|
# Register endpoint
|
|
if 'auth.register' in app.view_functions:
|
|
app.view_functions['auth.register'] = rate_limit(
|
|
limit=register_limit, per=register_period
|
|
)(app.view_functions['auth.register'])
|
|
|
|
# API endpoints
|
|
api_routes = [route for route in app.url_map.iter_rules()
|
|
if route.rule.startswith('/api/')]
|
|
|
|
for route in api_routes:
|
|
if route.endpoint in app.view_functions:
|
|
app.view_functions[route.endpoint] = rate_limit(
|
|
limit=api_limit, per=api_period
|
|
)(app.view_functions[route.endpoint]) |