From bd1c372435fcca11538a7eea641af7d243fc02ee Mon Sep 17 00:00:00 2001 From: ghostersk Date: Sat, 1 Mar 2025 19:13:22 +0000 Subject: [PATCH] added test --- bcp/{ => v1}/certificate.go | 0 bcp/{ => v1}/config.go | 0 bcp/{ => v1}/main.go | 0 bcp/{ => v1}/proxy.go | 0 bcp/{ => v1}/utils.go | 0 bcp/v2/certificate.go | 124 ++++++++++++++ bcp/v2/config.go | 153 +++++++++++++++++ bcp/v2/main.go | 197 +++++++++++++++++++++ bcp/v2/proxy.go | 332 ++++++++++++++++++++++++++++++++++++ bcp/v2/utils.go | 10 ++ certificate.go | 19 ++- config.go | 43 +++-- main.go | 96 ++++++----- proxy.go | 156 ++++++++++------- proxy_test.go | 168 ++++++++++++++++++ utils.go | 9 +- 16 files changed, 1179 insertions(+), 128 deletions(-) rename bcp/{ => v1}/certificate.go (100%) rename bcp/{ => v1}/config.go (100%) rename bcp/{ => v1}/main.go (100%) rename bcp/{ => v1}/proxy.go (100%) rename bcp/{ => v1}/utils.go (100%) create mode 100644 bcp/v2/certificate.go create mode 100644 bcp/v2/config.go create mode 100644 bcp/v2/main.go create mode 100644 bcp/v2/proxy.go create mode 100644 bcp/v2/utils.go create mode 100644 proxy_test.go diff --git a/bcp/certificate.go b/bcp/v1/certificate.go similarity index 100% rename from bcp/certificate.go rename to bcp/v1/certificate.go diff --git a/bcp/config.go b/bcp/v1/config.go similarity index 100% rename from bcp/config.go rename to bcp/v1/config.go diff --git a/bcp/main.go b/bcp/v1/main.go similarity index 100% rename from bcp/main.go rename to bcp/v1/main.go diff --git a/bcp/proxy.go b/bcp/v1/proxy.go similarity index 100% rename from bcp/proxy.go rename to bcp/v1/proxy.go diff --git a/bcp/utils.go b/bcp/v1/utils.go similarity index 100% rename from bcp/utils.go rename to bcp/v1/utils.go diff --git a/bcp/v2/certificate.go b/bcp/v2/certificate.go new file mode 100644 index 0000000..7fe63f7 --- /dev/null +++ b/bcp/v2/certificate.go @@ -0,0 +1,124 @@ +package main + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "os" + "time" +) + +func generateSelfSignedCert() error { + if err := os.MkdirAll(certDir, 0755); err != nil { + return fmt.Errorf("failed to create certificate directory %s: %v", certDir, err) + } + + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return fmt.Errorf("failed to generate private key: %v", err) + } + + notBefore := time.Now() + notAfter := notBefore.Add(365 * 24 * time.Hour) + serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + if err != nil { + return fmt.Errorf("failed to generate serial number: %v", err) + } + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"Proxy Self-Signed"}, + CommonName: "localhost", + }, + NotBefore: notBefore, + NotAfter: notAfter, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + DNSNames: []string{"localhost", "*.example.com"}, + } + + certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + return fmt.Errorf("failed to create certificate: %v", err) + } + + certOut, err := os.Create(certPath) + if err != nil { + return fmt.Errorf("failed to open %s for writing: %v", certPath, err) + } + if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}); err != nil { + certOut.Close() + return fmt.Errorf("failed to encode certificate: %v", err) + } + certOut.Close() + + keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_WRONLY, 0600) + if err != nil { + return fmt.Errorf("failed to open %s for writing: %v", keyPath, err) + } + if err := pem.Encode(keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}); err != nil { + keyOut.Close() + return fmt.Errorf("failed to encode private key: %v", err) + } + keyOut.Close() + + refreshLogger.Printf("Generated self-signed certificate in %s", certDir) + return nil +} + +func loadCertificate() error { + certFile, err := os.ReadFile(certPath) + if err != nil { + return fmt.Errorf("failed to read certificate %s: %v", certPath, err) + } + keyFile, err := os.ReadFile(keyPath) + if err != nil { + return fmt.Errorf("failed to read key %s: %v", keyPath, err) + } + + newCert, err := tls.X509KeyPair(certFile, keyFile) + if err != nil { + return fmt.Errorf("failed to parse certificate: %v", err) + } + + configMux.Lock() + cert = &newCert + configMux.Unlock() + return nil +} + +func monitorCertificates() { + var lastModTime time.Time + for { + certInfo, err := os.Stat(certPath) + if err != nil { + errorLogger.Printf("Error checking certificate: %v", err) + time.Sleep(5 * time.Second) + continue + } + + keyInfo, err := os.Stat(keyPath) + if err != nil { + errorLogger.Printf("Error checking key: %v", err) + time.Sleep(5 * time.Second) + continue + } + + if certInfo.ModTime() != lastModTime || keyInfo.ModTime() != lastModTime { + if err := loadCertificate(); err != nil { + errorLogger.Printf("Error reloading certificate: %v", err) + } else { + refreshLogger.Println("Certificate reloaded successfully") + lastModTime = certInfo.ModTime() + } + } + time.Sleep(5 * time.Second) + } +} diff --git a/bcp/v2/config.go b/bcp/v2/config.go new file mode 100644 index 0000000..c7cd801 --- /dev/null +++ b/bcp/v2/config.go @@ -0,0 +1,153 @@ +package main + +import ( + "fmt" + "os" + "time" + + "gopkg.in/yaml.v2" +) + +type Config struct { + ListenHTTP string `yaml:"listen_http"` + ListenHTTPS string `yaml:"listen_https"` + CertDir string `yaml:"cert_dir"` + CertFile string `yaml:"cert_file"` + KeyFile string `yaml:"key_file"` + Routes map[string]string `yaml:"routes"` + TrustTarget map[string]bool `yaml:"trust_target"` + NoHTTPSRedirect map[string]bool `yaml:"no_https_redirect"` +} + +func loadConfig() (Config, error) { + data, err := os.ReadFile(configPath) + if err != nil { + return Config{}, fmt.Errorf("failed to read config %s: %v", configPath, err) + } + + var cfg Config + err = yaml.Unmarshal(data, &cfg) + if err != nil { + return Config{}, fmt.Errorf("failed to unmarshal config: %v", err) + } + return cfg, nil +} + +func generateDefaultConfig() Config { + return Config{ + ListenHTTP: ":80", + ListenHTTPS: ":443", + CertDir: "certificates", + CertFile: "certificate.pem", + KeyFile: "key.pem", + Routes: map[string]string{ + "*": "https://127.0.0.1:3000", + "main.example.com": "https://10.100.111.254:4444", + }, + TrustTarget: map[string]bool{ + "*": true, + "main.example.com": true, + }, + NoHTTPSRedirect: map[string]bool{ + "*": false, + "main.example.com": false, + }, + } +} + +func saveConfig(cfg Config) error { + if err := os.MkdirAll(baseDir, 0755); err != nil { + return fmt.Errorf("failed to create base directory %s: %v", baseDir, err) + } + + data, err := yaml.Marshal(cfg) + if err != nil { + return fmt.Errorf("failed to marshal config: %v", err) + } + + err = os.WriteFile(configPath, data, 0644) + if err != nil { + return fmt.Errorf("failed to write config to %s: %v", configPath, err) + } + return nil +} + +func monitorConfig() { + var lastModTime time.Time + for { + configInfo, err := os.Stat(configPath) + if err != nil { + errorLogger.Printf("Error checking config file: %v", err) + time.Sleep(5 * time.Second) + continue + } + + if configInfo.ModTime() != lastModTime { + newConfig, err := loadConfig() + if err != nil { + errorLogger.Printf("Error reloading config: %v", err) + } else { + configMux.Lock() + if newConfig.ListenHTTP != config.ListenHTTP { + config.ListenHTTP = newConfig.ListenHTTP + refreshLogger.Printf("Updated listen_http to %s", config.ListenHTTP) + } + if newConfig.ListenHTTPS != config.ListenHTTPS { + config.ListenHTTPS = newConfig.ListenHTTPS + refreshLogger.Printf("Updated listen_https to %s", config.ListenHTTPS) + } + if newConfig.CertDir != config.CertDir || newConfig.CertFile != config.CertFile || newConfig.KeyFile != config.KeyFile { + config.CertDir = newConfig.CertDir + config.CertFile = newConfig.CertFile + config.KeyFile = newConfig.KeyFile + updatePaths() + if err := loadCertificate(); err != nil { + errorLogger.Printf("Error reloading certificate after path change: %v", err) + } else { + refreshLogger.Println("Updated certificate paths and reloaded certificate") + } + } + for k, v := range newConfig.Routes { + if oldV, exists := config.Routes[k]; !exists || oldV != v { + config.Routes[k] = v + refreshLogger.Printf("Updated route %s to %s", k, v) + } + } + for k := range config.Routes { + if _, exists := newConfig.Routes[k]; !exists { + delete(config.Routes, k) + refreshLogger.Printf("Removed route %s", k) + } + } + for k, v := range newConfig.TrustTarget { + if oldV, exists := config.TrustTarget[k]; !exists || oldV != v { + config.TrustTarget[k] = v + refreshLogger.Printf("Updated trust_target %s to %v", k, v) + } + } + for k := range config.TrustTarget { + if _, exists := newConfig.TrustTarget[k]; !exists { + delete(config.TrustTarget, k) + refreshLogger.Printf("Removed trust_target %s", k) + } + } + for k, v := range newConfig.NoHTTPSRedirect { + if oldV, exists := config.NoHTTPSRedirect[k]; !exists || oldV != v { + config.NoHTTPSRedirect[k] = v + refreshLogger.Printf("Updated no_https_redirect %s to %v", k, v) + } + } + for k := range config.NoHTTPSRedirect { + if _, exists := newConfig.NoHTTPSRedirect[k]; !exists { + delete(config.NoHTTPSRedirect, k) + refreshLogger.Printf("Removed no_https_redirect %s", k) + } + } + configMux.Unlock() + refreshLogger.Println("Config reloaded successfully") + lastModTime = configInfo.ModTime() + } + } + time.Sleep(5 * time.Second) + } +} diff --git a/bcp/v2/main.go b/bcp/v2/main.go new file mode 100644 index 0000000..e245f60 --- /dev/null +++ b/bcp/v2/main.go @@ -0,0 +1,197 @@ +package main + +import ( + "crypto/tls" + "io" + "log" + "net/http" + "os" + "path/filepath" + "strings" + "sync" + "time" +) + +var ( + config Config + configMux sync.RWMutex + cert *tls.Certificate + baseDir string + certDir string + certPath string + keyPath string + configPath string + + // Loggers + errorLogger *log.Logger + refreshLogger *log.Logger + trafficLogger *log.Logger +) + +func init() { + // Get the absolute path of the running executable (or main.go in `go run`) + exePath, err := os.Executable() + if err != nil { + log.Fatalf("Failed to get executable path: %v", err) + } + baseDir = filepath.Dir(exePath) + /* + if runtime.GOOS == "windows" && len(os.Args) > 0 && filepath.Ext(os.Args[0]) == ".go" { + var err error + baseDir, err = os.Getwd() + if err != nil { + log.Fatalf("Failed to get working directory: %v", err) + } + } else { + exePath, err := os.Executable() + if err != nil { + log.Fatalf("Failed to get executable path: %v", err) + } + baseDir = filepath.Dir(exePath) + } + */ + // Setup logging + setupLogging() +} + +func setupLogging() { + logsDir := filepath.Join(baseDir, "logs") + if err := os.MkdirAll(logsDir, 0755); err != nil { + log.Fatalf("Failed to create logs directory: %v", err) + } + + // Error log + errorFile, err := os.OpenFile(filepath.Join(logsDir, "errors.log"), os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + log.Fatalf("Failed to open errors.log: %v", err) + } + errorLogger = log.New(io.MultiWriter(os.Stdout, errorFile), "ERROR: ", log.Ldate|log.Ltime|log.Lshortfile) + + // Refresh log + refreshFile, err := os.OpenFile(filepath.Join(logsDir, "refresh.log"), os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + log.Fatalf("Failed to open refresh.log: %v", err) + } + refreshLogger = log.New(io.MultiWriter(os.Stdout, refreshFile), "REFRESH: ", log.Ldate|log.Ltime|log.Lshortfile) + + // Traffic log (daily rotation) + go manageTrafficLogs(logsDir) +} + +func manageTrafficLogs(logsDir string) { + for { + dateStr := time.Now().Format("2006-01-02") + trafficFile, err := os.OpenFile(filepath.Join(logsDir, "traffic-"+dateStr+".log"), os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + errorLogger.Printf("Failed to open traffic log: %v", err) + time.Sleep(1 * time.Minute) + continue + } + trafficLogger = log.New(io.MultiWriter(os.Stdout, trafficFile), "TRAFFIC: ", log.Ldate|log.Ltime) + + // Cleanup old logs + cleanupOldLogs(logsDir) + + // Wait until next day + nextDay := time.Now().Truncate(24 * time.Hour).Add(24 * time.Hour) + time.Sleep(time.Until(nextDay)) + trafficFile.Close() + } +} + +func cleanupOldLogs(logsDir string) { + files, err := os.ReadDir(logsDir) + if err != nil { + errorLogger.Printf("Failed to read logs directory: %v", err) + return + } + + cutoff := time.Now().AddDate(0, 0, -7) // 7 days ago + for _, file := range files { + if strings.HasPrefix(file.Name(), "traffic-") && file.Name() != "traffic-"+time.Now().Format("2006-01-02")+".log" { + info, err := file.Info() + if err != nil { + continue + } + if info.ModTime().Before(cutoff) { + if err := os.Remove(filepath.Join(logsDir, file.Name())); err != nil { + errorLogger.Printf("Failed to remove old log %s: %v", file.Name(), err) + } else { + refreshLogger.Printf("Removed old traffic log: %s", file.Name()) + } + } + } + } +} + +func main() { + configPath = filepath.Join(baseDir, "config.yaml") + if _, err := os.Stat(configPath); os.IsNotExist(err) { + config = generateDefaultConfig() + if err := saveConfig(config); err != nil { + errorLogger.Fatalf("Failed to save default config: %v", err) + } + refreshLogger.Println("Generated default config file") + } else { + cfg, err := loadConfig() + if err != nil { + errorLogger.Fatalf("Failed to load config: %v", err) + } + config = cfg + } + + updatePaths() + + _, certErr := os.Stat(certPath) + _, keyErr := os.Stat(keyPath) + if os.IsNotExist(certErr) || os.IsNotExist(keyErr) { + if err := generateSelfSignedCert(); err != nil { + errorLogger.Fatalf("Failed to generate self-signed certificate: %v", err) + } + } + + if err := loadCertificate(); err != nil { + errorLogger.Fatalf("Failed to load certificate: %v", err) + } + + go monitorCertificates() + go monitorConfig() + + httpServer := &http.Server{ + Addr: config.ListenHTTP, + Handler: http.HandlerFunc(handler), + MaxHeaderBytes: 1 << 20, + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + } + + tlsConfig := &tls.Config{ + GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + configMux.RLock() + defer configMux.RUnlock() + return cert, nil + }, + } + httpsServer := &http.Server{ + Addr: config.ListenHTTPS, + Handler: http.HandlerFunc(handler), + TLSConfig: tlsConfig, + MaxHeaderBytes: 1 << 20, + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + } + + go func() { + log.Printf("Starting HTTP server on %s", config.ListenHTTP) + trafficLogger.Printf("Starting HTTP server on %s", config.ListenHTTP) + if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { + errorLogger.Fatalf("HTTP server error: %v", err) + } + }() + + log.Printf("Starting HTTPS server on %s", config.ListenHTTPS) + trafficLogger.Printf("Starting HTTPS server on %s", config.ListenHTTPS) + if err := httpsServer.ListenAndServeTLS("", ""); err != nil && err != http.ErrServerClosed { + errorLogger.Fatalf("HTTPS server error: %v", err) + } +} diff --git a/bcp/v2/proxy.go b/bcp/v2/proxy.go new file mode 100644 index 0000000..c458e92 --- /dev/null +++ b/bcp/v2/proxy.go @@ -0,0 +1,332 @@ +package main + +import ( + "bytes" + "compress/gzip" + "crypto/md5" + "crypto/tls" + "fmt" + "io" + "log" + "net" + "net/http" + "net/http/httputil" + "net/url" + "strconv" + "strings" + "sync" + "time" + + "golang.org/x/time/rate" +) + +var ( + transportPool = &http.Transport{ + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 90 * time.Second, + DialContext: (&net.Dialer{Timeout: 30 * time.Second}).DialContext, + ResponseHeaderTimeout: 60 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + cache = make(map[string]cachedResponse) + defaultCacheTTL = 5 * time.Minute + cacheMutex sync.RWMutex + + // Rate limiter per client IP + rateLimiters = make(map[string]*rate.Limiter) + rateMutex sync.RWMutex + rateLimit = rate.Limit(10) // 10 requests per second per client + rateBurst = 20 // Allow burst of 20 requests +) + +type cachedResponse struct { + body []byte + headers http.Header + statusCode int + cachedAt time.Time + cacheDuration time.Duration + etag string +} + +func getReverseProxy(target string, skipVerify bool) *httputil.ReverseProxy { + targetURL, err := url.Parse(target) + if err != nil { + log.Printf("Error parsing target URL %s: %v", target, err) + errorLogger.Printf("Error parsing target URL %s: %v", target, err) + return nil + } + + director := func(req *http.Request) { + req.URL.Scheme = targetURL.Scheme + req.URL.Host = targetURL.Host + // Preserve client's Host header for session continuity + if req.Header.Get("Host") != "" { + req.Host = req.Header.Get("Host") + } else { + req.Host = targetURL.Host + } + + // Preserve all request headers, including cookies + for k, v := range req.Header { + req.Header[k] = v + } + + // Preserve query parameters and path + req.URL.RawQuery = req.URL.RawQuery + if targetURL.Path != "" { + req.URL.Path = strings.TrimPrefix(req.URL.Path, "/") + req.URL.Path = singleJoin(targetURL.Path, req.URL.Path) + } + + // WebSocket support: pass upgrade headers + if strings.ToLower(req.Header.Get("Upgrade")) == "websocket" { + req.Header.Set("Connection", "Upgrade") + } + + trafficLogger.Printf("Request: %s %s -> %s [Host: %s]", req.Method, req.URL.String(), target, req.Host) + } + + transport := transportPool + if !skipVerify { + transport = &http.Transport{ + MaxIdleConns: transportPool.MaxIdleConns, + MaxIdleConnsPerHost: transportPool.MaxIdleConnsPerHost, + IdleConnTimeout: transportPool.IdleConnTimeout, + TLSClientConfig: &tls.Config{InsecureSkipVerify: false}, + } + } + + return &httputil.ReverseProxy{ + Director: director, + Transport: transport, + ModifyResponse: func(resp *http.Response) error { + // Preserve all response headers, including Set-Cookie + for k, v := range resp.Header { + resp.Header[k] = v + } + + // Handle WebSocket upgrade + if strings.ToLower(resp.Header.Get("Upgrade")) == "websocket" { + return nil // No further modification needed for WebSocket + } + + // Compression if client supports it and response isn’t already compressed + if resp.Header.Get("Content-Encoding") == "" && strings.Contains(resp.Request.Header.Get("Accept-Encoding"), "gzip") { + body, err := io.ReadAll(resp.Body) + if err != nil { + errorLogger.Printf("Error reading response body for compression: %v", err) + return err + } + resp.Body.Close() + + var buf bytes.Buffer + gw := gzip.NewWriter(&buf) + if _, err := gw.Write(body); err != nil { + errorLogger.Printf("Error compressing response: %v", err) + return err + } + gw.Close() + + resp.Body = io.NopCloser(&buf) + resp.Header.Set("Content-Encoding", "gzip") + resp.Header.Del("Content-Length") // Length changes after compression + trafficLogger.Printf("Compressed response for %s", resp.Request.URL.String()) + } + + // Cache static content with ETag support + if shouldCache(resp) { + body, err := io.ReadAll(resp.Body) + if err != nil { + errorLogger.Printf("Error reading response body for caching: %v", err) + return err + } + resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(body)) + + etag := resp.Header.Get("ETag") + if etag == "" { + etag = generateETag(body) // Simple ETag generation if not provided + } + + cacheDuration := parseCacheControl(resp.Header.Get("Cache-Control")) + if cacheDuration == 0 { + cacheDuration = defaultCacheTTL + } + + cacheMutex.Lock() + cache[resp.Request.URL.String()] = cachedResponse{ + body: body, + headers: resp.Header.Clone(), + statusCode: resp.StatusCode, + cachedAt: time.Now(), + cacheDuration: cacheDuration, + etag: etag, + } + cacheMutex.Unlock() + trafficLogger.Printf("Cached response for %s [ETag: %s]", resp.Request.URL.String(), etag) + } + + trafficLogger.Printf("Response: %s %d from %s", resp.Status, resp.StatusCode, target) + return nil + }, + ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) { + log.Printf("Proxy error for %s: %v", r.Host, err) + errorLogger.Printf("Proxy error for %s: %v", r.Host, err) + http.Error(w, "Proxy error", http.StatusBadGateway) + }, + } +} + +func shouldCache(resp *http.Response) bool { + if resp.Request.Method != "GET" || resp.StatusCode != http.StatusOK { + return false + } + contentType := resp.Header.Get("Content-Type") + return strings.HasPrefix(contentType, "text/") || strings.HasPrefix(contentType, "image/") || + strings.HasPrefix(contentType, "application/javascript") || strings.HasPrefix(contentType, "application/json") +} + +func singleJoin(prefix, suffix string) string { + prefix = strings.TrimSuffix(prefix, "/") + suffix = strings.TrimPrefix(suffix, "/") + return prefix + "/" + suffix +} + +func generateETag(body []byte) string { + return fmt.Sprintf(`"%x"`, md5.Sum(body)) // Simple ETag based on MD5 hash +} + +func parseCacheControl(header string) time.Duration { + if header == "" { + return 0 + } + parts := strings.Split(header, ",") + for _, part := range parts { + if strings.Contains(part, "max-age=") { + ageStr := strings.TrimPrefix(part, "max-age=") + ageStr = strings.TrimSpace(ageStr) + if age, err := strconv.Atoi(ageStr); err == nil { + return time.Duration(age) * time.Second + } + } + } + return 0 +} + +func getLimiter(ip string) *rate.Limiter { + rateMutex.Lock() + defer rateMutex.Unlock() + + if limiter, exists := rateLimiters[ip]; exists { + return limiter + } + limiter := rate.NewLimiter(rateLimit, rateBurst) + rateLimiters[ip] = limiter + return limiter +} + +func handler(w http.ResponseWriter, r *http.Request) { + configMux.RLock() + defer configMux.RUnlock() + + // Rate limiting based on client IP + clientIP := r.RemoteAddr[:strings.LastIndex(r.RemoteAddr, ":")] // Strip port + limiter := getLimiter(clientIP) + if !limiter.Allow() { + http.Error(w, "Too Many Requests", http.StatusTooManyRequests) + trafficLogger.Printf("Rate limited client %s", clientIP) + return + } + + target, exists := config.Routes[r.Host] + skipVerify := config.TrustTarget[r.Host] + noHTTPSRedirect := config.NoHTTPSRedirect[r.Host] + + if !exists { + if target, exists = config.Routes["*"]; !exists { + http.Error(w, "Host not configured", http.StatusNotFound) + return + } + skipVerify = config.TrustTarget["*"] + noHTTPSRedirect = config.NoHTTPSRedirect["*"] + } + + isHTTPS := target[:5] == "https" + isHTTPReq := r.TLS == nil + + if isHTTPReq && isHTTPS && !noHTTPSRedirect { + redirectURL := "https://" + r.Host + r.RequestURI + http.Redirect(w, r, redirectURL, http.StatusMovedPermanently) + return + } + + cacheKey := r.URL.String() + cacheMutex.RLock() + if cached, ok := cache[cacheKey]; ok && time.Since(cached.cachedAt) < cached.cacheDuration { + if etag := r.Header.Get("If-None-Match"); etag != "" && etag == cached.etag { + w.WriteHeader(http.StatusNotModified) + trafficLogger.Printf("Served 304 Not Modified from cache: %s [ETag: %s]", cacheKey, cached.etag) + } else { + for k, v := range cached.headers { + w.Header()[k] = v + } + w.WriteHeader(cached.statusCode) + w.Write(cached.body) + trafficLogger.Printf("Served from cache: %s [ETag: %s]", cacheKey, cached.etag) + } + cacheMutex.RUnlock() + return + } + cacheMutex.RUnlock() + + proxy := getReverseProxy(target, skipVerify) + if proxy == nil { + http.Error(w, "Invalid target configuration", http.StatusInternalServerError) + return + } + + targetURL, err := url.Parse(target) + if err != nil { + errorLogger.Printf("Error parsing target URL %s: %v", target, err) + http.Error(w, "Invalid target configuration", http.StatusInternalServerError) + return + } + + // WebSocket upgrade handling + if strings.ToLower(r.Header.Get("Upgrade")) == "websocket" { + hijacker, ok := w.(http.Hijacker) + if !ok { + http.Error(w, "WebSocket upgrade not supported", http.StatusInternalServerError) + return + } + conn, _, err := hijacker.Hijack() + if err != nil { + errorLogger.Printf("Failed to hijack connection for WebSocket: %v", err) + return + } + defer conn.Close() + + targetConn, err := transportPool.Dial("tcp", targetURL.Host) + if err != nil { + errorLogger.Printf("Failed to dial target for WebSocket: %v", err) + return + } + defer targetConn.Close() + + // Forward request to target + if err := r.Write(targetConn); err != nil { + errorLogger.Printf("Failed to forward WebSocket request: %v", err) + return + } + + // Pipe connections + go io.Copy(conn, targetConn) + io.Copy(targetConn, conn) + trafficLogger.Printf("WebSocket connection established: %s -> %s", r.Host, target) + return + } + + proxy.ServeHTTP(w, r) +} diff --git a/bcp/v2/utils.go b/bcp/v2/utils.go new file mode 100644 index 0000000..50c4e43 --- /dev/null +++ b/bcp/v2/utils.go @@ -0,0 +1,10 @@ +package main + +import "path/filepath" + +func updatePaths() { + certDir = filepath.Join(baseDir, config.CertDir) + certPath = filepath.Join(certDir, config.CertFile) + keyPath = filepath.Join(certDir, config.KeyFile) + configPath = filepath.Join(baseDir, "config.yaml") +} diff --git a/certificate.go b/certificate.go index 7fe63f7..7859b02 100644 --- a/certificate.go +++ b/certificate.go @@ -13,23 +13,30 @@ import ( "time" ) +// generateSelfSignedCert creates a self-signed TLS certificate if one doesn’t exist func generateSelfSignedCert() error { + // Ensure certificate directory exists if err := os.MkdirAll(certDir, 0755); err != nil { return fmt.Errorf("failed to create certificate directory %s: %v", certDir, err) } + // Generate a 2048-bit RSA private key priv, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { return fmt.Errorf("failed to generate private key: %v", err) } + // Set certificate validity period (1 year) notBefore := time.Now() notAfter := notBefore.Add(365 * 24 * time.Hour) + + // Generate a random serial number serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) if err != nil { return fmt.Errorf("failed to generate serial number: %v", err) } + // Define certificate template with basic attributes template := x509.Certificate{ SerialNumber: serialNumber, Subject: pkix.Name{ @@ -41,14 +48,16 @@ func generateSelfSignedCert() error { KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, BasicConstraintsValid: true, - DNSNames: []string{"localhost", "*.example.com"}, + DNSNames: []string{"localhost", "*.example.com"}, // Supported domains } + // Create self-signed certificate certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) if err != nil { return fmt.Errorf("failed to create certificate: %v", err) } + // Write certificate to file certOut, err := os.Create(certPath) if err != nil { return fmt.Errorf("failed to open %s for writing: %v", certPath, err) @@ -59,7 +68,8 @@ func generateSelfSignedCert() error { } certOut.Close() - keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_WRONLY, 0600) + // Write private key to file + keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) if err != nil { return fmt.Errorf("failed to open %s for writing: %v", keyPath, err) } @@ -73,6 +83,7 @@ func generateSelfSignedCert() error { return nil } +// loadCertificate reads and parses the certificate and key files into memory func loadCertificate() error { certFile, err := os.ReadFile(certPath) if err != nil { @@ -94,6 +105,7 @@ func loadCertificate() error { return nil } +// monitorCertificates watches for changes to certificate/key files and reloads them func monitorCertificates() { var lastModTime time.Time for { @@ -111,6 +123,7 @@ func monitorCertificates() { continue } + // Reload certificate if either file has changed if certInfo.ModTime() != lastModTime || keyInfo.ModTime() != lastModTime { if err := loadCertificate(); err != nil { errorLogger.Printf("Error reloading certificate: %v", err) @@ -119,6 +132,6 @@ func monitorCertificates() { lastModTime = certInfo.ModTime() } } - time.Sleep(5 * time.Second) + time.Sleep(5 * time.Second) // Poll every 5 seconds } } diff --git a/config.go b/config.go index c7cd801..a2558fc 100644 --- a/config.go +++ b/config.go @@ -8,17 +8,19 @@ import ( "gopkg.in/yaml.v2" ) +// Config defines the structure of the proxy configuration loaded from config.yaml type Config struct { - ListenHTTP string `yaml:"listen_http"` - ListenHTTPS string `yaml:"listen_https"` - CertDir string `yaml:"cert_dir"` - CertFile string `yaml:"cert_file"` - KeyFile string `yaml:"key_file"` - Routes map[string]string `yaml:"routes"` - TrustTarget map[string]bool `yaml:"trust_target"` - NoHTTPSRedirect map[string]bool `yaml:"no_https_redirect"` + ListenHTTP string `yaml:"listen_http"` // Port for HTTP server (e.g., ":80") + ListenHTTPS string `yaml:"listen_https"` // Port for HTTPS server (e.g., ":443") + CertDir string `yaml:"cert_dir"` // Directory for certificate files + CertFile string `yaml:"cert_file"` // Certificate filename + KeyFile string `yaml:"key_file"` // Private key filename + Routes map[string]string `yaml:"routes"` // Mapping of hostnames to target URLs + TrustTarget map[string]bool `yaml:"trust_target"` // Whether to skip TLS verification for targets + NoHTTPSRedirect map[string]bool `yaml:"no_https_redirect"` // Whether to skip HTTP->HTTPS redirect } +// loadConfig reads and parses the config.yaml file func loadConfig() (Config, error) { data, err := os.ReadFile(configPath) if err != nil { @@ -26,35 +28,36 @@ func loadConfig() (Config, error) { } var cfg Config - err = yaml.Unmarshal(data, &cfg) - if err != nil { + if err := yaml.Unmarshal(data, &cfg); err != nil { return Config{}, fmt.Errorf("failed to unmarshal config: %v", err) } return cfg, nil } +// generateDefaultConfig creates a default configuration if config.yaml doesn’t exist func generateDefaultConfig() Config { return Config{ ListenHTTP: ":80", ListenHTTPS: ":443", - CertDir: "certificates", + CertDir: "./certificate", CertFile: "certificate.pem", KeyFile: "key.pem", Routes: map[string]string{ - "*": "https://127.0.0.1:3000", - "main.example.com": "https://10.100.111.254:4444", + "*": "https://127.0.0.1:3000", // Wildcard route + "main.example.com": "https://10.100.111.254:4444", // Specific route }, TrustTarget: map[string]bool{ - "*": true, + "*": true, // Skip TLS verification by default "main.example.com": true, }, NoHTTPSRedirect: map[string]bool{ - "*": false, + "*": false, // Redirect HTTP to HTTPS by default "main.example.com": false, }, } } +// saveConfig writes the configuration to config.yaml func saveConfig(cfg Config) error { if err := os.MkdirAll(baseDir, 0755); err != nil { return fmt.Errorf("failed to create base directory %s: %v", baseDir, err) @@ -65,13 +68,13 @@ func saveConfig(cfg Config) error { return fmt.Errorf("failed to marshal config: %v", err) } - err = os.WriteFile(configPath, data, 0644) - if err != nil { + if err := os.WriteFile(configPath, data, 0644); err != nil { return fmt.Errorf("failed to write config to %s: %v", configPath, err) } return nil } +// monitorConfig watches config.yaml for changes and updates the in-memory config func monitorConfig() { var lastModTime time.Time for { @@ -88,6 +91,7 @@ func monitorConfig() { errorLogger.Printf("Error reloading config: %v", err) } else { configMux.Lock() + // Update individual fields only if they’ve changed if newConfig.ListenHTTP != config.ListenHTTP { config.ListenHTTP = newConfig.ListenHTTP refreshLogger.Printf("Updated listen_http to %s", config.ListenHTTP) @@ -107,6 +111,7 @@ func monitorConfig() { refreshLogger.Println("Updated certificate paths and reloaded certificate") } } + // Update routes for k, v := range newConfig.Routes { if oldV, exists := config.Routes[k]; !exists || oldV != v { config.Routes[k] = v @@ -119,6 +124,7 @@ func monitorConfig() { refreshLogger.Printf("Removed route %s", k) } } + // Update trust_target for k, v := range newConfig.TrustTarget { if oldV, exists := config.TrustTarget[k]; !exists || oldV != v { config.TrustTarget[k] = v @@ -131,6 +137,7 @@ func monitorConfig() { refreshLogger.Printf("Removed trust_target %s", k) } } + // Update no_https_redirect for k, v := range newConfig.NoHTTPSRedirect { if oldV, exists := config.NoHTTPSRedirect[k]; !exists || oldV != v { config.NoHTTPSRedirect[k] = v @@ -148,6 +155,6 @@ func monitorConfig() { lastModTime = configInfo.ModTime() } } - time.Sleep(5 * time.Second) + time.Sleep(5 * time.Second) // Poll every 5 seconds } } diff --git a/main.go b/main.go index e245f60..eda63d3 100644 --- a/main.go +++ b/main.go @@ -7,77 +7,78 @@ import ( "net/http" "os" "path/filepath" + "runtime" "strings" "sync" "time" ) +// Shared global variables used across all files var ( - config Config - configMux sync.RWMutex - cert *tls.Certificate - baseDir string - certDir string - certPath string - keyPath string - configPath string + config Config // Holds the proxy configuration loaded from config.yaml + configMux sync.RWMutex // Mutex for thread-safe access to config + cert *tls.Certificate // TLS certificate for HTTPS server + baseDir string // Base directory (working dir for go run, executable dir for binary) + certDir string // Directory for certificates + certPath string // Path to certificate file + keyPath string // Path to private key file + configPath string // Path to config.yaml - // Loggers - errorLogger *log.Logger - refreshLogger *log.Logger - trafficLogger *log.Logger + // Loggers for different types of events + errorLogger *log.Logger // Logs errors to errors.log and console + refreshLogger *log.Logger // Logs config/certificate refreshes to refresh.log and console + trafficLogger *log.Logger // Logs traffic to traffic-YYYY-MM-DD.log and console ) +// init sets up the base directory and initializes logging func init() { - // Get the absolute path of the running executable (or main.go in `go run`) - exePath, err := os.Executable() - if err != nil { - log.Fatalf("Failed to get executable path: %v", err) - } - baseDir = filepath.Dir(exePath) - /* - if runtime.GOOS == "windows" && len(os.Args) > 0 && filepath.Ext(os.Args[0]) == ".go" { - var err error - baseDir, err = os.Getwd() - if err != nil { - log.Fatalf("Failed to get working directory: %v", err) - } - } else { - exePath, err := os.Executable() - if err != nil { - log.Fatalf("Failed to get executable path: %v", err) - } - baseDir = filepath.Dir(exePath) + // Determine base directory based on execution context + if runtime.GOOS == "windows" && len(os.Args) > 0 && filepath.Ext(os.Args[0]) == ".go" { + // For "go run", use current working directory + var err error + baseDir, err = os.Getwd() + if err != nil { + log.Fatalf("Failed to get working directory: %v", err) } - */ - // Setup logging + } else { + // For compiled binary, use executable's directory + exePath, err := os.Executable() + if err != nil { + log.Fatalf("Failed to get executable path: %v", err) + } + baseDir = filepath.Dir(exePath) + } + + // Initialize logging to files and console setupLogging() } +// setupLogging configures loggers for errors, refreshes, and traffic func setupLogging() { logsDir := filepath.Join(baseDir, "logs") if err := os.MkdirAll(logsDir, 0755); err != nil { log.Fatalf("Failed to create logs directory: %v", err) } - // Error log + // Error log: writes to errors.log and stdout errorFile, err := os.OpenFile(filepath.Join(logsDir, "errors.log"), os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) if err != nil { log.Fatalf("Failed to open errors.log: %v", err) } errorLogger = log.New(io.MultiWriter(os.Stdout, errorFile), "ERROR: ", log.Ldate|log.Ltime|log.Lshortfile) - // Refresh log + // Refresh log: writes to refresh.log and stdout refreshFile, err := os.OpenFile(filepath.Join(logsDir, "refresh.log"), os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) if err != nil { log.Fatalf("Failed to open refresh.log: %v", err) } refreshLogger = log.New(io.MultiWriter(os.Stdout, refreshFile), "REFRESH: ", log.Ldate|log.Ltime|log.Lshortfile) - // Traffic log (daily rotation) + // Traffic log: managed in a goroutine for daily rotation go manageTrafficLogs(logsDir) } +// manageTrafficLogs handles daily rotation of traffic logs with 7-day retention func manageTrafficLogs(logsDir string) { for { dateStr := time.Now().Format("2006-01-02") @@ -89,7 +90,7 @@ func manageTrafficLogs(logsDir string) { } trafficLogger = log.New(io.MultiWriter(os.Stdout, trafficFile), "TRAFFIC: ", log.Ldate|log.Ltime) - // Cleanup old logs + // Cleanup logs older than 7 days cleanupOldLogs(logsDir) // Wait until next day @@ -99,6 +100,7 @@ func manageTrafficLogs(logsDir string) { } } +// cleanupOldLogs removes traffic logs older than 7 days func cleanupOldLogs(logsDir string) { files, err := os.ReadDir(logsDir) if err != nil { @@ -106,7 +108,7 @@ func cleanupOldLogs(logsDir string) { return } - cutoff := time.Now().AddDate(0, 0, -7) // 7 days ago + cutoff := time.Now().AddDate(0, 0, -7) for _, file := range files { if strings.HasPrefix(file.Name(), "traffic-") && file.Name() != "traffic-"+time.Now().Format("2006-01-02")+".log" { info, err := file.Info() @@ -124,7 +126,9 @@ func cleanupOldLogs(logsDir string) { } } +// main is the entry point, setting up and running the HTTP/HTTPS servers func main() { + // Set config file path and load or generate initial config configPath = filepath.Join(baseDir, "config.yaml") if _, err := os.Stat(configPath); os.IsNotExist(err) { config = generateDefaultConfig() @@ -140,8 +144,10 @@ func main() { config = cfg } + // Update certificate and config paths based on loaded config updatePaths() + // Generate or load TLS certificate _, certErr := os.Stat(certPath) _, keyErr := os.Stat(keyPath) if os.IsNotExist(certErr) || os.IsNotExist(keyErr) { @@ -149,22 +155,24 @@ func main() { errorLogger.Fatalf("Failed to generate self-signed certificate: %v", err) } } - if err := loadCertificate(); err != nil { errorLogger.Fatalf("Failed to load certificate: %v", err) } + // Start background monitoring for config and certificate changes go monitorCertificates() go monitorConfig() + // Configure HTTP server with timeouts for robustness httpServer := &http.Server{ Addr: config.ListenHTTP, Handler: http.HandlerFunc(handler), - MaxHeaderBytes: 1 << 20, + MaxHeaderBytes: 1 << 20, // 1 MB max header size ReadTimeout: 10 * time.Second, WriteTimeout: 10 * time.Second, } + // Configure HTTPS server with TLS and certificate fetching tlsConfig := &tls.Config{ GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) { configMux.RLock() @@ -181,16 +189,16 @@ func main() { WriteTimeout: 10 * time.Second, } + // Start HTTP server in a goroutine go func() { - log.Printf("Starting HTTP server on %s", config.ListenHTTP) - trafficLogger.Printf("Starting HTTP server on %s", config.ListenHTTP) + refreshLogger.Printf("Starting HTTP server on %s", config.ListenHTTP) if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { errorLogger.Fatalf("HTTP server error: %v", err) } }() - log.Printf("Starting HTTPS server on %s", config.ListenHTTPS) - trafficLogger.Printf("Starting HTTPS server on %s", config.ListenHTTPS) + // Start HTTPS server in the main goroutine + refreshLogger.Printf("Starting HTTPS server on %s", config.ListenHTTPS) if err := httpsServer.ListenAndServeTLS("", ""); err != nil && err != http.ErrServerClosed { errorLogger.Fatalf("HTTPS server error: %v", err) } diff --git a/proxy.go b/proxy.go index c458e92..69cbe3a 100644 --- a/proxy.go +++ b/proxy.go @@ -21,26 +21,29 @@ import ( ) var ( + // transportPool configures HTTP transport with timeouts matching Nginx defaults transportPool = &http.Transport{ - MaxIdleConns: 100, - MaxIdleConnsPerHost: 10, - IdleConnTimeout: 90 * time.Second, - DialContext: (&net.Dialer{Timeout: 30 * time.Second}).DialContext, - ResponseHeaderTimeout: 60 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + MaxIdleConns: 100, // Max idle connections + MaxIdleConnsPerHost: 10, // Max idle per host + IdleConnTimeout: 90 * time.Second, // Idle connection timeout + DialContext: (&net.Dialer{Timeout: 30 * time.Second}).DialContext, // Dial timeout + ResponseHeaderTimeout: 60 * time.Second, // Response header timeout (Nginx-like) + TLSHandshakeTimeout: 10 * time.Second, // TLS handshake timeout + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // Default skip TLS verification } - cache = make(map[string]cachedResponse) - defaultCacheTTL = 5 * time.Minute - cacheMutex sync.RWMutex + cache = make(map[string]cachedResponse) // Cache for static responses + cacheMutex sync.RWMutex // Mutex for cache access - // Rate limiter per client IP + // Rate limiting per client IP rateLimiters = make(map[string]*rate.Limiter) rateMutex sync.RWMutex - rateLimit = rate.Limit(10) // 10 requests per second per client - rateBurst = 20 // Allow burst of 20 requests + rateLimit = rate.Limit(10) // 10 req/s per client + rateBurst = 20 // Burst allowance + + defaultCacheTTL = 5 * time.Minute // Default TTL for cached responses ) +// cachedResponse stores cached response details type cachedResponse struct { body []byte headers http.Header @@ -50,7 +53,8 @@ type cachedResponse struct { etag string } -func getReverseProxy(target string, skipVerify bool) *httputil.ReverseProxy { +// getReverseProxy creates a reverse proxy for a target URL +func getReverseProxy(target string, skipVerify bool, originalReq *http.Request) *httputil.ReverseProxy { targetURL, err := url.Parse(target) if err != nil { log.Printf("Error parsing target URL %s: %v", target, err) @@ -58,61 +62,68 @@ func getReverseProxy(target string, skipVerify bool) *httputil.ReverseProxy { return nil } + // director modifies the request to forward it to the target director := func(req *http.Request) { req.URL.Scheme = targetURL.Scheme req.URL.Host = targetURL.Host - // Preserve client's Host header for session continuity - if req.Header.Get("Host") != "" { - req.Host = req.Header.Get("Host") + // Preserve the original client’s Host header for session continuity + if originalReq.Header.Get("Host") != "" { + req.Host = originalReq.Header.Get("Host") } else { req.Host = targetURL.Host } - // Preserve all request headers, including cookies - for k, v := range req.Header { + // Copy all request headers to ensure cookies and session data are preserved + for k, v := range originalReq.Header { req.Header[k] = v } - // Preserve query parameters and path - req.URL.RawQuery = req.URL.RawQuery + // Preserve query parameters and full path + req.URL.RawQuery = originalReq.URL.RawQuery if targetURL.Path != "" { - req.URL.Path = strings.TrimPrefix(req.URL.Path, "/") + req.URL.Path = strings.TrimPrefix(originalReq.URL.Path, "/") req.URL.Path = singleJoin(targetURL.Path, req.URL.Path) } - // WebSocket support: pass upgrade headers + // Ensure WebSocket headers are passed correctly if strings.ToLower(req.Header.Get("Upgrade")) == "websocket" { req.Header.Set("Connection", "Upgrade") } - trafficLogger.Printf("Request: %s %s -> %s [Host: %s]", req.Method, req.URL.String(), target, req.Host) + // Log request details for debugging + trafficLogger.Printf("Request: %s %s -> %s [Host: %s] [Headers: %v]", req.Method, req.URL.String(), target, req.Host, req.Header) } transport := transportPool if !skipVerify { + // Use a new transport with TLS verification if skipVerify is false transport = &http.Transport{ - MaxIdleConns: transportPool.MaxIdleConns, - MaxIdleConnsPerHost: transportPool.MaxIdleConnsPerHost, - IdleConnTimeout: transportPool.IdleConnTimeout, - TLSClientConfig: &tls.Config{InsecureSkipVerify: false}, + MaxIdleConns: transportPool.MaxIdleConns, + MaxIdleConnsPerHost: transportPool.MaxIdleConnsPerHost, + IdleConnTimeout: transportPool.IdleConnTimeout, + DialContext: transportPool.DialContext, + ResponseHeaderTimeout: transportPool.ResponseHeaderTimeout, + TLSHandshakeTimeout: transportPool.TLSHandshakeTimeout, + TLSClientConfig: &tls.Config{InsecureSkipVerify: false}, } } return &httputil.ReverseProxy{ Director: director, Transport: transport, + // ModifyResponse processes responses, avoiding corruption ModifyResponse: func(resp *http.Response) error { - // Preserve all response headers, including Set-Cookie + // Preserve all response headers, including Set-Cookie for sessions for k, v := range resp.Header { resp.Header[k] = v } - // Handle WebSocket upgrade + // Skip modification for WebSocket responses if strings.ToLower(resp.Header.Get("Upgrade")) == "websocket" { - return nil // No further modification needed for WebSocket + return nil } - // Compression if client supports it and response isn’t already compressed + // Apply compression only if not already compressed and client supports it if resp.Header.Get("Content-Encoding") == "" && strings.Contains(resp.Request.Header.Get("Accept-Encoding"), "gzip") { body, err := io.ReadAll(resp.Body) if err != nil { @@ -131,11 +142,11 @@ func getReverseProxy(target string, skipVerify bool) *httputil.ReverseProxy { resp.Body = io.NopCloser(&buf) resp.Header.Set("Content-Encoding", "gzip") - resp.Header.Del("Content-Length") // Length changes after compression + resp.Header.Del("Content-Length") trafficLogger.Printf("Compressed response for %s", resp.Request.URL.String()) } - // Cache static content with ETag support + // Cache static content if applicable if shouldCache(resp) { body, err := io.ReadAll(resp.Body) if err != nil { @@ -147,7 +158,7 @@ func getReverseProxy(target string, skipVerify bool) *httputil.ReverseProxy { etag := resp.Header.Get("ETag") if etag == "" { - etag = generateETag(body) // Simple ETag generation if not provided + etag = generateETag(body) } cacheDuration := parseCacheControl(resp.Header.Get("Cache-Control")) @@ -165,13 +176,18 @@ func getReverseProxy(target string, skipVerify bool) *httputil.ReverseProxy { etag: etag, } cacheMutex.Unlock() + resp.Header.Set("ETag", etag) trafficLogger.Printf("Cached response for %s [ETag: %s]", resp.Request.URL.String(), etag) } - trafficLogger.Printf("Response: %s %d from %s", resp.Status, resp.StatusCode, target) + trafficLogger.Printf("Response: %s %d from %s [Headers: %v]", resp.Status, resp.StatusCode, target, resp.Header) return nil }, ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) { + if err.Error() == "context canceled" { + trafficLogger.Printf("Request canceled by client for %s: %v", r.Host, err) + return + } log.Printf("Proxy error for %s: %v", r.Host, err) errorLogger.Printf("Proxy error for %s: %v", r.Host, err) http.Error(w, "Proxy error", http.StatusBadGateway) @@ -179,6 +195,7 @@ func getReverseProxy(target string, skipVerify bool) *httputil.ReverseProxy { } } +// shouldCache checks if a response should be cached based on method and content type func shouldCache(resp *http.Response) bool { if resp.Request.Method != "GET" || resp.StatusCode != http.StatusOK { return false @@ -188,25 +205,27 @@ func shouldCache(resp *http.Response) bool { strings.HasPrefix(contentType, "application/javascript") || strings.HasPrefix(contentType, "application/json") } +// singleJoin combines path segments with a single slash func singleJoin(prefix, suffix string) string { prefix = strings.TrimSuffix(prefix, "/") suffix = strings.TrimPrefix(suffix, "/") return prefix + "/" + suffix } +// generateETag creates an ETag from the response body using MD5 func generateETag(body []byte) string { - return fmt.Sprintf(`"%x"`, md5.Sum(body)) // Simple ETag based on MD5 hash + return fmt.Sprintf(`"%x"`, md5.Sum(body)) } +// parseCacheControl extracts max-age from the Cache-Control header for caching duration func parseCacheControl(header string) time.Duration { if header == "" { return 0 } parts := strings.Split(header, ",") for _, part := range parts { - if strings.Contains(part, "max-age=") { - ageStr := strings.TrimPrefix(part, "max-age=") - ageStr = strings.TrimSpace(ageStr) + if strings.HasPrefix(part, "max-age=") { + ageStr := strings.TrimSpace(strings.TrimPrefix(part, "max-age=")) if age, err := strconv.Atoi(ageStr); err == nil { return time.Duration(age) * time.Second } @@ -215,6 +234,7 @@ func parseCacheControl(header string) time.Duration { return 0 } +// getLimiter manages rate limiters per client IP func getLimiter(ip string) *rate.Limiter { rateMutex.Lock() defer rateMutex.Unlock() @@ -227,12 +247,13 @@ func getLimiter(ip string) *rate.Limiter { return limiter } +// handler processes incoming HTTP/HTTPS requests func handler(w http.ResponseWriter, r *http.Request) { configMux.RLock() defer configMux.RUnlock() // Rate limiting based on client IP - clientIP := r.RemoteAddr[:strings.LastIndex(r.RemoteAddr, ":")] // Strip port + clientIP := r.RemoteAddr[:strings.LastIndex(r.RemoteAddr, ":")] limiter := getLimiter(clientIP) if !limiter.Allow() { http.Error(w, "Too Many Requests", http.StatusTooManyRequests) @@ -240,6 +261,7 @@ func handler(w http.ResponseWriter, r *http.Request) { return } + // Retrieve target and settings for the requested host target, exists := config.Routes[r.Host] skipVerify := config.TrustTarget[r.Host] noHTTPSRedirect := config.NoHTTPSRedirect[r.Host] @@ -253,15 +275,16 @@ func handler(w http.ResponseWriter, r *http.Request) { noHTTPSRedirect = config.NoHTTPSRedirect["*"] } + // Handle HTTP->HTTPS redirect if applicable isHTTPS := target[:5] == "https" isHTTPReq := r.TLS == nil - if isHTTPReq && isHTTPS && !noHTTPSRedirect { redirectURL := "https://" + r.Host + r.RequestURI http.Redirect(w, r, redirectURL, http.StatusMovedPermanently) return } + // Check cache for static content cacheKey := r.URL.String() cacheMutex.RLock() if cached, ok := cache[cacheKey]; ok && time.Since(cached.cachedAt) < cached.cacheDuration { @@ -281,50 +304,65 @@ func handler(w http.ResponseWriter, r *http.Request) { } cacheMutex.RUnlock() - proxy := getReverseProxy(target, skipVerify) + proxy := getReverseProxy(target, skipVerify, r) if proxy == nil { http.Error(w, "Invalid target configuration", http.StatusInternalServerError) return } - targetURL, err := url.Parse(target) - if err != nil { - errorLogger.Printf("Error parsing target URL %s: %v", target, err) - http.Error(w, "Invalid target configuration", http.StatusInternalServerError) - return - } - - // WebSocket upgrade handling + // Handle WebSocket connections if strings.ToLower(r.Header.Get("Upgrade")) == "websocket" { + targetURL, err := url.Parse(target) + if err != nil { + errorLogger.Printf("Error parsing target URL for WebSocket: %v", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + hijacker, ok := w.(http.Hijacker) if !ok { - http.Error(w, "WebSocket upgrade not supported", http.StatusInternalServerError) + errorLogger.Printf("WebSocket upgrade not supported by server") + http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } - conn, _, err := hijacker.Hijack() + clientConn, _, err := hijacker.Hijack() if err != nil { errorLogger.Printf("Failed to hijack connection for WebSocket: %v", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } - defer conn.Close() + defer clientConn.Close() - targetConn, err := transportPool.Dial("tcp", targetURL.Host) + dialer := transportPool + if !skipVerify { + dialer = &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: false}, + } + } + targetConn, err := dialer.Dial("tcp", targetURL.Host) if err != nil { errorLogger.Printf("Failed to dial target for WebSocket: %v", err) return } defer targetConn.Close() - // Forward request to target if err := r.Write(targetConn); err != nil { errorLogger.Printf("Failed to forward WebSocket request: %v", err) return } - // Pipe connections - go io.Copy(conn, targetConn) - io.Copy(targetConn, conn) + errChan := make(chan error, 2) + go func() { + _, err := io.Copy(targetConn, clientConn) + errChan <- err + }() + go func() { + _, err := io.Copy(clientConn, targetConn) + errChan <- err + }() trafficLogger.Printf("WebSocket connection established: %s -> %s", r.Host, target) + <-errChan + trafficLogger.Printf("WebSocket connection closed: %s -> %s", r.Host, target) return } diff --git a/proxy_test.go b/proxy_test.go new file mode 100644 index 0000000..217c49b --- /dev/null +++ b/proxy_test.go @@ -0,0 +1,168 @@ +package main + +import ( + "bufio" + "bytes" + "io" + "log" + "net" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" +) + +// Mock globals for testing +var ( + mockConfig = Config{ + Routes: map[string]string{ + "test.local": "http://mock-target", + "ws.local": "ws://mock-target", + "cache.local": "http://mock-target", + }, + TrustTarget: map[string]bool{ + "test.local": true, + "ws.local": true, + "cache.local": true, + }, + NoHTTPSRedirect: map[string]bool{ + "test.local": true, + "ws.local": true, + "cache.local": true, + }, + } + mockConfigMux = sync.RWMutex{} + mockLogger = log.New(io.Discard, "", 0) // Discard logs during testing +) + +// TestHandlerRoute tests basic routing to a target +func TestHandlerRoute(t *testing.T) { + // Mock target server + targetServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("Hello from target")) + })) + defer targetServer.Close() + + // Setup mock globals + configMux.Lock() + config = mockConfig + config.Routes["test.local"] = targetServer.URL + trafficLogger = mockLogger + configMux.Unlock() + + // Create request + req, _ := http.NewRequest("GET", "http://test.local", nil) + rr := httptest.NewRecorder() + + // Run handler + handler(rr, req) + + // Check response + if status := rr.Code; status != http.StatusOK { + t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusOK) + } + if body := rr.Body.String(); body != "Hello from target" { + t.Errorf("handler returned unexpected body: got %v want %v", body, "Hello from target") + } +} + +// TestHandlerWebSocket tests WebSocket upgrade handling +func TestHandlerWebSocket(t *testing.T) { + // Mock WebSocket target (simplified) + targetServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Upgrade") == "websocket" { + hj, ok := w.(http.Hijacker) + if !ok { + t.Fatal("Server doesn’t support hijacking") + } + conn, _, err := hj.Hijack() + if err != nil { + t.Fatal(err) + } + conn.Write([]byte("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n")) + conn.Close() + } + })) + defer targetServer.Close() + + // Setup mock globals + configMux.Lock() + config = mockConfig + config.Routes["ws.local"] = targetServer.URL + trafficLogger = mockLogger + configMux.Unlock() + + // Create WebSocket request + req, _ := http.NewRequest("GET", "http://ws.local", nil) + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Connection", "Upgrade") + + // Use a custom recorder for hijacking + rr := &hijackRecorder{ResponseRecorder: httptest.NewRecorder()} + handler(rr, req) + + // Check if hijacking occurred + if !rr.hijacked { + t.Errorf("Expected WebSocket hijacking, but it didn’t occur") + } +} + +// TestHandlerCache tests caching functionality +func TestHandlerCache(t *testing.T) { + // Mock target server + targetServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain") + w.Write([]byte("Cached content")) + })) + defer targetServer.Close() + + // Setup mock globals + configMux.Lock() + config = mockConfig + config.Routes["cache.local"] = targetServer.URL + trafficLogger = mockLogger + cache = make(map[string]cachedResponse) // Reset cache for test isolation + configMux.Unlock() + + // First request to cache + req, _ := http.NewRequest("GET", "http://cache.local", nil) + rr1 := httptest.NewRecorder() + handler(rr1, req) + + // Second request to hit cache + rr2 := httptest.NewRecorder() + handler(rr2, req) + + // Check if second response is from cache + if rr2.Code != http.StatusOK { + t.Errorf("Expected status OK, got %v", rr2.Code) + } + if body := rr2.Body.String(); body != "Cached content" { + t.Errorf("Expected cached response, got %v", body) + } +} + +// hijackRecorder mocks ResponseRecorder with Hijack support for WebSocket testing +type hijackRecorder struct { + *httptest.ResponseRecorder + hijacked bool +} + +func (hr *hijackRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hr.hijacked = true + return &mockConn{Reader: bytes.NewReader([]byte("mock response")), Writer: hr.Body}, nil, nil +} + +// mockConn is a simple net.Conn mock for testing +type mockConn struct { + io.Reader + io.Writer +} + +func (mc *mockConn) Close() error { return nil } +func (mc *mockConn) LocalAddr() net.Addr { return nil } +func (mc *mockConn) RemoteAddr() net.Addr { return nil } +func (mc *mockConn) SetDeadline(t time.Time) error { return nil } +func (mc *mockConn) SetReadDeadline(t time.Time) error { return nil } +func (mc *mockConn) SetWriteDeadline(t time.Time) error { return nil } diff --git a/utils.go b/utils.go index 50c4e43..c593c28 100644 --- a/utils.go +++ b/utils.go @@ -2,9 +2,10 @@ package main import "path/filepath" +// updatePaths sets certificate and config file paths based on the current config func updatePaths() { - certDir = filepath.Join(baseDir, config.CertDir) - certPath = filepath.Join(certDir, config.CertFile) - keyPath = filepath.Join(certDir, config.KeyFile) - configPath = filepath.Join(baseDir, "config.yaml") + certDir = filepath.Join(baseDir, config.CertDir) // Certificate directory path + certPath = filepath.Join(certDir, config.CertFile) // Full certificate file path + keyPath = filepath.Join(certDir, config.KeyFile) // Full private key file path + configPath = filepath.Join(baseDir, "config.yaml") // Config file path }