added test
This commit is contained in:
@@ -21,26 +21,29 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
// transportPool configures HTTP transport with timeouts matching Nginx defaults
|
||||
transportPool = &http.Transport{
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
DialContext: (&net.Dialer{Timeout: 30 * time.Second}).DialContext,
|
||||
ResponseHeaderTimeout: 60 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
MaxIdleConns: 100, // Max idle connections
|
||||
MaxIdleConnsPerHost: 10, // Max idle per host
|
||||
IdleConnTimeout: 90 * time.Second, // Idle connection timeout
|
||||
DialContext: (&net.Dialer{Timeout: 30 * time.Second}).DialContext, // Dial timeout
|
||||
ResponseHeaderTimeout: 60 * time.Second, // Response header timeout (Nginx-like)
|
||||
TLSHandshakeTimeout: 10 * time.Second, // TLS handshake timeout
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // Default skip TLS verification
|
||||
}
|
||||
cache = make(map[string]cachedResponse)
|
||||
defaultCacheTTL = 5 * time.Minute
|
||||
cacheMutex sync.RWMutex
|
||||
cache = make(map[string]cachedResponse) // Cache for static responses
|
||||
cacheMutex sync.RWMutex // Mutex for cache access
|
||||
|
||||
// Rate limiter per client IP
|
||||
// Rate limiting per client IP
|
||||
rateLimiters = make(map[string]*rate.Limiter)
|
||||
rateMutex sync.RWMutex
|
||||
rateLimit = rate.Limit(10) // 10 requests per second per client
|
||||
rateBurst = 20 // Allow burst of 20 requests
|
||||
rateLimit = rate.Limit(10) // 10 req/s per client
|
||||
rateBurst = 20 // Burst allowance
|
||||
|
||||
defaultCacheTTL = 5 * time.Minute // Default TTL for cached responses
|
||||
)
|
||||
|
||||
// cachedResponse stores cached response details
|
||||
type cachedResponse struct {
|
||||
body []byte
|
||||
headers http.Header
|
||||
@@ -50,7 +53,8 @@ type cachedResponse struct {
|
||||
etag string
|
||||
}
|
||||
|
||||
func getReverseProxy(target string, skipVerify bool) *httputil.ReverseProxy {
|
||||
// getReverseProxy creates a reverse proxy for a target URL
|
||||
func getReverseProxy(target string, skipVerify bool, originalReq *http.Request) *httputil.ReverseProxy {
|
||||
targetURL, err := url.Parse(target)
|
||||
if err != nil {
|
||||
log.Printf("Error parsing target URL %s: %v", target, err)
|
||||
@@ -58,61 +62,68 @@ func getReverseProxy(target string, skipVerify bool) *httputil.ReverseProxy {
|
||||
return nil
|
||||
}
|
||||
|
||||
// director modifies the request to forward it to the target
|
||||
director := func(req *http.Request) {
|
||||
req.URL.Scheme = targetURL.Scheme
|
||||
req.URL.Host = targetURL.Host
|
||||
// Preserve client's Host header for session continuity
|
||||
if req.Header.Get("Host") != "" {
|
||||
req.Host = req.Header.Get("Host")
|
||||
// Preserve the original client’s Host header for session continuity
|
||||
if originalReq.Header.Get("Host") != "" {
|
||||
req.Host = originalReq.Header.Get("Host")
|
||||
} else {
|
||||
req.Host = targetURL.Host
|
||||
}
|
||||
|
||||
// Preserve all request headers, including cookies
|
||||
for k, v := range req.Header {
|
||||
// Copy all request headers to ensure cookies and session data are preserved
|
||||
for k, v := range originalReq.Header {
|
||||
req.Header[k] = v
|
||||
}
|
||||
|
||||
// Preserve query parameters and path
|
||||
req.URL.RawQuery = req.URL.RawQuery
|
||||
// Preserve query parameters and full path
|
||||
req.URL.RawQuery = originalReq.URL.RawQuery
|
||||
if targetURL.Path != "" {
|
||||
req.URL.Path = strings.TrimPrefix(req.URL.Path, "/")
|
||||
req.URL.Path = strings.TrimPrefix(originalReq.URL.Path, "/")
|
||||
req.URL.Path = singleJoin(targetURL.Path, req.URL.Path)
|
||||
}
|
||||
|
||||
// WebSocket support: pass upgrade headers
|
||||
// Ensure WebSocket headers are passed correctly
|
||||
if strings.ToLower(req.Header.Get("Upgrade")) == "websocket" {
|
||||
req.Header.Set("Connection", "Upgrade")
|
||||
}
|
||||
|
||||
trafficLogger.Printf("Request: %s %s -> %s [Host: %s]", req.Method, req.URL.String(), target, req.Host)
|
||||
// Log request details for debugging
|
||||
trafficLogger.Printf("Request: %s %s -> %s [Host: %s] [Headers: %v]", req.Method, req.URL.String(), target, req.Host, req.Header)
|
||||
}
|
||||
|
||||
transport := transportPool
|
||||
if !skipVerify {
|
||||
// Use a new transport with TLS verification if skipVerify is false
|
||||
transport = &http.Transport{
|
||||
MaxIdleConns: transportPool.MaxIdleConns,
|
||||
MaxIdleConnsPerHost: transportPool.MaxIdleConnsPerHost,
|
||||
IdleConnTimeout: transportPool.IdleConnTimeout,
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: false},
|
||||
MaxIdleConns: transportPool.MaxIdleConns,
|
||||
MaxIdleConnsPerHost: transportPool.MaxIdleConnsPerHost,
|
||||
IdleConnTimeout: transportPool.IdleConnTimeout,
|
||||
DialContext: transportPool.DialContext,
|
||||
ResponseHeaderTimeout: transportPool.ResponseHeaderTimeout,
|
||||
TLSHandshakeTimeout: transportPool.TLSHandshakeTimeout,
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: false},
|
||||
}
|
||||
}
|
||||
|
||||
return &httputil.ReverseProxy{
|
||||
Director: director,
|
||||
Transport: transport,
|
||||
// ModifyResponse processes responses, avoiding corruption
|
||||
ModifyResponse: func(resp *http.Response) error {
|
||||
// Preserve all response headers, including Set-Cookie
|
||||
// Preserve all response headers, including Set-Cookie for sessions
|
||||
for k, v := range resp.Header {
|
||||
resp.Header[k] = v
|
||||
}
|
||||
|
||||
// Handle WebSocket upgrade
|
||||
// Skip modification for WebSocket responses
|
||||
if strings.ToLower(resp.Header.Get("Upgrade")) == "websocket" {
|
||||
return nil // No further modification needed for WebSocket
|
||||
return nil
|
||||
}
|
||||
|
||||
// Compression if client supports it and response isn’t already compressed
|
||||
// Apply compression only if not already compressed and client supports it
|
||||
if resp.Header.Get("Content-Encoding") == "" && strings.Contains(resp.Request.Header.Get("Accept-Encoding"), "gzip") {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
@@ -131,11 +142,11 @@ func getReverseProxy(target string, skipVerify bool) *httputil.ReverseProxy {
|
||||
|
||||
resp.Body = io.NopCloser(&buf)
|
||||
resp.Header.Set("Content-Encoding", "gzip")
|
||||
resp.Header.Del("Content-Length") // Length changes after compression
|
||||
resp.Header.Del("Content-Length")
|
||||
trafficLogger.Printf("Compressed response for %s", resp.Request.URL.String())
|
||||
}
|
||||
|
||||
// Cache static content with ETag support
|
||||
// Cache static content if applicable
|
||||
if shouldCache(resp) {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
@@ -147,7 +158,7 @@ func getReverseProxy(target string, skipVerify bool) *httputil.ReverseProxy {
|
||||
|
||||
etag := resp.Header.Get("ETag")
|
||||
if etag == "" {
|
||||
etag = generateETag(body) // Simple ETag generation if not provided
|
||||
etag = generateETag(body)
|
||||
}
|
||||
|
||||
cacheDuration := parseCacheControl(resp.Header.Get("Cache-Control"))
|
||||
@@ -165,13 +176,18 @@ func getReverseProxy(target string, skipVerify bool) *httputil.ReverseProxy {
|
||||
etag: etag,
|
||||
}
|
||||
cacheMutex.Unlock()
|
||||
resp.Header.Set("ETag", etag)
|
||||
trafficLogger.Printf("Cached response for %s [ETag: %s]", resp.Request.URL.String(), etag)
|
||||
}
|
||||
|
||||
trafficLogger.Printf("Response: %s %d from %s", resp.Status, resp.StatusCode, target)
|
||||
trafficLogger.Printf("Response: %s %d from %s [Headers: %v]", resp.Status, resp.StatusCode, target, resp.Header)
|
||||
return nil
|
||||
},
|
||||
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
if err.Error() == "context canceled" {
|
||||
trafficLogger.Printf("Request canceled by client for %s: %v", r.Host, err)
|
||||
return
|
||||
}
|
||||
log.Printf("Proxy error for %s: %v", r.Host, err)
|
||||
errorLogger.Printf("Proxy error for %s: %v", r.Host, err)
|
||||
http.Error(w, "Proxy error", http.StatusBadGateway)
|
||||
@@ -179,6 +195,7 @@ func getReverseProxy(target string, skipVerify bool) *httputil.ReverseProxy {
|
||||
}
|
||||
}
|
||||
|
||||
// shouldCache checks if a response should be cached based on method and content type
|
||||
func shouldCache(resp *http.Response) bool {
|
||||
if resp.Request.Method != "GET" || resp.StatusCode != http.StatusOK {
|
||||
return false
|
||||
@@ -188,25 +205,27 @@ func shouldCache(resp *http.Response) bool {
|
||||
strings.HasPrefix(contentType, "application/javascript") || strings.HasPrefix(contentType, "application/json")
|
||||
}
|
||||
|
||||
// singleJoin combines path segments with a single slash
|
||||
func singleJoin(prefix, suffix string) string {
|
||||
prefix = strings.TrimSuffix(prefix, "/")
|
||||
suffix = strings.TrimPrefix(suffix, "/")
|
||||
return prefix + "/" + suffix
|
||||
}
|
||||
|
||||
// generateETag creates an ETag from the response body using MD5
|
||||
func generateETag(body []byte) string {
|
||||
return fmt.Sprintf(`"%x"`, md5.Sum(body)) // Simple ETag based on MD5 hash
|
||||
return fmt.Sprintf(`"%x"`, md5.Sum(body))
|
||||
}
|
||||
|
||||
// parseCacheControl extracts max-age from the Cache-Control header for caching duration
|
||||
func parseCacheControl(header string) time.Duration {
|
||||
if header == "" {
|
||||
return 0
|
||||
}
|
||||
parts := strings.Split(header, ",")
|
||||
for _, part := range parts {
|
||||
if strings.Contains(part, "max-age=") {
|
||||
ageStr := strings.TrimPrefix(part, "max-age=")
|
||||
ageStr = strings.TrimSpace(ageStr)
|
||||
if strings.HasPrefix(part, "max-age=") {
|
||||
ageStr := strings.TrimSpace(strings.TrimPrefix(part, "max-age="))
|
||||
if age, err := strconv.Atoi(ageStr); err == nil {
|
||||
return time.Duration(age) * time.Second
|
||||
}
|
||||
@@ -215,6 +234,7 @@ func parseCacheControl(header string) time.Duration {
|
||||
return 0
|
||||
}
|
||||
|
||||
// getLimiter manages rate limiters per client IP
|
||||
func getLimiter(ip string) *rate.Limiter {
|
||||
rateMutex.Lock()
|
||||
defer rateMutex.Unlock()
|
||||
@@ -227,12 +247,13 @@ func getLimiter(ip string) *rate.Limiter {
|
||||
return limiter
|
||||
}
|
||||
|
||||
// handler processes incoming HTTP/HTTPS requests
|
||||
func handler(w http.ResponseWriter, r *http.Request) {
|
||||
configMux.RLock()
|
||||
defer configMux.RUnlock()
|
||||
|
||||
// Rate limiting based on client IP
|
||||
clientIP := r.RemoteAddr[:strings.LastIndex(r.RemoteAddr, ":")] // Strip port
|
||||
clientIP := r.RemoteAddr[:strings.LastIndex(r.RemoteAddr, ":")]
|
||||
limiter := getLimiter(clientIP)
|
||||
if !limiter.Allow() {
|
||||
http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
|
||||
@@ -240,6 +261,7 @@ func handler(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Retrieve target and settings for the requested host
|
||||
target, exists := config.Routes[r.Host]
|
||||
skipVerify := config.TrustTarget[r.Host]
|
||||
noHTTPSRedirect := config.NoHTTPSRedirect[r.Host]
|
||||
@@ -253,15 +275,16 @@ func handler(w http.ResponseWriter, r *http.Request) {
|
||||
noHTTPSRedirect = config.NoHTTPSRedirect["*"]
|
||||
}
|
||||
|
||||
// Handle HTTP->HTTPS redirect if applicable
|
||||
isHTTPS := target[:5] == "https"
|
||||
isHTTPReq := r.TLS == nil
|
||||
|
||||
if isHTTPReq && isHTTPS && !noHTTPSRedirect {
|
||||
redirectURL := "https://" + r.Host + r.RequestURI
|
||||
http.Redirect(w, r, redirectURL, http.StatusMovedPermanently)
|
||||
return
|
||||
}
|
||||
|
||||
// Check cache for static content
|
||||
cacheKey := r.URL.String()
|
||||
cacheMutex.RLock()
|
||||
if cached, ok := cache[cacheKey]; ok && time.Since(cached.cachedAt) < cached.cacheDuration {
|
||||
@@ -281,50 +304,65 @@ func handler(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
cacheMutex.RUnlock()
|
||||
|
||||
proxy := getReverseProxy(target, skipVerify)
|
||||
proxy := getReverseProxy(target, skipVerify, r)
|
||||
if proxy == nil {
|
||||
http.Error(w, "Invalid target configuration", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
targetURL, err := url.Parse(target)
|
||||
if err != nil {
|
||||
errorLogger.Printf("Error parsing target URL %s: %v", target, err)
|
||||
http.Error(w, "Invalid target configuration", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// WebSocket upgrade handling
|
||||
// Handle WebSocket connections
|
||||
if strings.ToLower(r.Header.Get("Upgrade")) == "websocket" {
|
||||
targetURL, err := url.Parse(target)
|
||||
if err != nil {
|
||||
errorLogger.Printf("Error parsing target URL for WebSocket: %v", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
hijacker, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
http.Error(w, "WebSocket upgrade not supported", http.StatusInternalServerError)
|
||||
errorLogger.Printf("WebSocket upgrade not supported by server")
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
conn, _, err := hijacker.Hijack()
|
||||
clientConn, _, err := hijacker.Hijack()
|
||||
if err != nil {
|
||||
errorLogger.Printf("Failed to hijack connection for WebSocket: %v", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
defer clientConn.Close()
|
||||
|
||||
targetConn, err := transportPool.Dial("tcp", targetURL.Host)
|
||||
dialer := transportPool
|
||||
if !skipVerify {
|
||||
dialer = &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: false},
|
||||
}
|
||||
}
|
||||
targetConn, err := dialer.Dial("tcp", targetURL.Host)
|
||||
if err != nil {
|
||||
errorLogger.Printf("Failed to dial target for WebSocket: %v", err)
|
||||
return
|
||||
}
|
||||
defer targetConn.Close()
|
||||
|
||||
// Forward request to target
|
||||
if err := r.Write(targetConn); err != nil {
|
||||
errorLogger.Printf("Failed to forward WebSocket request: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Pipe connections
|
||||
go io.Copy(conn, targetConn)
|
||||
io.Copy(targetConn, conn)
|
||||
errChan := make(chan error, 2)
|
||||
go func() {
|
||||
_, err := io.Copy(targetConn, clientConn)
|
||||
errChan <- err
|
||||
}()
|
||||
go func() {
|
||||
_, err := io.Copy(clientConn, targetConn)
|
||||
errChan <- err
|
||||
}()
|
||||
trafficLogger.Printf("WebSocket connection established: %s -> %s", r.Host, target)
|
||||
<-errChan
|
||||
trafficLogger.Printf("WebSocket connection closed: %s -> %s", r.Host, target)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user