Files
GoLangProxy/proxy.go
2025-03-01 18:07:52 +00:00

333 lines
9.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"bytes"
"compress/gzip"
"crypto/md5"
"crypto/tls"
"fmt"
"io"
"log"
"net"
"net/http"
"net/http/httputil"
"net/url"
"strconv"
"strings"
"sync"
"time"
"golang.org/x/time/rate"
)
var (
transportPool = &http.Transport{
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
DialContext: (&net.Dialer{Timeout: 30 * time.Second}).DialContext,
ResponseHeaderTimeout: 60 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
cache = make(map[string]cachedResponse)
defaultCacheTTL = 5 * time.Minute
cacheMutex sync.RWMutex
// Rate limiter per client IP
rateLimiters = make(map[string]*rate.Limiter)
rateMutex sync.RWMutex
rateLimit = rate.Limit(10) // 10 requests per second per client
rateBurst = 20 // Allow burst of 20 requests
)
type cachedResponse struct {
body []byte
headers http.Header
statusCode int
cachedAt time.Time
cacheDuration time.Duration
etag string
}
func getReverseProxy(target string, skipVerify bool) *httputil.ReverseProxy {
targetURL, err := url.Parse(target)
if err != nil {
log.Printf("Error parsing target URL %s: %v", target, err)
errorLogger.Printf("Error parsing target URL %s: %v", target, err)
return nil
}
director := func(req *http.Request) {
req.URL.Scheme = targetURL.Scheme
req.URL.Host = targetURL.Host
// Preserve client's Host header for session continuity
if req.Header.Get("Host") != "" {
req.Host = req.Header.Get("Host")
} else {
req.Host = targetURL.Host
}
// Preserve all request headers, including cookies
for k, v := range req.Header {
req.Header[k] = v
}
// Preserve query parameters and path
req.URL.RawQuery = req.URL.RawQuery
if targetURL.Path != "" {
req.URL.Path = strings.TrimPrefix(req.URL.Path, "/")
req.URL.Path = singleJoin(targetURL.Path, req.URL.Path)
}
// WebSocket support: pass upgrade headers
if strings.ToLower(req.Header.Get("Upgrade")) == "websocket" {
req.Header.Set("Connection", "Upgrade")
}
trafficLogger.Printf("Request: %s %s -> %s [Host: %s]", req.Method, req.URL.String(), target, req.Host)
}
transport := transportPool
if !skipVerify {
transport = &http.Transport{
MaxIdleConns: transportPool.MaxIdleConns,
MaxIdleConnsPerHost: transportPool.MaxIdleConnsPerHost,
IdleConnTimeout: transportPool.IdleConnTimeout,
TLSClientConfig: &tls.Config{InsecureSkipVerify: false},
}
}
return &httputil.ReverseProxy{
Director: director,
Transport: transport,
ModifyResponse: func(resp *http.Response) error {
// Preserve all response headers, including Set-Cookie
for k, v := range resp.Header {
resp.Header[k] = v
}
// Handle WebSocket upgrade
if strings.ToLower(resp.Header.Get("Upgrade")) == "websocket" {
return nil // No further modification needed for WebSocket
}
// Compression if client supports it and response isnt already compressed
if resp.Header.Get("Content-Encoding") == "" && strings.Contains(resp.Request.Header.Get("Accept-Encoding"), "gzip") {
body, err := io.ReadAll(resp.Body)
if err != nil {
errorLogger.Printf("Error reading response body for compression: %v", err)
return err
}
resp.Body.Close()
var buf bytes.Buffer
gw := gzip.NewWriter(&buf)
if _, err := gw.Write(body); err != nil {
errorLogger.Printf("Error compressing response: %v", err)
return err
}
gw.Close()
resp.Body = io.NopCloser(&buf)
resp.Header.Set("Content-Encoding", "gzip")
resp.Header.Del("Content-Length") // Length changes after compression
trafficLogger.Printf("Compressed response for %s", resp.Request.URL.String())
}
// Cache static content with ETag support
if shouldCache(resp) {
body, err := io.ReadAll(resp.Body)
if err != nil {
errorLogger.Printf("Error reading response body for caching: %v", err)
return err
}
resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(body))
etag := resp.Header.Get("ETag")
if etag == "" {
etag = generateETag(body) // Simple ETag generation if not provided
}
cacheDuration := parseCacheControl(resp.Header.Get("Cache-Control"))
if cacheDuration == 0 {
cacheDuration = defaultCacheTTL
}
cacheMutex.Lock()
cache[resp.Request.URL.String()] = cachedResponse{
body: body,
headers: resp.Header.Clone(),
statusCode: resp.StatusCode,
cachedAt: time.Now(),
cacheDuration: cacheDuration,
etag: etag,
}
cacheMutex.Unlock()
trafficLogger.Printf("Cached response for %s [ETag: %s]", resp.Request.URL.String(), etag)
}
trafficLogger.Printf("Response: %s %d from %s", resp.Status, resp.StatusCode, target)
return nil
},
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
log.Printf("Proxy error for %s: %v", r.Host, err)
errorLogger.Printf("Proxy error for %s: %v", r.Host, err)
http.Error(w, "Proxy error", http.StatusBadGateway)
},
}
}
func shouldCache(resp *http.Response) bool {
if resp.Request.Method != "GET" || resp.StatusCode != http.StatusOK {
return false
}
contentType := resp.Header.Get("Content-Type")
return strings.HasPrefix(contentType, "text/") || strings.HasPrefix(contentType, "image/") ||
strings.HasPrefix(contentType, "application/javascript") || strings.HasPrefix(contentType, "application/json")
}
func singleJoin(prefix, suffix string) string {
prefix = strings.TrimSuffix(prefix, "/")
suffix = strings.TrimPrefix(suffix, "/")
return prefix + "/" + suffix
}
func generateETag(body []byte) string {
return fmt.Sprintf(`"%x"`, md5.Sum(body)) // Simple ETag based on MD5 hash
}
func parseCacheControl(header string) time.Duration {
if header == "" {
return 0
}
parts := strings.Split(header, ",")
for _, part := range parts {
if strings.Contains(part, "max-age=") {
ageStr := strings.TrimPrefix(part, "max-age=")
ageStr = strings.TrimSpace(ageStr)
if age, err := strconv.Atoi(ageStr); err == nil {
return time.Duration(age) * time.Second
}
}
}
return 0
}
func getLimiter(ip string) *rate.Limiter {
rateMutex.Lock()
defer rateMutex.Unlock()
if limiter, exists := rateLimiters[ip]; exists {
return limiter
}
limiter := rate.NewLimiter(rateLimit, rateBurst)
rateLimiters[ip] = limiter
return limiter
}
func handler(w http.ResponseWriter, r *http.Request) {
configMux.RLock()
defer configMux.RUnlock()
// Rate limiting based on client IP
clientIP := r.RemoteAddr[:strings.LastIndex(r.RemoteAddr, ":")] // Strip port
limiter := getLimiter(clientIP)
if !limiter.Allow() {
http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
trafficLogger.Printf("Rate limited client %s", clientIP)
return
}
target, exists := config.Routes[r.Host]
skipVerify := config.TrustTarget[r.Host]
noHTTPSRedirect := config.NoHTTPSRedirect[r.Host]
if !exists {
if target, exists = config.Routes["*"]; !exists {
http.Error(w, "Host not configured", http.StatusNotFound)
return
}
skipVerify = config.TrustTarget["*"]
noHTTPSRedirect = config.NoHTTPSRedirect["*"]
}
isHTTPS := target[:5] == "https"
isHTTPReq := r.TLS == nil
if isHTTPReq && isHTTPS && !noHTTPSRedirect {
redirectURL := "https://" + r.Host + r.RequestURI
http.Redirect(w, r, redirectURL, http.StatusMovedPermanently)
return
}
cacheKey := r.URL.String()
cacheMutex.RLock()
if cached, ok := cache[cacheKey]; ok && time.Since(cached.cachedAt) < cached.cacheDuration {
if etag := r.Header.Get("If-None-Match"); etag != "" && etag == cached.etag {
w.WriteHeader(http.StatusNotModified)
trafficLogger.Printf("Served 304 Not Modified from cache: %s [ETag: %s]", cacheKey, cached.etag)
} else {
for k, v := range cached.headers {
w.Header()[k] = v
}
w.WriteHeader(cached.statusCode)
w.Write(cached.body)
trafficLogger.Printf("Served from cache: %s [ETag: %s]", cacheKey, cached.etag)
}
cacheMutex.RUnlock()
return
}
cacheMutex.RUnlock()
proxy := getReverseProxy(target, skipVerify)
if proxy == nil {
http.Error(w, "Invalid target configuration", http.StatusInternalServerError)
return
}
targetURL, err := url.Parse(target)
if err != nil {
errorLogger.Printf("Error parsing target URL %s: %v", target, err)
http.Error(w, "Invalid target configuration", http.StatusInternalServerError)
return
}
// WebSocket upgrade handling
if strings.ToLower(r.Header.Get("Upgrade")) == "websocket" {
hijacker, ok := w.(http.Hijacker)
if !ok {
http.Error(w, "WebSocket upgrade not supported", http.StatusInternalServerError)
return
}
conn, _, err := hijacker.Hijack()
if err != nil {
errorLogger.Printf("Failed to hijack connection for WebSocket: %v", err)
return
}
defer conn.Close()
targetConn, err := transportPool.Dial("tcp", targetURL.Host)
if err != nil {
errorLogger.Printf("Failed to dial target for WebSocket: %v", err)
return
}
defer targetConn.Close()
// Forward request to target
if err := r.Write(targetConn); err != nil {
errorLogger.Printf("Failed to forward WebSocket request: %v", err)
return
}
// Pipe connections
go io.Copy(conn, targetConn)
io.Copy(targetConn, conn)
trafficLogger.Printf("WebSocket connection established: %s -> %s", r.Host, target)
return
}
proxy.ServeHTTP(w, r)
}