added test
This commit is contained in:
@@ -0,0 +1,124 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
func generateSelfSignedCert() error {
|
||||
if err := os.MkdirAll(certDir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create certificate directory %s: %v", certDir, err)
|
||||
}
|
||||
|
||||
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate private key: %v", err)
|
||||
}
|
||||
|
||||
notBefore := time.Now()
|
||||
notAfter := notBefore.Add(365 * 24 * time.Hour)
|
||||
serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate serial number: %v", err)
|
||||
}
|
||||
|
||||
template := x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"Proxy Self-Signed"},
|
||||
CommonName: "localhost",
|
||||
},
|
||||
NotBefore: notBefore,
|
||||
NotAfter: notAfter,
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
DNSNames: []string{"localhost", "*.example.com"},
|
||||
}
|
||||
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create certificate: %v", err)
|
||||
}
|
||||
|
||||
certOut, err := os.Create(certPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open %s for writing: %v", certPath, err)
|
||||
}
|
||||
if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}); err != nil {
|
||||
certOut.Close()
|
||||
return fmt.Errorf("failed to encode certificate: %v", err)
|
||||
}
|
||||
certOut.Close()
|
||||
|
||||
keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_WRONLY, 0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open %s for writing: %v", keyPath, err)
|
||||
}
|
||||
if err := pem.Encode(keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}); err != nil {
|
||||
keyOut.Close()
|
||||
return fmt.Errorf("failed to encode private key: %v", err)
|
||||
}
|
||||
keyOut.Close()
|
||||
|
||||
refreshLogger.Printf("Generated self-signed certificate in %s", certDir)
|
||||
return nil
|
||||
}
|
||||
|
||||
func loadCertificate() error {
|
||||
certFile, err := os.ReadFile(certPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read certificate %s: %v", certPath, err)
|
||||
}
|
||||
keyFile, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read key %s: %v", keyPath, err)
|
||||
}
|
||||
|
||||
newCert, err := tls.X509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse certificate: %v", err)
|
||||
}
|
||||
|
||||
configMux.Lock()
|
||||
cert = &newCert
|
||||
configMux.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func monitorCertificates() {
|
||||
var lastModTime time.Time
|
||||
for {
|
||||
certInfo, err := os.Stat(certPath)
|
||||
if err != nil {
|
||||
errorLogger.Printf("Error checking certificate: %v", err)
|
||||
time.Sleep(5 * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
keyInfo, err := os.Stat(keyPath)
|
||||
if err != nil {
|
||||
errorLogger.Printf("Error checking key: %v", err)
|
||||
time.Sleep(5 * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
if certInfo.ModTime() != lastModTime || keyInfo.ModTime() != lastModTime {
|
||||
if err := loadCertificate(); err != nil {
|
||||
errorLogger.Printf("Error reloading certificate: %v", err)
|
||||
} else {
|
||||
refreshLogger.Println("Certificate reloaded successfully")
|
||||
lastModTime = certInfo.ModTime()
|
||||
}
|
||||
}
|
||||
time.Sleep(5 * time.Second)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,153 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
ListenHTTP string `yaml:"listen_http"`
|
||||
ListenHTTPS string `yaml:"listen_https"`
|
||||
CertDir string `yaml:"cert_dir"`
|
||||
CertFile string `yaml:"cert_file"`
|
||||
KeyFile string `yaml:"key_file"`
|
||||
Routes map[string]string `yaml:"routes"`
|
||||
TrustTarget map[string]bool `yaml:"trust_target"`
|
||||
NoHTTPSRedirect map[string]bool `yaml:"no_https_redirect"`
|
||||
}
|
||||
|
||||
func loadConfig() (Config, error) {
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("failed to read config %s: %v", configPath, err)
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
err = yaml.Unmarshal(data, &cfg)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("failed to unmarshal config: %v", err)
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func generateDefaultConfig() Config {
|
||||
return Config{
|
||||
ListenHTTP: ":80",
|
||||
ListenHTTPS: ":443",
|
||||
CertDir: "certificates",
|
||||
CertFile: "certificate.pem",
|
||||
KeyFile: "key.pem",
|
||||
Routes: map[string]string{
|
||||
"*": "https://127.0.0.1:3000",
|
||||
"main.example.com": "https://10.100.111.254:4444",
|
||||
},
|
||||
TrustTarget: map[string]bool{
|
||||
"*": true,
|
||||
"main.example.com": true,
|
||||
},
|
||||
NoHTTPSRedirect: map[string]bool{
|
||||
"*": false,
|
||||
"main.example.com": false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func saveConfig(cfg Config) error {
|
||||
if err := os.MkdirAll(baseDir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create base directory %s: %v", baseDir, err)
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal config: %v", err)
|
||||
}
|
||||
|
||||
err = os.WriteFile(configPath, data, 0644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write config to %s: %v", configPath, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func monitorConfig() {
|
||||
var lastModTime time.Time
|
||||
for {
|
||||
configInfo, err := os.Stat(configPath)
|
||||
if err != nil {
|
||||
errorLogger.Printf("Error checking config file: %v", err)
|
||||
time.Sleep(5 * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
if configInfo.ModTime() != lastModTime {
|
||||
newConfig, err := loadConfig()
|
||||
if err != nil {
|
||||
errorLogger.Printf("Error reloading config: %v", err)
|
||||
} else {
|
||||
configMux.Lock()
|
||||
if newConfig.ListenHTTP != config.ListenHTTP {
|
||||
config.ListenHTTP = newConfig.ListenHTTP
|
||||
refreshLogger.Printf("Updated listen_http to %s", config.ListenHTTP)
|
||||
}
|
||||
if newConfig.ListenHTTPS != config.ListenHTTPS {
|
||||
config.ListenHTTPS = newConfig.ListenHTTPS
|
||||
refreshLogger.Printf("Updated listen_https to %s", config.ListenHTTPS)
|
||||
}
|
||||
if newConfig.CertDir != config.CertDir || newConfig.CertFile != config.CertFile || newConfig.KeyFile != config.KeyFile {
|
||||
config.CertDir = newConfig.CertDir
|
||||
config.CertFile = newConfig.CertFile
|
||||
config.KeyFile = newConfig.KeyFile
|
||||
updatePaths()
|
||||
if err := loadCertificate(); err != nil {
|
||||
errorLogger.Printf("Error reloading certificate after path change: %v", err)
|
||||
} else {
|
||||
refreshLogger.Println("Updated certificate paths and reloaded certificate")
|
||||
}
|
||||
}
|
||||
for k, v := range newConfig.Routes {
|
||||
if oldV, exists := config.Routes[k]; !exists || oldV != v {
|
||||
config.Routes[k] = v
|
||||
refreshLogger.Printf("Updated route %s to %s", k, v)
|
||||
}
|
||||
}
|
||||
for k := range config.Routes {
|
||||
if _, exists := newConfig.Routes[k]; !exists {
|
||||
delete(config.Routes, k)
|
||||
refreshLogger.Printf("Removed route %s", k)
|
||||
}
|
||||
}
|
||||
for k, v := range newConfig.TrustTarget {
|
||||
if oldV, exists := config.TrustTarget[k]; !exists || oldV != v {
|
||||
config.TrustTarget[k] = v
|
||||
refreshLogger.Printf("Updated trust_target %s to %v", k, v)
|
||||
}
|
||||
}
|
||||
for k := range config.TrustTarget {
|
||||
if _, exists := newConfig.TrustTarget[k]; !exists {
|
||||
delete(config.TrustTarget, k)
|
||||
refreshLogger.Printf("Removed trust_target %s", k)
|
||||
}
|
||||
}
|
||||
for k, v := range newConfig.NoHTTPSRedirect {
|
||||
if oldV, exists := config.NoHTTPSRedirect[k]; !exists || oldV != v {
|
||||
config.NoHTTPSRedirect[k] = v
|
||||
refreshLogger.Printf("Updated no_https_redirect %s to %v", k, v)
|
||||
}
|
||||
}
|
||||
for k := range config.NoHTTPSRedirect {
|
||||
if _, exists := newConfig.NoHTTPSRedirect[k]; !exists {
|
||||
delete(config.NoHTTPSRedirect, k)
|
||||
refreshLogger.Printf("Removed no_https_redirect %s", k)
|
||||
}
|
||||
}
|
||||
configMux.Unlock()
|
||||
refreshLogger.Println("Config reloaded successfully")
|
||||
lastModTime = configInfo.ModTime()
|
||||
}
|
||||
}
|
||||
time.Sleep(5 * time.Second)
|
||||
}
|
||||
}
|
||||
+197
@@ -0,0 +1,197 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
config Config
|
||||
configMux sync.RWMutex
|
||||
cert *tls.Certificate
|
||||
baseDir string
|
||||
certDir string
|
||||
certPath string
|
||||
keyPath string
|
||||
configPath string
|
||||
|
||||
// Loggers
|
||||
errorLogger *log.Logger
|
||||
refreshLogger *log.Logger
|
||||
trafficLogger *log.Logger
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Get the absolute path of the running executable (or main.go in `go run`)
|
||||
exePath, err := os.Executable()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to get executable path: %v", err)
|
||||
}
|
||||
baseDir = filepath.Dir(exePath)
|
||||
/*
|
||||
if runtime.GOOS == "windows" && len(os.Args) > 0 && filepath.Ext(os.Args[0]) == ".go" {
|
||||
var err error
|
||||
baseDir, err = os.Getwd()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to get working directory: %v", err)
|
||||
}
|
||||
} else {
|
||||
exePath, err := os.Executable()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to get executable path: %v", err)
|
||||
}
|
||||
baseDir = filepath.Dir(exePath)
|
||||
}
|
||||
*/
|
||||
// Setup logging
|
||||
setupLogging()
|
||||
}
|
||||
|
||||
func setupLogging() {
|
||||
logsDir := filepath.Join(baseDir, "logs")
|
||||
if err := os.MkdirAll(logsDir, 0755); err != nil {
|
||||
log.Fatalf("Failed to create logs directory: %v", err)
|
||||
}
|
||||
|
||||
// Error log
|
||||
errorFile, err := os.OpenFile(filepath.Join(logsDir, "errors.log"), os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to open errors.log: %v", err)
|
||||
}
|
||||
errorLogger = log.New(io.MultiWriter(os.Stdout, errorFile), "ERROR: ", log.Ldate|log.Ltime|log.Lshortfile)
|
||||
|
||||
// Refresh log
|
||||
refreshFile, err := os.OpenFile(filepath.Join(logsDir, "refresh.log"), os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to open refresh.log: %v", err)
|
||||
}
|
||||
refreshLogger = log.New(io.MultiWriter(os.Stdout, refreshFile), "REFRESH: ", log.Ldate|log.Ltime|log.Lshortfile)
|
||||
|
||||
// Traffic log (daily rotation)
|
||||
go manageTrafficLogs(logsDir)
|
||||
}
|
||||
|
||||
func manageTrafficLogs(logsDir string) {
|
||||
for {
|
||||
dateStr := time.Now().Format("2006-01-02")
|
||||
trafficFile, err := os.OpenFile(filepath.Join(logsDir, "traffic-"+dateStr+".log"), os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
errorLogger.Printf("Failed to open traffic log: %v", err)
|
||||
time.Sleep(1 * time.Minute)
|
||||
continue
|
||||
}
|
||||
trafficLogger = log.New(io.MultiWriter(os.Stdout, trafficFile), "TRAFFIC: ", log.Ldate|log.Ltime)
|
||||
|
||||
// Cleanup old logs
|
||||
cleanupOldLogs(logsDir)
|
||||
|
||||
// Wait until next day
|
||||
nextDay := time.Now().Truncate(24 * time.Hour).Add(24 * time.Hour)
|
||||
time.Sleep(time.Until(nextDay))
|
||||
trafficFile.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func cleanupOldLogs(logsDir string) {
|
||||
files, err := os.ReadDir(logsDir)
|
||||
if err != nil {
|
||||
errorLogger.Printf("Failed to read logs directory: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
cutoff := time.Now().AddDate(0, 0, -7) // 7 days ago
|
||||
for _, file := range files {
|
||||
if strings.HasPrefix(file.Name(), "traffic-") && file.Name() != "traffic-"+time.Now().Format("2006-01-02")+".log" {
|
||||
info, err := file.Info()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if info.ModTime().Before(cutoff) {
|
||||
if err := os.Remove(filepath.Join(logsDir, file.Name())); err != nil {
|
||||
errorLogger.Printf("Failed to remove old log %s: %v", file.Name(), err)
|
||||
} else {
|
||||
refreshLogger.Printf("Removed old traffic log: %s", file.Name())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
configPath = filepath.Join(baseDir, "config.yaml")
|
||||
if _, err := os.Stat(configPath); os.IsNotExist(err) {
|
||||
config = generateDefaultConfig()
|
||||
if err := saveConfig(config); err != nil {
|
||||
errorLogger.Fatalf("Failed to save default config: %v", err)
|
||||
}
|
||||
refreshLogger.Println("Generated default config file")
|
||||
} else {
|
||||
cfg, err := loadConfig()
|
||||
if err != nil {
|
||||
errorLogger.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
config = cfg
|
||||
}
|
||||
|
||||
updatePaths()
|
||||
|
||||
_, certErr := os.Stat(certPath)
|
||||
_, keyErr := os.Stat(keyPath)
|
||||
if os.IsNotExist(certErr) || os.IsNotExist(keyErr) {
|
||||
if err := generateSelfSignedCert(); err != nil {
|
||||
errorLogger.Fatalf("Failed to generate self-signed certificate: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := loadCertificate(); err != nil {
|
||||
errorLogger.Fatalf("Failed to load certificate: %v", err)
|
||||
}
|
||||
|
||||
go monitorCertificates()
|
||||
go monitorConfig()
|
||||
|
||||
httpServer := &http.Server{
|
||||
Addr: config.ListenHTTP,
|
||||
Handler: http.HandlerFunc(handler),
|
||||
MaxHeaderBytes: 1 << 20,
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
configMux.RLock()
|
||||
defer configMux.RUnlock()
|
||||
return cert, nil
|
||||
},
|
||||
}
|
||||
httpsServer := &http.Server{
|
||||
Addr: config.ListenHTTPS,
|
||||
Handler: http.HandlerFunc(handler),
|
||||
TLSConfig: tlsConfig,
|
||||
MaxHeaderBytes: 1 << 20,
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
go func() {
|
||||
log.Printf("Starting HTTP server on %s", config.ListenHTTP)
|
||||
trafficLogger.Printf("Starting HTTP server on %s", config.ListenHTTP)
|
||||
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
errorLogger.Fatalf("HTTP server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
log.Printf("Starting HTTPS server on %s", config.ListenHTTPS)
|
||||
trafficLogger.Printf("Starting HTTPS server on %s", config.ListenHTTPS)
|
||||
if err := httpsServer.ListenAndServeTLS("", ""); err != nil && err != http.ErrServerClosed {
|
||||
errorLogger.Fatalf("HTTPS server error: %v", err)
|
||||
}
|
||||
}
|
||||
+332
@@ -0,0 +1,332 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"crypto/md5"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
var (
|
||||
transportPool = &http.Transport{
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
DialContext: (&net.Dialer{Timeout: 30 * time.Second}).DialContext,
|
||||
ResponseHeaderTimeout: 60 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
cache = make(map[string]cachedResponse)
|
||||
defaultCacheTTL = 5 * time.Minute
|
||||
cacheMutex sync.RWMutex
|
||||
|
||||
// Rate limiter per client IP
|
||||
rateLimiters = make(map[string]*rate.Limiter)
|
||||
rateMutex sync.RWMutex
|
||||
rateLimit = rate.Limit(10) // 10 requests per second per client
|
||||
rateBurst = 20 // Allow burst of 20 requests
|
||||
)
|
||||
|
||||
type cachedResponse struct {
|
||||
body []byte
|
||||
headers http.Header
|
||||
statusCode int
|
||||
cachedAt time.Time
|
||||
cacheDuration time.Duration
|
||||
etag string
|
||||
}
|
||||
|
||||
func getReverseProxy(target string, skipVerify bool) *httputil.ReverseProxy {
|
||||
targetURL, err := url.Parse(target)
|
||||
if err != nil {
|
||||
log.Printf("Error parsing target URL %s: %v", target, err)
|
||||
errorLogger.Printf("Error parsing target URL %s: %v", target, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
director := func(req *http.Request) {
|
||||
req.URL.Scheme = targetURL.Scheme
|
||||
req.URL.Host = targetURL.Host
|
||||
// Preserve client's Host header for session continuity
|
||||
if req.Header.Get("Host") != "" {
|
||||
req.Host = req.Header.Get("Host")
|
||||
} else {
|
||||
req.Host = targetURL.Host
|
||||
}
|
||||
|
||||
// Preserve all request headers, including cookies
|
||||
for k, v := range req.Header {
|
||||
req.Header[k] = v
|
||||
}
|
||||
|
||||
// Preserve query parameters and path
|
||||
req.URL.RawQuery = req.URL.RawQuery
|
||||
if targetURL.Path != "" {
|
||||
req.URL.Path = strings.TrimPrefix(req.URL.Path, "/")
|
||||
req.URL.Path = singleJoin(targetURL.Path, req.URL.Path)
|
||||
}
|
||||
|
||||
// WebSocket support: pass upgrade headers
|
||||
if strings.ToLower(req.Header.Get("Upgrade")) == "websocket" {
|
||||
req.Header.Set("Connection", "Upgrade")
|
||||
}
|
||||
|
||||
trafficLogger.Printf("Request: %s %s -> %s [Host: %s]", req.Method, req.URL.String(), target, req.Host)
|
||||
}
|
||||
|
||||
transport := transportPool
|
||||
if !skipVerify {
|
||||
transport = &http.Transport{
|
||||
MaxIdleConns: transportPool.MaxIdleConns,
|
||||
MaxIdleConnsPerHost: transportPool.MaxIdleConnsPerHost,
|
||||
IdleConnTimeout: transportPool.IdleConnTimeout,
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: false},
|
||||
}
|
||||
}
|
||||
|
||||
return &httputil.ReverseProxy{
|
||||
Director: director,
|
||||
Transport: transport,
|
||||
ModifyResponse: func(resp *http.Response) error {
|
||||
// Preserve all response headers, including Set-Cookie
|
||||
for k, v := range resp.Header {
|
||||
resp.Header[k] = v
|
||||
}
|
||||
|
||||
// Handle WebSocket upgrade
|
||||
if strings.ToLower(resp.Header.Get("Upgrade")) == "websocket" {
|
||||
return nil // No further modification needed for WebSocket
|
||||
}
|
||||
|
||||
// Compression if client supports it and response isn’t already compressed
|
||||
if resp.Header.Get("Content-Encoding") == "" && strings.Contains(resp.Request.Header.Get("Accept-Encoding"), "gzip") {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
errorLogger.Printf("Error reading response body for compression: %v", err)
|
||||
return err
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
var buf bytes.Buffer
|
||||
gw := gzip.NewWriter(&buf)
|
||||
if _, err := gw.Write(body); err != nil {
|
||||
errorLogger.Printf("Error compressing response: %v", err)
|
||||
return err
|
||||
}
|
||||
gw.Close()
|
||||
|
||||
resp.Body = io.NopCloser(&buf)
|
||||
resp.Header.Set("Content-Encoding", "gzip")
|
||||
resp.Header.Del("Content-Length") // Length changes after compression
|
||||
trafficLogger.Printf("Compressed response for %s", resp.Request.URL.String())
|
||||
}
|
||||
|
||||
// Cache static content with ETag support
|
||||
if shouldCache(resp) {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
errorLogger.Printf("Error reading response body for caching: %v", err)
|
||||
return err
|
||||
}
|
||||
resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(body))
|
||||
|
||||
etag := resp.Header.Get("ETag")
|
||||
if etag == "" {
|
||||
etag = generateETag(body) // Simple ETag generation if not provided
|
||||
}
|
||||
|
||||
cacheDuration := parseCacheControl(resp.Header.Get("Cache-Control"))
|
||||
if cacheDuration == 0 {
|
||||
cacheDuration = defaultCacheTTL
|
||||
}
|
||||
|
||||
cacheMutex.Lock()
|
||||
cache[resp.Request.URL.String()] = cachedResponse{
|
||||
body: body,
|
||||
headers: resp.Header.Clone(),
|
||||
statusCode: resp.StatusCode,
|
||||
cachedAt: time.Now(),
|
||||
cacheDuration: cacheDuration,
|
||||
etag: etag,
|
||||
}
|
||||
cacheMutex.Unlock()
|
||||
trafficLogger.Printf("Cached response for %s [ETag: %s]", resp.Request.URL.String(), etag)
|
||||
}
|
||||
|
||||
trafficLogger.Printf("Response: %s %d from %s", resp.Status, resp.StatusCode, target)
|
||||
return nil
|
||||
},
|
||||
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
log.Printf("Proxy error for %s: %v", r.Host, err)
|
||||
errorLogger.Printf("Proxy error for %s: %v", r.Host, err)
|
||||
http.Error(w, "Proxy error", http.StatusBadGateway)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func shouldCache(resp *http.Response) bool {
|
||||
if resp.Request.Method != "GET" || resp.StatusCode != http.StatusOK {
|
||||
return false
|
||||
}
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
return strings.HasPrefix(contentType, "text/") || strings.HasPrefix(contentType, "image/") ||
|
||||
strings.HasPrefix(contentType, "application/javascript") || strings.HasPrefix(contentType, "application/json")
|
||||
}
|
||||
|
||||
func singleJoin(prefix, suffix string) string {
|
||||
prefix = strings.TrimSuffix(prefix, "/")
|
||||
suffix = strings.TrimPrefix(suffix, "/")
|
||||
return prefix + "/" + suffix
|
||||
}
|
||||
|
||||
func generateETag(body []byte) string {
|
||||
return fmt.Sprintf(`"%x"`, md5.Sum(body)) // Simple ETag based on MD5 hash
|
||||
}
|
||||
|
||||
func parseCacheControl(header string) time.Duration {
|
||||
if header == "" {
|
||||
return 0
|
||||
}
|
||||
parts := strings.Split(header, ",")
|
||||
for _, part := range parts {
|
||||
if strings.Contains(part, "max-age=") {
|
||||
ageStr := strings.TrimPrefix(part, "max-age=")
|
||||
ageStr = strings.TrimSpace(ageStr)
|
||||
if age, err := strconv.Atoi(ageStr); err == nil {
|
||||
return time.Duration(age) * time.Second
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func getLimiter(ip string) *rate.Limiter {
|
||||
rateMutex.Lock()
|
||||
defer rateMutex.Unlock()
|
||||
|
||||
if limiter, exists := rateLimiters[ip]; exists {
|
||||
return limiter
|
||||
}
|
||||
limiter := rate.NewLimiter(rateLimit, rateBurst)
|
||||
rateLimiters[ip] = limiter
|
||||
return limiter
|
||||
}
|
||||
|
||||
func handler(w http.ResponseWriter, r *http.Request) {
|
||||
configMux.RLock()
|
||||
defer configMux.RUnlock()
|
||||
|
||||
// Rate limiting based on client IP
|
||||
clientIP := r.RemoteAddr[:strings.LastIndex(r.RemoteAddr, ":")] // Strip port
|
||||
limiter := getLimiter(clientIP)
|
||||
if !limiter.Allow() {
|
||||
http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
|
||||
trafficLogger.Printf("Rate limited client %s", clientIP)
|
||||
return
|
||||
}
|
||||
|
||||
target, exists := config.Routes[r.Host]
|
||||
skipVerify := config.TrustTarget[r.Host]
|
||||
noHTTPSRedirect := config.NoHTTPSRedirect[r.Host]
|
||||
|
||||
if !exists {
|
||||
if target, exists = config.Routes["*"]; !exists {
|
||||
http.Error(w, "Host not configured", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
skipVerify = config.TrustTarget["*"]
|
||||
noHTTPSRedirect = config.NoHTTPSRedirect["*"]
|
||||
}
|
||||
|
||||
isHTTPS := target[:5] == "https"
|
||||
isHTTPReq := r.TLS == nil
|
||||
|
||||
if isHTTPReq && isHTTPS && !noHTTPSRedirect {
|
||||
redirectURL := "https://" + r.Host + r.RequestURI
|
||||
http.Redirect(w, r, redirectURL, http.StatusMovedPermanently)
|
||||
return
|
||||
}
|
||||
|
||||
cacheKey := r.URL.String()
|
||||
cacheMutex.RLock()
|
||||
if cached, ok := cache[cacheKey]; ok && time.Since(cached.cachedAt) < cached.cacheDuration {
|
||||
if etag := r.Header.Get("If-None-Match"); etag != "" && etag == cached.etag {
|
||||
w.WriteHeader(http.StatusNotModified)
|
||||
trafficLogger.Printf("Served 304 Not Modified from cache: %s [ETag: %s]", cacheKey, cached.etag)
|
||||
} else {
|
||||
for k, v := range cached.headers {
|
||||
w.Header()[k] = v
|
||||
}
|
||||
w.WriteHeader(cached.statusCode)
|
||||
w.Write(cached.body)
|
||||
trafficLogger.Printf("Served from cache: %s [ETag: %s]", cacheKey, cached.etag)
|
||||
}
|
||||
cacheMutex.RUnlock()
|
||||
return
|
||||
}
|
||||
cacheMutex.RUnlock()
|
||||
|
||||
proxy := getReverseProxy(target, skipVerify)
|
||||
if proxy == nil {
|
||||
http.Error(w, "Invalid target configuration", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
targetURL, err := url.Parse(target)
|
||||
if err != nil {
|
||||
errorLogger.Printf("Error parsing target URL %s: %v", target, err)
|
||||
http.Error(w, "Invalid target configuration", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// WebSocket upgrade handling
|
||||
if strings.ToLower(r.Header.Get("Upgrade")) == "websocket" {
|
||||
hijacker, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
http.Error(w, "WebSocket upgrade not supported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
conn, _, err := hijacker.Hijack()
|
||||
if err != nil {
|
||||
errorLogger.Printf("Failed to hijack connection for WebSocket: %v", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
targetConn, err := transportPool.Dial("tcp", targetURL.Host)
|
||||
if err != nil {
|
||||
errorLogger.Printf("Failed to dial target for WebSocket: %v", err)
|
||||
return
|
||||
}
|
||||
defer targetConn.Close()
|
||||
|
||||
// Forward request to target
|
||||
if err := r.Write(targetConn); err != nil {
|
||||
errorLogger.Printf("Failed to forward WebSocket request: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Pipe connections
|
||||
go io.Copy(conn, targetConn)
|
||||
io.Copy(targetConn, conn)
|
||||
trafficLogger.Printf("WebSocket connection established: %s -> %s", r.Host, target)
|
||||
return
|
||||
}
|
||||
|
||||
proxy.ServeHTTP(w, r)
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
package main
|
||||
|
||||
import "path/filepath"
|
||||
|
||||
func updatePaths() {
|
||||
certDir = filepath.Join(baseDir, config.CertDir)
|
||||
certPath = filepath.Join(certDir, config.CertFile)
|
||||
keyPath = filepath.Join(certDir, config.KeyFile)
|
||||
configPath = filepath.Join(baseDir, "config.yaml")
|
||||
}
|
||||
+16
-3
@@ -13,23 +13,30 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// generateSelfSignedCert creates a self-signed TLS certificate if one doesn’t exist
|
||||
func generateSelfSignedCert() error {
|
||||
// Ensure certificate directory exists
|
||||
if err := os.MkdirAll(certDir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create certificate directory %s: %v", certDir, err)
|
||||
}
|
||||
|
||||
// Generate a 2048-bit RSA private key
|
||||
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate private key: %v", err)
|
||||
}
|
||||
|
||||
// Set certificate validity period (1 year)
|
||||
notBefore := time.Now()
|
||||
notAfter := notBefore.Add(365 * 24 * time.Hour)
|
||||
|
||||
// Generate a random serial number
|
||||
serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate serial number: %v", err)
|
||||
}
|
||||
|
||||
// Define certificate template with basic attributes
|
||||
template := x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{
|
||||
@@ -41,14 +48,16 @@ func generateSelfSignedCert() error {
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
DNSNames: []string{"localhost", "*.example.com"},
|
||||
DNSNames: []string{"localhost", "*.example.com"}, // Supported domains
|
||||
}
|
||||
|
||||
// Create self-signed certificate
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create certificate: %v", err)
|
||||
}
|
||||
|
||||
// Write certificate to file
|
||||
certOut, err := os.Create(certPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open %s for writing: %v", certPath, err)
|
||||
@@ -59,7 +68,8 @@ func generateSelfSignedCert() error {
|
||||
}
|
||||
certOut.Close()
|
||||
|
||||
keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_WRONLY, 0600)
|
||||
// Write private key to file
|
||||
keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open %s for writing: %v", keyPath, err)
|
||||
}
|
||||
@@ -73,6 +83,7 @@ func generateSelfSignedCert() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadCertificate reads and parses the certificate and key files into memory
|
||||
func loadCertificate() error {
|
||||
certFile, err := os.ReadFile(certPath)
|
||||
if err != nil {
|
||||
@@ -94,6 +105,7 @@ func loadCertificate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// monitorCertificates watches for changes to certificate/key files and reloads them
|
||||
func monitorCertificates() {
|
||||
var lastModTime time.Time
|
||||
for {
|
||||
@@ -111,6 +123,7 @@ func monitorCertificates() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Reload certificate if either file has changed
|
||||
if certInfo.ModTime() != lastModTime || keyInfo.ModTime() != lastModTime {
|
||||
if err := loadCertificate(); err != nil {
|
||||
errorLogger.Printf("Error reloading certificate: %v", err)
|
||||
@@ -119,6 +132,6 @@ func monitorCertificates() {
|
||||
lastModTime = certInfo.ModTime()
|
||||
}
|
||||
}
|
||||
time.Sleep(5 * time.Second)
|
||||
time.Sleep(5 * time.Second) // Poll every 5 seconds
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,17 +8,19 @@ import (
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
// Config defines the structure of the proxy configuration loaded from config.yaml
|
||||
type Config struct {
|
||||
ListenHTTP string `yaml:"listen_http"`
|
||||
ListenHTTPS string `yaml:"listen_https"`
|
||||
CertDir string `yaml:"cert_dir"`
|
||||
CertFile string `yaml:"cert_file"`
|
||||
KeyFile string `yaml:"key_file"`
|
||||
Routes map[string]string `yaml:"routes"`
|
||||
TrustTarget map[string]bool `yaml:"trust_target"`
|
||||
NoHTTPSRedirect map[string]bool `yaml:"no_https_redirect"`
|
||||
ListenHTTP string `yaml:"listen_http"` // Port for HTTP server (e.g., ":80")
|
||||
ListenHTTPS string `yaml:"listen_https"` // Port for HTTPS server (e.g., ":443")
|
||||
CertDir string `yaml:"cert_dir"` // Directory for certificate files
|
||||
CertFile string `yaml:"cert_file"` // Certificate filename
|
||||
KeyFile string `yaml:"key_file"` // Private key filename
|
||||
Routes map[string]string `yaml:"routes"` // Mapping of hostnames to target URLs
|
||||
TrustTarget map[string]bool `yaml:"trust_target"` // Whether to skip TLS verification for targets
|
||||
NoHTTPSRedirect map[string]bool `yaml:"no_https_redirect"` // Whether to skip HTTP->HTTPS redirect
|
||||
}
|
||||
|
||||
// loadConfig reads and parses the config.yaml file
|
||||
func loadConfig() (Config, error) {
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
@@ -26,35 +28,36 @@ func loadConfig() (Config, error) {
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
err = yaml.Unmarshal(data, &cfg)
|
||||
if err != nil {
|
||||
if err := yaml.Unmarshal(data, &cfg); err != nil {
|
||||
return Config{}, fmt.Errorf("failed to unmarshal config: %v", err)
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// generateDefaultConfig creates a default configuration if config.yaml doesn’t exist
|
||||
func generateDefaultConfig() Config {
|
||||
return Config{
|
||||
ListenHTTP: ":80",
|
||||
ListenHTTPS: ":443",
|
||||
CertDir: "certificates",
|
||||
CertDir: "./certificate",
|
||||
CertFile: "certificate.pem",
|
||||
KeyFile: "key.pem",
|
||||
Routes: map[string]string{
|
||||
"*": "https://127.0.0.1:3000",
|
||||
"main.example.com": "https://10.100.111.254:4444",
|
||||
"*": "https://127.0.0.1:3000", // Wildcard route
|
||||
"main.example.com": "https://10.100.111.254:4444", // Specific route
|
||||
},
|
||||
TrustTarget: map[string]bool{
|
||||
"*": true,
|
||||
"*": true, // Skip TLS verification by default
|
||||
"main.example.com": true,
|
||||
},
|
||||
NoHTTPSRedirect: map[string]bool{
|
||||
"*": false,
|
||||
"*": false, // Redirect HTTP to HTTPS by default
|
||||
"main.example.com": false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// saveConfig writes the configuration to config.yaml
|
||||
func saveConfig(cfg Config) error {
|
||||
if err := os.MkdirAll(baseDir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create base directory %s: %v", baseDir, err)
|
||||
@@ -65,13 +68,13 @@ func saveConfig(cfg Config) error {
|
||||
return fmt.Errorf("failed to marshal config: %v", err)
|
||||
}
|
||||
|
||||
err = os.WriteFile(configPath, data, 0644)
|
||||
if err != nil {
|
||||
if err := os.WriteFile(configPath, data, 0644); err != nil {
|
||||
return fmt.Errorf("failed to write config to %s: %v", configPath, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// monitorConfig watches config.yaml for changes and updates the in-memory config
|
||||
func monitorConfig() {
|
||||
var lastModTime time.Time
|
||||
for {
|
||||
@@ -88,6 +91,7 @@ func monitorConfig() {
|
||||
errorLogger.Printf("Error reloading config: %v", err)
|
||||
} else {
|
||||
configMux.Lock()
|
||||
// Update individual fields only if they’ve changed
|
||||
if newConfig.ListenHTTP != config.ListenHTTP {
|
||||
config.ListenHTTP = newConfig.ListenHTTP
|
||||
refreshLogger.Printf("Updated listen_http to %s", config.ListenHTTP)
|
||||
@@ -107,6 +111,7 @@ func monitorConfig() {
|
||||
refreshLogger.Println("Updated certificate paths and reloaded certificate")
|
||||
}
|
||||
}
|
||||
// Update routes
|
||||
for k, v := range newConfig.Routes {
|
||||
if oldV, exists := config.Routes[k]; !exists || oldV != v {
|
||||
config.Routes[k] = v
|
||||
@@ -119,6 +124,7 @@ func monitorConfig() {
|
||||
refreshLogger.Printf("Removed route %s", k)
|
||||
}
|
||||
}
|
||||
// Update trust_target
|
||||
for k, v := range newConfig.TrustTarget {
|
||||
if oldV, exists := config.TrustTarget[k]; !exists || oldV != v {
|
||||
config.TrustTarget[k] = v
|
||||
@@ -131,6 +137,7 @@ func monitorConfig() {
|
||||
refreshLogger.Printf("Removed trust_target %s", k)
|
||||
}
|
||||
}
|
||||
// Update no_https_redirect
|
||||
for k, v := range newConfig.NoHTTPSRedirect {
|
||||
if oldV, exists := config.NoHTTPSRedirect[k]; !exists || oldV != v {
|
||||
config.NoHTTPSRedirect[k] = v
|
||||
@@ -148,6 +155,6 @@ func monitorConfig() {
|
||||
lastModTime = configInfo.ModTime()
|
||||
}
|
||||
}
|
||||
time.Sleep(5 * time.Second)
|
||||
time.Sleep(5 * time.Second) // Poll every 5 seconds
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,77 +7,78 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Shared global variables used across all files
|
||||
var (
|
||||
config Config
|
||||
configMux sync.RWMutex
|
||||
cert *tls.Certificate
|
||||
baseDir string
|
||||
certDir string
|
||||
certPath string
|
||||
keyPath string
|
||||
configPath string
|
||||
config Config // Holds the proxy configuration loaded from config.yaml
|
||||
configMux sync.RWMutex // Mutex for thread-safe access to config
|
||||
cert *tls.Certificate // TLS certificate for HTTPS server
|
||||
baseDir string // Base directory (working dir for go run, executable dir for binary)
|
||||
certDir string // Directory for certificates
|
||||
certPath string // Path to certificate file
|
||||
keyPath string // Path to private key file
|
||||
configPath string // Path to config.yaml
|
||||
|
||||
// Loggers
|
||||
errorLogger *log.Logger
|
||||
refreshLogger *log.Logger
|
||||
trafficLogger *log.Logger
|
||||
// Loggers for different types of events
|
||||
errorLogger *log.Logger // Logs errors to errors.log and console
|
||||
refreshLogger *log.Logger // Logs config/certificate refreshes to refresh.log and console
|
||||
trafficLogger *log.Logger // Logs traffic to traffic-YYYY-MM-DD.log and console
|
||||
)
|
||||
|
||||
// init sets up the base directory and initializes logging
|
||||
func init() {
|
||||
// Get the absolute path of the running executable (or main.go in `go run`)
|
||||
exePath, err := os.Executable()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to get executable path: %v", err)
|
||||
}
|
||||
baseDir = filepath.Dir(exePath)
|
||||
/*
|
||||
// Determine base directory based on execution context
|
||||
if runtime.GOOS == "windows" && len(os.Args) > 0 && filepath.Ext(os.Args[0]) == ".go" {
|
||||
// For "go run", use current working directory
|
||||
var err error
|
||||
baseDir, err = os.Getwd()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to get working directory: %v", err)
|
||||
}
|
||||
} else {
|
||||
// For compiled binary, use executable's directory
|
||||
exePath, err := os.Executable()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to get executable path: %v", err)
|
||||
}
|
||||
baseDir = filepath.Dir(exePath)
|
||||
}
|
||||
*/
|
||||
// Setup logging
|
||||
|
||||
// Initialize logging to files and console
|
||||
setupLogging()
|
||||
}
|
||||
|
||||
// setupLogging configures loggers for errors, refreshes, and traffic
|
||||
func setupLogging() {
|
||||
logsDir := filepath.Join(baseDir, "logs")
|
||||
if err := os.MkdirAll(logsDir, 0755); err != nil {
|
||||
log.Fatalf("Failed to create logs directory: %v", err)
|
||||
}
|
||||
|
||||
// Error log
|
||||
// Error log: writes to errors.log and stdout
|
||||
errorFile, err := os.OpenFile(filepath.Join(logsDir, "errors.log"), os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to open errors.log: %v", err)
|
||||
}
|
||||
errorLogger = log.New(io.MultiWriter(os.Stdout, errorFile), "ERROR: ", log.Ldate|log.Ltime|log.Lshortfile)
|
||||
|
||||
// Refresh log
|
||||
// Refresh log: writes to refresh.log and stdout
|
||||
refreshFile, err := os.OpenFile(filepath.Join(logsDir, "refresh.log"), os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to open refresh.log: %v", err)
|
||||
}
|
||||
refreshLogger = log.New(io.MultiWriter(os.Stdout, refreshFile), "REFRESH: ", log.Ldate|log.Ltime|log.Lshortfile)
|
||||
|
||||
// Traffic log (daily rotation)
|
||||
// Traffic log: managed in a goroutine for daily rotation
|
||||
go manageTrafficLogs(logsDir)
|
||||
}
|
||||
|
||||
// manageTrafficLogs handles daily rotation of traffic logs with 7-day retention
|
||||
func manageTrafficLogs(logsDir string) {
|
||||
for {
|
||||
dateStr := time.Now().Format("2006-01-02")
|
||||
@@ -89,7 +90,7 @@ func manageTrafficLogs(logsDir string) {
|
||||
}
|
||||
trafficLogger = log.New(io.MultiWriter(os.Stdout, trafficFile), "TRAFFIC: ", log.Ldate|log.Ltime)
|
||||
|
||||
// Cleanup old logs
|
||||
// Cleanup logs older than 7 days
|
||||
cleanupOldLogs(logsDir)
|
||||
|
||||
// Wait until next day
|
||||
@@ -99,6 +100,7 @@ func manageTrafficLogs(logsDir string) {
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupOldLogs removes traffic logs older than 7 days
|
||||
func cleanupOldLogs(logsDir string) {
|
||||
files, err := os.ReadDir(logsDir)
|
||||
if err != nil {
|
||||
@@ -106,7 +108,7 @@ func cleanupOldLogs(logsDir string) {
|
||||
return
|
||||
}
|
||||
|
||||
cutoff := time.Now().AddDate(0, 0, -7) // 7 days ago
|
||||
cutoff := time.Now().AddDate(0, 0, -7)
|
||||
for _, file := range files {
|
||||
if strings.HasPrefix(file.Name(), "traffic-") && file.Name() != "traffic-"+time.Now().Format("2006-01-02")+".log" {
|
||||
info, err := file.Info()
|
||||
@@ -124,7 +126,9 @@ func cleanupOldLogs(logsDir string) {
|
||||
}
|
||||
}
|
||||
|
||||
// main is the entry point, setting up and running the HTTP/HTTPS servers
|
||||
func main() {
|
||||
// Set config file path and load or generate initial config
|
||||
configPath = filepath.Join(baseDir, "config.yaml")
|
||||
if _, err := os.Stat(configPath); os.IsNotExist(err) {
|
||||
config = generateDefaultConfig()
|
||||
@@ -140,8 +144,10 @@ func main() {
|
||||
config = cfg
|
||||
}
|
||||
|
||||
// Update certificate and config paths based on loaded config
|
||||
updatePaths()
|
||||
|
||||
// Generate or load TLS certificate
|
||||
_, certErr := os.Stat(certPath)
|
||||
_, keyErr := os.Stat(keyPath)
|
||||
if os.IsNotExist(certErr) || os.IsNotExist(keyErr) {
|
||||
@@ -149,22 +155,24 @@ func main() {
|
||||
errorLogger.Fatalf("Failed to generate self-signed certificate: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := loadCertificate(); err != nil {
|
||||
errorLogger.Fatalf("Failed to load certificate: %v", err)
|
||||
}
|
||||
|
||||
// Start background monitoring for config and certificate changes
|
||||
go monitorCertificates()
|
||||
go monitorConfig()
|
||||
|
||||
// Configure HTTP server with timeouts for robustness
|
||||
httpServer := &http.Server{
|
||||
Addr: config.ListenHTTP,
|
||||
Handler: http.HandlerFunc(handler),
|
||||
MaxHeaderBytes: 1 << 20,
|
||||
MaxHeaderBytes: 1 << 20, // 1 MB max header size
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
// Configure HTTPS server with TLS and certificate fetching
|
||||
tlsConfig := &tls.Config{
|
||||
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
configMux.RLock()
|
||||
@@ -181,16 +189,16 @@ func main() {
|
||||
WriteTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
// Start HTTP server in a goroutine
|
||||
go func() {
|
||||
log.Printf("Starting HTTP server on %s", config.ListenHTTP)
|
||||
trafficLogger.Printf("Starting HTTP server on %s", config.ListenHTTP)
|
||||
refreshLogger.Printf("Starting HTTP server on %s", config.ListenHTTP)
|
||||
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
errorLogger.Fatalf("HTTP server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
log.Printf("Starting HTTPS server on %s", config.ListenHTTPS)
|
||||
trafficLogger.Printf("Starting HTTPS server on %s", config.ListenHTTPS)
|
||||
// Start HTTPS server in the main goroutine
|
||||
refreshLogger.Printf("Starting HTTPS server on %s", config.ListenHTTPS)
|
||||
if err := httpsServer.ListenAndServeTLS("", ""); err != nil && err != http.ErrServerClosed {
|
||||
errorLogger.Fatalf("HTTPS server error: %v", err)
|
||||
}
|
||||
|
||||
@@ -21,26 +21,29 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
// transportPool configures HTTP transport with timeouts matching Nginx defaults
|
||||
transportPool = &http.Transport{
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
DialContext: (&net.Dialer{Timeout: 30 * time.Second}).DialContext,
|
||||
ResponseHeaderTimeout: 60 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
MaxIdleConns: 100, // Max idle connections
|
||||
MaxIdleConnsPerHost: 10, // Max idle per host
|
||||
IdleConnTimeout: 90 * time.Second, // Idle connection timeout
|
||||
DialContext: (&net.Dialer{Timeout: 30 * time.Second}).DialContext, // Dial timeout
|
||||
ResponseHeaderTimeout: 60 * time.Second, // Response header timeout (Nginx-like)
|
||||
TLSHandshakeTimeout: 10 * time.Second, // TLS handshake timeout
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // Default skip TLS verification
|
||||
}
|
||||
cache = make(map[string]cachedResponse)
|
||||
defaultCacheTTL = 5 * time.Minute
|
||||
cacheMutex sync.RWMutex
|
||||
cache = make(map[string]cachedResponse) // Cache for static responses
|
||||
cacheMutex sync.RWMutex // Mutex for cache access
|
||||
|
||||
// Rate limiter per client IP
|
||||
// Rate limiting per client IP
|
||||
rateLimiters = make(map[string]*rate.Limiter)
|
||||
rateMutex sync.RWMutex
|
||||
rateLimit = rate.Limit(10) // 10 requests per second per client
|
||||
rateBurst = 20 // Allow burst of 20 requests
|
||||
rateLimit = rate.Limit(10) // 10 req/s per client
|
||||
rateBurst = 20 // Burst allowance
|
||||
|
||||
defaultCacheTTL = 5 * time.Minute // Default TTL for cached responses
|
||||
)
|
||||
|
||||
// cachedResponse stores cached response details
|
||||
type cachedResponse struct {
|
||||
body []byte
|
||||
headers http.Header
|
||||
@@ -50,7 +53,8 @@ type cachedResponse struct {
|
||||
etag string
|
||||
}
|
||||
|
||||
func getReverseProxy(target string, skipVerify bool) *httputil.ReverseProxy {
|
||||
// getReverseProxy creates a reverse proxy for a target URL
|
||||
func getReverseProxy(target string, skipVerify bool, originalReq *http.Request) *httputil.ReverseProxy {
|
||||
targetURL, err := url.Parse(target)
|
||||
if err != nil {
|
||||
log.Printf("Error parsing target URL %s: %v", target, err)
|
||||
@@ -58,42 +62,48 @@ func getReverseProxy(target string, skipVerify bool) *httputil.ReverseProxy {
|
||||
return nil
|
||||
}
|
||||
|
||||
// director modifies the request to forward it to the target
|
||||
director := func(req *http.Request) {
|
||||
req.URL.Scheme = targetURL.Scheme
|
||||
req.URL.Host = targetURL.Host
|
||||
// Preserve client's Host header for session continuity
|
||||
if req.Header.Get("Host") != "" {
|
||||
req.Host = req.Header.Get("Host")
|
||||
// Preserve the original client’s Host header for session continuity
|
||||
if originalReq.Header.Get("Host") != "" {
|
||||
req.Host = originalReq.Header.Get("Host")
|
||||
} else {
|
||||
req.Host = targetURL.Host
|
||||
}
|
||||
|
||||
// Preserve all request headers, including cookies
|
||||
for k, v := range req.Header {
|
||||
// Copy all request headers to ensure cookies and session data are preserved
|
||||
for k, v := range originalReq.Header {
|
||||
req.Header[k] = v
|
||||
}
|
||||
|
||||
// Preserve query parameters and path
|
||||
req.URL.RawQuery = req.URL.RawQuery
|
||||
// Preserve query parameters and full path
|
||||
req.URL.RawQuery = originalReq.URL.RawQuery
|
||||
if targetURL.Path != "" {
|
||||
req.URL.Path = strings.TrimPrefix(req.URL.Path, "/")
|
||||
req.URL.Path = strings.TrimPrefix(originalReq.URL.Path, "/")
|
||||
req.URL.Path = singleJoin(targetURL.Path, req.URL.Path)
|
||||
}
|
||||
|
||||
// WebSocket support: pass upgrade headers
|
||||
// Ensure WebSocket headers are passed correctly
|
||||
if strings.ToLower(req.Header.Get("Upgrade")) == "websocket" {
|
||||
req.Header.Set("Connection", "Upgrade")
|
||||
}
|
||||
|
||||
trafficLogger.Printf("Request: %s %s -> %s [Host: %s]", req.Method, req.URL.String(), target, req.Host)
|
||||
// Log request details for debugging
|
||||
trafficLogger.Printf("Request: %s %s -> %s [Host: %s] [Headers: %v]", req.Method, req.URL.String(), target, req.Host, req.Header)
|
||||
}
|
||||
|
||||
transport := transportPool
|
||||
if !skipVerify {
|
||||
// Use a new transport with TLS verification if skipVerify is false
|
||||
transport = &http.Transport{
|
||||
MaxIdleConns: transportPool.MaxIdleConns,
|
||||
MaxIdleConnsPerHost: transportPool.MaxIdleConnsPerHost,
|
||||
IdleConnTimeout: transportPool.IdleConnTimeout,
|
||||
DialContext: transportPool.DialContext,
|
||||
ResponseHeaderTimeout: transportPool.ResponseHeaderTimeout,
|
||||
TLSHandshakeTimeout: transportPool.TLSHandshakeTimeout,
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: false},
|
||||
}
|
||||
}
|
||||
@@ -101,18 +111,19 @@ func getReverseProxy(target string, skipVerify bool) *httputil.ReverseProxy {
|
||||
return &httputil.ReverseProxy{
|
||||
Director: director,
|
||||
Transport: transport,
|
||||
// ModifyResponse processes responses, avoiding corruption
|
||||
ModifyResponse: func(resp *http.Response) error {
|
||||
// Preserve all response headers, including Set-Cookie
|
||||
// Preserve all response headers, including Set-Cookie for sessions
|
||||
for k, v := range resp.Header {
|
||||
resp.Header[k] = v
|
||||
}
|
||||
|
||||
// Handle WebSocket upgrade
|
||||
// Skip modification for WebSocket responses
|
||||
if strings.ToLower(resp.Header.Get("Upgrade")) == "websocket" {
|
||||
return nil // No further modification needed for WebSocket
|
||||
return nil
|
||||
}
|
||||
|
||||
// Compression if client supports it and response isn’t already compressed
|
||||
// Apply compression only if not already compressed and client supports it
|
||||
if resp.Header.Get("Content-Encoding") == "" && strings.Contains(resp.Request.Header.Get("Accept-Encoding"), "gzip") {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
@@ -131,11 +142,11 @@ func getReverseProxy(target string, skipVerify bool) *httputil.ReverseProxy {
|
||||
|
||||
resp.Body = io.NopCloser(&buf)
|
||||
resp.Header.Set("Content-Encoding", "gzip")
|
||||
resp.Header.Del("Content-Length") // Length changes after compression
|
||||
resp.Header.Del("Content-Length")
|
||||
trafficLogger.Printf("Compressed response for %s", resp.Request.URL.String())
|
||||
}
|
||||
|
||||
// Cache static content with ETag support
|
||||
// Cache static content if applicable
|
||||
if shouldCache(resp) {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
@@ -147,7 +158,7 @@ func getReverseProxy(target string, skipVerify bool) *httputil.ReverseProxy {
|
||||
|
||||
etag := resp.Header.Get("ETag")
|
||||
if etag == "" {
|
||||
etag = generateETag(body) // Simple ETag generation if not provided
|
||||
etag = generateETag(body)
|
||||
}
|
||||
|
||||
cacheDuration := parseCacheControl(resp.Header.Get("Cache-Control"))
|
||||
@@ -165,13 +176,18 @@ func getReverseProxy(target string, skipVerify bool) *httputil.ReverseProxy {
|
||||
etag: etag,
|
||||
}
|
||||
cacheMutex.Unlock()
|
||||
resp.Header.Set("ETag", etag)
|
||||
trafficLogger.Printf("Cached response for %s [ETag: %s]", resp.Request.URL.String(), etag)
|
||||
}
|
||||
|
||||
trafficLogger.Printf("Response: %s %d from %s", resp.Status, resp.StatusCode, target)
|
||||
trafficLogger.Printf("Response: %s %d from %s [Headers: %v]", resp.Status, resp.StatusCode, target, resp.Header)
|
||||
return nil
|
||||
},
|
||||
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
if err.Error() == "context canceled" {
|
||||
trafficLogger.Printf("Request canceled by client for %s: %v", r.Host, err)
|
||||
return
|
||||
}
|
||||
log.Printf("Proxy error for %s: %v", r.Host, err)
|
||||
errorLogger.Printf("Proxy error for %s: %v", r.Host, err)
|
||||
http.Error(w, "Proxy error", http.StatusBadGateway)
|
||||
@@ -179,6 +195,7 @@ func getReverseProxy(target string, skipVerify bool) *httputil.ReverseProxy {
|
||||
}
|
||||
}
|
||||
|
||||
// shouldCache checks if a response should be cached based on method and content type
|
||||
func shouldCache(resp *http.Response) bool {
|
||||
if resp.Request.Method != "GET" || resp.StatusCode != http.StatusOK {
|
||||
return false
|
||||
@@ -188,25 +205,27 @@ func shouldCache(resp *http.Response) bool {
|
||||
strings.HasPrefix(contentType, "application/javascript") || strings.HasPrefix(contentType, "application/json")
|
||||
}
|
||||
|
||||
// singleJoin combines path segments with a single slash
|
||||
func singleJoin(prefix, suffix string) string {
|
||||
prefix = strings.TrimSuffix(prefix, "/")
|
||||
suffix = strings.TrimPrefix(suffix, "/")
|
||||
return prefix + "/" + suffix
|
||||
}
|
||||
|
||||
// generateETag creates an ETag from the response body using MD5
|
||||
func generateETag(body []byte) string {
|
||||
return fmt.Sprintf(`"%x"`, md5.Sum(body)) // Simple ETag based on MD5 hash
|
||||
return fmt.Sprintf(`"%x"`, md5.Sum(body))
|
||||
}
|
||||
|
||||
// parseCacheControl extracts max-age from the Cache-Control header for caching duration
|
||||
func parseCacheControl(header string) time.Duration {
|
||||
if header == "" {
|
||||
return 0
|
||||
}
|
||||
parts := strings.Split(header, ",")
|
||||
for _, part := range parts {
|
||||
if strings.Contains(part, "max-age=") {
|
||||
ageStr := strings.TrimPrefix(part, "max-age=")
|
||||
ageStr = strings.TrimSpace(ageStr)
|
||||
if strings.HasPrefix(part, "max-age=") {
|
||||
ageStr := strings.TrimSpace(strings.TrimPrefix(part, "max-age="))
|
||||
if age, err := strconv.Atoi(ageStr); err == nil {
|
||||
return time.Duration(age) * time.Second
|
||||
}
|
||||
@@ -215,6 +234,7 @@ func parseCacheControl(header string) time.Duration {
|
||||
return 0
|
||||
}
|
||||
|
||||
// getLimiter manages rate limiters per client IP
|
||||
func getLimiter(ip string) *rate.Limiter {
|
||||
rateMutex.Lock()
|
||||
defer rateMutex.Unlock()
|
||||
@@ -227,12 +247,13 @@ func getLimiter(ip string) *rate.Limiter {
|
||||
return limiter
|
||||
}
|
||||
|
||||
// handler processes incoming HTTP/HTTPS requests
|
||||
func handler(w http.ResponseWriter, r *http.Request) {
|
||||
configMux.RLock()
|
||||
defer configMux.RUnlock()
|
||||
|
||||
// Rate limiting based on client IP
|
||||
clientIP := r.RemoteAddr[:strings.LastIndex(r.RemoteAddr, ":")] // Strip port
|
||||
clientIP := r.RemoteAddr[:strings.LastIndex(r.RemoteAddr, ":")]
|
||||
limiter := getLimiter(clientIP)
|
||||
if !limiter.Allow() {
|
||||
http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
|
||||
@@ -240,6 +261,7 @@ func handler(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Retrieve target and settings for the requested host
|
||||
target, exists := config.Routes[r.Host]
|
||||
skipVerify := config.TrustTarget[r.Host]
|
||||
noHTTPSRedirect := config.NoHTTPSRedirect[r.Host]
|
||||
@@ -253,15 +275,16 @@ func handler(w http.ResponseWriter, r *http.Request) {
|
||||
noHTTPSRedirect = config.NoHTTPSRedirect["*"]
|
||||
}
|
||||
|
||||
// Handle HTTP->HTTPS redirect if applicable
|
||||
isHTTPS := target[:5] == "https"
|
||||
isHTTPReq := r.TLS == nil
|
||||
|
||||
if isHTTPReq && isHTTPS && !noHTTPSRedirect {
|
||||
redirectURL := "https://" + r.Host + r.RequestURI
|
||||
http.Redirect(w, r, redirectURL, http.StatusMovedPermanently)
|
||||
return
|
||||
}
|
||||
|
||||
// Check cache for static content
|
||||
cacheKey := r.URL.String()
|
||||
cacheMutex.RLock()
|
||||
if cached, ok := cache[cacheKey]; ok && time.Since(cached.cachedAt) < cached.cacheDuration {
|
||||
@@ -281,50 +304,65 @@ func handler(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
cacheMutex.RUnlock()
|
||||
|
||||
proxy := getReverseProxy(target, skipVerify)
|
||||
proxy := getReverseProxy(target, skipVerify, r)
|
||||
if proxy == nil {
|
||||
http.Error(w, "Invalid target configuration", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Handle WebSocket connections
|
||||
if strings.ToLower(r.Header.Get("Upgrade")) == "websocket" {
|
||||
targetURL, err := url.Parse(target)
|
||||
if err != nil {
|
||||
errorLogger.Printf("Error parsing target URL %s: %v", target, err)
|
||||
http.Error(w, "Invalid target configuration", http.StatusInternalServerError)
|
||||
errorLogger.Printf("Error parsing target URL for WebSocket: %v", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// WebSocket upgrade handling
|
||||
if strings.ToLower(r.Header.Get("Upgrade")) == "websocket" {
|
||||
hijacker, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
http.Error(w, "WebSocket upgrade not supported", http.StatusInternalServerError)
|
||||
errorLogger.Printf("WebSocket upgrade not supported by server")
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
conn, _, err := hijacker.Hijack()
|
||||
clientConn, _, err := hijacker.Hijack()
|
||||
if err != nil {
|
||||
errorLogger.Printf("Failed to hijack connection for WebSocket: %v", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
defer clientConn.Close()
|
||||
|
||||
targetConn, err := transportPool.Dial("tcp", targetURL.Host)
|
||||
dialer := transportPool
|
||||
if !skipVerify {
|
||||
dialer = &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: false},
|
||||
}
|
||||
}
|
||||
targetConn, err := dialer.Dial("tcp", targetURL.Host)
|
||||
if err != nil {
|
||||
errorLogger.Printf("Failed to dial target for WebSocket: %v", err)
|
||||
return
|
||||
}
|
||||
defer targetConn.Close()
|
||||
|
||||
// Forward request to target
|
||||
if err := r.Write(targetConn); err != nil {
|
||||
errorLogger.Printf("Failed to forward WebSocket request: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Pipe connections
|
||||
go io.Copy(conn, targetConn)
|
||||
io.Copy(targetConn, conn)
|
||||
errChan := make(chan error, 2)
|
||||
go func() {
|
||||
_, err := io.Copy(targetConn, clientConn)
|
||||
errChan <- err
|
||||
}()
|
||||
go func() {
|
||||
_, err := io.Copy(clientConn, targetConn)
|
||||
errChan <- err
|
||||
}()
|
||||
trafficLogger.Printf("WebSocket connection established: %s -> %s", r.Host, target)
|
||||
<-errChan
|
||||
trafficLogger.Printf("WebSocket connection closed: %s -> %s", r.Host, target)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
+168
@@ -0,0 +1,168 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Mock globals for testing
|
||||
var (
|
||||
mockConfig = Config{
|
||||
Routes: map[string]string{
|
||||
"test.local": "http://mock-target",
|
||||
"ws.local": "ws://mock-target",
|
||||
"cache.local": "http://mock-target",
|
||||
},
|
||||
TrustTarget: map[string]bool{
|
||||
"test.local": true,
|
||||
"ws.local": true,
|
||||
"cache.local": true,
|
||||
},
|
||||
NoHTTPSRedirect: map[string]bool{
|
||||
"test.local": true,
|
||||
"ws.local": true,
|
||||
"cache.local": true,
|
||||
},
|
||||
}
|
||||
mockConfigMux = sync.RWMutex{}
|
||||
mockLogger = log.New(io.Discard, "", 0) // Discard logs during testing
|
||||
)
|
||||
|
||||
// TestHandlerRoute tests basic routing to a target
|
||||
func TestHandlerRoute(t *testing.T) {
|
||||
// Mock target server
|
||||
targetServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("Hello from target"))
|
||||
}))
|
||||
defer targetServer.Close()
|
||||
|
||||
// Setup mock globals
|
||||
configMux.Lock()
|
||||
config = mockConfig
|
||||
config.Routes["test.local"] = targetServer.URL
|
||||
trafficLogger = mockLogger
|
||||
configMux.Unlock()
|
||||
|
||||
// Create request
|
||||
req, _ := http.NewRequest("GET", "http://test.local", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
// Run handler
|
||||
handler(rr, req)
|
||||
|
||||
// Check response
|
||||
if status := rr.Code; status != http.StatusOK {
|
||||
t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusOK)
|
||||
}
|
||||
if body := rr.Body.String(); body != "Hello from target" {
|
||||
t.Errorf("handler returned unexpected body: got %v want %v", body, "Hello from target")
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandlerWebSocket tests WebSocket upgrade handling
|
||||
func TestHandlerWebSocket(t *testing.T) {
|
||||
// Mock WebSocket target (simplified)
|
||||
targetServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("Upgrade") == "websocket" {
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
t.Fatal("Server doesn’t support hijacking")
|
||||
}
|
||||
conn, _, err := hj.Hijack()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
conn.Write([]byte("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n"))
|
||||
conn.Close()
|
||||
}
|
||||
}))
|
||||
defer targetServer.Close()
|
||||
|
||||
// Setup mock globals
|
||||
configMux.Lock()
|
||||
config = mockConfig
|
||||
config.Routes["ws.local"] = targetServer.URL
|
||||
trafficLogger = mockLogger
|
||||
configMux.Unlock()
|
||||
|
||||
// Create WebSocket request
|
||||
req, _ := http.NewRequest("GET", "http://ws.local", nil)
|
||||
req.Header.Set("Upgrade", "websocket")
|
||||
req.Header.Set("Connection", "Upgrade")
|
||||
|
||||
// Use a custom recorder for hijacking
|
||||
rr := &hijackRecorder{ResponseRecorder: httptest.NewRecorder()}
|
||||
handler(rr, req)
|
||||
|
||||
// Check if hijacking occurred
|
||||
if !rr.hijacked {
|
||||
t.Errorf("Expected WebSocket hijacking, but it didn’t occur")
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandlerCache tests caching functionality
|
||||
func TestHandlerCache(t *testing.T) {
|
||||
// Mock target server
|
||||
targetServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.Write([]byte("Cached content"))
|
||||
}))
|
||||
defer targetServer.Close()
|
||||
|
||||
// Setup mock globals
|
||||
configMux.Lock()
|
||||
config = mockConfig
|
||||
config.Routes["cache.local"] = targetServer.URL
|
||||
trafficLogger = mockLogger
|
||||
cache = make(map[string]cachedResponse) // Reset cache for test isolation
|
||||
configMux.Unlock()
|
||||
|
||||
// First request to cache
|
||||
req, _ := http.NewRequest("GET", "http://cache.local", nil)
|
||||
rr1 := httptest.NewRecorder()
|
||||
handler(rr1, req)
|
||||
|
||||
// Second request to hit cache
|
||||
rr2 := httptest.NewRecorder()
|
||||
handler(rr2, req)
|
||||
|
||||
// Check if second response is from cache
|
||||
if rr2.Code != http.StatusOK {
|
||||
t.Errorf("Expected status OK, got %v", rr2.Code)
|
||||
}
|
||||
if body := rr2.Body.String(); body != "Cached content" {
|
||||
t.Errorf("Expected cached response, got %v", body)
|
||||
}
|
||||
}
|
||||
|
||||
// hijackRecorder mocks ResponseRecorder with Hijack support for WebSocket testing
|
||||
type hijackRecorder struct {
|
||||
*httptest.ResponseRecorder
|
||||
hijacked bool
|
||||
}
|
||||
|
||||
func (hr *hijackRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
hr.hijacked = true
|
||||
return &mockConn{Reader: bytes.NewReader([]byte("mock response")), Writer: hr.Body}, nil, nil
|
||||
}
|
||||
|
||||
// mockConn is a simple net.Conn mock for testing
|
||||
type mockConn struct {
|
||||
io.Reader
|
||||
io.Writer
|
||||
}
|
||||
|
||||
func (mc *mockConn) Close() error { return nil }
|
||||
func (mc *mockConn) LocalAddr() net.Addr { return nil }
|
||||
func (mc *mockConn) RemoteAddr() net.Addr { return nil }
|
||||
func (mc *mockConn) SetDeadline(t time.Time) error { return nil }
|
||||
func (mc *mockConn) SetReadDeadline(t time.Time) error { return nil }
|
||||
func (mc *mockConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||
@@ -2,9 +2,10 @@ package main
|
||||
|
||||
import "path/filepath"
|
||||
|
||||
// updatePaths sets certificate and config file paths based on the current config
|
||||
func updatePaths() {
|
||||
certDir = filepath.Join(baseDir, config.CertDir)
|
||||
certPath = filepath.Join(certDir, config.CertFile)
|
||||
keyPath = filepath.Join(certDir, config.KeyFile)
|
||||
configPath = filepath.Join(baseDir, "config.yaml")
|
||||
certDir = filepath.Join(baseDir, config.CertDir) // Certificate directory path
|
||||
certPath = filepath.Join(certDir, config.CertFile) // Full certificate file path
|
||||
keyPath = filepath.Join(certDir, config.KeyFile) // Full private key file path
|
||||
configPath = filepath.Join(baseDir, "config.yaml") // Config file path
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user