caching and logging
This commit is contained in:
@@ -0,0 +1,125 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"encoding/pem"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"math/big"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func generateSelfSignedCert() error {
|
||||||
|
if err := os.MkdirAll(certDir, 0755); err != nil {
|
||||||
|
return fmt.Errorf("failed to create certificate directory %s: %v", certDir, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to generate private key: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
notBefore := time.Now()
|
||||||
|
notAfter := notBefore.Add(365 * 24 * time.Hour)
|
||||||
|
serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to generate serial number: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
template := x509.Certificate{
|
||||||
|
SerialNumber: serialNumber,
|
||||||
|
Subject: pkix.Name{
|
||||||
|
Organization: []string{"Proxy Self-Signed"},
|
||||||
|
CommonName: "localhost",
|
||||||
|
},
|
||||||
|
NotBefore: notBefore,
|
||||||
|
NotAfter: notAfter,
|
||||||
|
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||||
|
BasicConstraintsValid: true,
|
||||||
|
DNSNames: []string{"localhost", "*.example.com"},
|
||||||
|
}
|
||||||
|
|
||||||
|
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create certificate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
certOut, err := os.Create(certPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to open %s for writing: %v", certPath, err)
|
||||||
|
}
|
||||||
|
if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}); err != nil {
|
||||||
|
certOut.Close()
|
||||||
|
return fmt.Errorf("failed to encode certificate: %v", err)
|
||||||
|
}
|
||||||
|
certOut.Close()
|
||||||
|
|
||||||
|
keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to open %s for writing: %v", keyPath, err)
|
||||||
|
}
|
||||||
|
if err := pem.Encode(keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}); err != nil {
|
||||||
|
keyOut.Close()
|
||||||
|
return fmt.Errorf("failed to encode private key: %v", err)
|
||||||
|
}
|
||||||
|
keyOut.Close()
|
||||||
|
|
||||||
|
log.Printf("Generated self-signed certificate in %s", certDir)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadCertificate() error {
|
||||||
|
certFile, err := os.ReadFile(certPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read certificate %s: %v", certPath, err)
|
||||||
|
}
|
||||||
|
keyFile, err := os.ReadFile(keyPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read key %s: %v", keyPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
newCert, err := tls.X509KeyPair(certFile, keyFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse certificate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
configMux.Lock()
|
||||||
|
cert = &newCert
|
||||||
|
configMux.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func monitorCertificates() {
|
||||||
|
var lastModTime time.Time
|
||||||
|
for {
|
||||||
|
certInfo, err := os.Stat(certPath)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Error checking certificate: %v", err)
|
||||||
|
time.Sleep(5 * time.Second)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
keyInfo, err := os.Stat(keyPath)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Error checking key: %v", err)
|
||||||
|
time.Sleep(5 * time.Second)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if certInfo.ModTime() != lastModTime || keyInfo.ModTime() != lastModTime {
|
||||||
|
if err := loadCertificate(); err != nil {
|
||||||
|
log.Printf("Error reloading certificate: %v", err)
|
||||||
|
} else {
|
||||||
|
log.Println("Certificate reloaded successfully")
|
||||||
|
lastModTime = certInfo.ModTime()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
time.Sleep(5 * time.Second)
|
||||||
|
}
|
||||||
|
}
|
||||||
+154
@@ -0,0 +1,154 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
ListenHTTP string `yaml:"listen_http"`
|
||||||
|
ListenHTTPS string `yaml:"listen_https"`
|
||||||
|
CertDir string `yaml:"cert_dir"`
|
||||||
|
CertFile string `yaml:"cert_file"`
|
||||||
|
KeyFile string `yaml:"key_file"`
|
||||||
|
Routes map[string]string `yaml:"routes"`
|
||||||
|
TrustTarget map[string]bool `yaml:"trust_target"`
|
||||||
|
NoHTTPSRedirect map[string]bool `yaml:"no_https_redirect"` // New field
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadConfig() (Config, error) {
|
||||||
|
data, err := os.ReadFile(configPath)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("failed to read config %s: %v", configPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var cfg Config
|
||||||
|
err = yaml.Unmarshal(data, &cfg)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("failed to unmarshal config: %v", err)
|
||||||
|
}
|
||||||
|
return cfg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateDefaultConfig() Config {
|
||||||
|
return Config{
|
||||||
|
ListenHTTP: ":80",
|
||||||
|
ListenHTTPS: ":443",
|
||||||
|
CertDir: "certificates", // default current directory creates certificates folder
|
||||||
|
CertFile: "certificate.pem",
|
||||||
|
KeyFile: "key.pem",
|
||||||
|
Routes: map[string]string{
|
||||||
|
"*": "http://127.0.0.1:80",
|
||||||
|
"main.example.com": "http://127.0.0.1:80",
|
||||||
|
},
|
||||||
|
TrustTarget: map[string]bool{
|
||||||
|
"*": true, // default trust all certificates
|
||||||
|
"main.example.com": false, // use only trusted certificate
|
||||||
|
},
|
||||||
|
NoHTTPSRedirect: map[string]bool{
|
||||||
|
"*": false, // Default: redirect to HTTPS
|
||||||
|
"main.example.com": true, // set to not redirect HTTP to HTTPS
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func saveConfig(cfg Config) error {
|
||||||
|
if err := os.MkdirAll(baseDir, 0755); err != nil {
|
||||||
|
return fmt.Errorf("failed to create base directory %s: %v", baseDir, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := yaml.Marshal(cfg)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = os.WriteFile(configPath, data, 0644)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to write config to %s: %v", configPath, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func monitorConfig() {
|
||||||
|
var lastModTime time.Time
|
||||||
|
for {
|
||||||
|
configInfo, err := os.Stat(configPath)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Error checking config file: %v", err)
|
||||||
|
time.Sleep(5 * time.Second)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if configInfo.ModTime() != lastModTime {
|
||||||
|
newConfig, err := loadConfig()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Error reloading config: %v", err)
|
||||||
|
} else {
|
||||||
|
configMux.Lock()
|
||||||
|
if newConfig.ListenHTTP != config.ListenHTTP {
|
||||||
|
config.ListenHTTP = newConfig.ListenHTTP
|
||||||
|
log.Printf("Updated listen_http to %s", config.ListenHTTP)
|
||||||
|
}
|
||||||
|
if newConfig.ListenHTTPS != config.ListenHTTPS {
|
||||||
|
config.ListenHTTPS = newConfig.ListenHTTPS
|
||||||
|
log.Printf("Updated listen_https to %s", config.ListenHTTPS)
|
||||||
|
}
|
||||||
|
if newConfig.CertDir != config.CertDir || newConfig.CertFile != config.CertFile || newConfig.KeyFile != config.KeyFile {
|
||||||
|
config.CertDir = newConfig.CertDir
|
||||||
|
config.CertFile = newConfig.CertFile
|
||||||
|
config.KeyFile = newConfig.KeyFile
|
||||||
|
updatePaths()
|
||||||
|
if err := loadCertificate(); err != nil {
|
||||||
|
log.Printf("Error reloading certificate after path change: %v", err)
|
||||||
|
} else {
|
||||||
|
log.Println("Updated certificate paths and reloaded certificate")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for k, v := range newConfig.Routes {
|
||||||
|
if oldV, exists := config.Routes[k]; !exists || oldV != v {
|
||||||
|
config.Routes[k] = v
|
||||||
|
log.Printf("Updated route %s to %s", k, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for k := range config.Routes {
|
||||||
|
if _, exists := newConfig.Routes[k]; !exists {
|
||||||
|
delete(config.Routes, k)
|
||||||
|
log.Printf("Removed route %s", k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for k, v := range newConfig.TrustTarget {
|
||||||
|
if oldV, exists := config.TrustTarget[k]; !exists || oldV != v {
|
||||||
|
config.TrustTarget[k] = v
|
||||||
|
log.Printf("Updated trust_target %s to %v", k, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for k := range config.TrustTarget {
|
||||||
|
if _, exists := newConfig.TrustTarget[k]; !exists {
|
||||||
|
delete(config.TrustTarget, k)
|
||||||
|
log.Printf("Removed trust_target %s", k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for k, v := range newConfig.NoHTTPSRedirect {
|
||||||
|
if oldV, exists := config.NoHTTPSRedirect[k]; !exists || oldV != v {
|
||||||
|
config.NoHTTPSRedirect[k] = v
|
||||||
|
log.Printf("Updated no_https_redirect %s to %v", k, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for k := range config.NoHTTPSRedirect {
|
||||||
|
if _, exists := newConfig.NoHTTPSRedirect[k]; !exists {
|
||||||
|
delete(config.NoHTTPSRedirect, k)
|
||||||
|
log.Printf("Removed no_https_redirect %s", k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
configMux.Unlock()
|
||||||
|
log.Println("Config reloaded successfully")
|
||||||
|
lastModTime = configInfo.ModTime()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
time.Sleep(5 * time.Second)
|
||||||
|
}
|
||||||
|
}
|
||||||
+105
@@ -0,0 +1,105 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
// go run main.go config.go certificate.go proxy.go utils.go
|
||||||
|
// go build -o proxy
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Shared global variables (declared only here)
|
||||||
|
var (
|
||||||
|
config Config
|
||||||
|
configMux sync.RWMutex
|
||||||
|
cert *tls.Certificate
|
||||||
|
baseDir string
|
||||||
|
certDir string
|
||||||
|
certPath string
|
||||||
|
keyPath string
|
||||||
|
configPath string
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
if runtime.GOOS == "windows" && len(os.Args) > 0 && filepath.Ext(os.Args[0]) == ".go" {
|
||||||
|
var err error
|
||||||
|
baseDir, err = os.Getwd()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to get working directory: %v", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
exePath, err := os.Executable()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to get executable path: %v", err)
|
||||||
|
}
|
||||||
|
baseDir = filepath.Dir(exePath)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
configPath = filepath.Join(baseDir, "config.yaml")
|
||||||
|
if _, err := os.Stat(configPath); os.IsNotExist(err) {
|
||||||
|
config = generateDefaultConfig()
|
||||||
|
if err := saveConfig(config); err != nil {
|
||||||
|
log.Fatalf("Failed to save default config: %v", err)
|
||||||
|
}
|
||||||
|
log.Println("Generated default config file")
|
||||||
|
} else {
|
||||||
|
cfg, err := loadConfig()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to load config: %v", err)
|
||||||
|
}
|
||||||
|
config = cfg
|
||||||
|
}
|
||||||
|
|
||||||
|
updatePaths()
|
||||||
|
|
||||||
|
_, certErr := os.Stat(certPath)
|
||||||
|
_, keyErr := os.Stat(keyPath)
|
||||||
|
if os.IsNotExist(certErr) || os.IsNotExist(keyErr) {
|
||||||
|
if err := generateSelfSignedCert(); err != nil {
|
||||||
|
log.Fatalf("Failed to generate self-signed certificate: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := loadCertificate(); err != nil {
|
||||||
|
log.Fatalf("Failed to load certificate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
go monitorCertificates()
|
||||||
|
go monitorConfig()
|
||||||
|
|
||||||
|
httpServer := &http.Server{
|
||||||
|
Addr: config.ListenHTTP,
|
||||||
|
Handler: http.HandlerFunc(handler),
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
|
configMux.RLock()
|
||||||
|
defer configMux.RUnlock()
|
||||||
|
return cert, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
httpsServer := &http.Server{
|
||||||
|
Addr: config.ListenHTTPS,
|
||||||
|
Handler: http.HandlerFunc(handler),
|
||||||
|
TLSConfig: tlsConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
log.Printf("Starting HTTP server on %s", config.ListenHTTP)
|
||||||
|
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||||
|
log.Fatalf("HTTP server error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
log.Printf("Starting HTTPS server on %s", config.ListenHTTPS)
|
||||||
|
if err := httpsServer.ListenAndServeTLS("", ""); err != nil && err != http.ErrServerClosed {
|
||||||
|
log.Fatalf("HTTPS server error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,93 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httputil"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func getReverseProxy(target string, skipVerify bool) *httputil.ReverseProxy {
|
||||||
|
targetURL, err := url.Parse(target)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Error parsing target URL %s: %v", target, err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
director := func(req *http.Request) {
|
||||||
|
req.URL.Scheme = targetURL.Scheme
|
||||||
|
req.URL.Host = targetURL.Host
|
||||||
|
req.Host = targetURL.Host // Set Host header to match target
|
||||||
|
|
||||||
|
// Preserve original headers
|
||||||
|
for k, v := range req.Header {
|
||||||
|
if k != "Host" { // Host is set above
|
||||||
|
req.Header[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure the full path is preserved
|
||||||
|
if targetURL.Path != "" {
|
||||||
|
req.URL.Path = strings.TrimPrefix(req.URL.Path, "/") // Avoid double slashes
|
||||||
|
req.URL.Path = singleJoin(targetURL.Path, req.URL.Path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
transport := &http.Transport{
|
||||||
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: skipVerify},
|
||||||
|
}
|
||||||
|
|
||||||
|
return &httputil.ReverseProxy{
|
||||||
|
Director: director,
|
||||||
|
Transport: transport,
|
||||||
|
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
|
||||||
|
log.Printf("Proxy error for %s: %v", r.Host, err)
|
||||||
|
http.Error(w, "Proxy error", http.StatusBadGateway)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// singleJoin ensures a single slash between path segments
|
||||||
|
func singleJoin(prefix, suffix string) string {
|
||||||
|
prefix = strings.TrimSuffix(prefix, "/")
|
||||||
|
suffix = strings.TrimPrefix(suffix, "/")
|
||||||
|
return prefix + "/" + suffix
|
||||||
|
}
|
||||||
|
|
||||||
|
func handler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
configMux.RLock()
|
||||||
|
defer configMux.RUnlock()
|
||||||
|
|
||||||
|
target, exists := config.Routes[r.Host]
|
||||||
|
skipVerify := config.TrustTarget[r.Host]
|
||||||
|
noHTTPSRedirect := config.NoHTTPSRedirect[r.Host]
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
if target, exists = config.Routes["*"]; !exists {
|
||||||
|
http.Error(w, "Host not configured", http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
skipVerify = config.TrustTarget["*"]
|
||||||
|
noHTTPSRedirect = config.NoHTTPSRedirect["*"]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if request is HTTP and target is HTTPS
|
||||||
|
isHTTPS := target[:5] == "https"
|
||||||
|
isHTTPReq := r.TLS == nil // r.TLS is nil for HTTP, non-nil for HTTPS
|
||||||
|
|
||||||
|
if isHTTPReq && isHTTPS && !noHTTPSRedirect {
|
||||||
|
// Redirect to HTTPS version of the host
|
||||||
|
redirectURL := "https://" + r.Host + r.RequestURI
|
||||||
|
http.Redirect(w, r, redirectURL, http.StatusMovedPermanently)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy := getReverseProxy(target, skipVerify)
|
||||||
|
if proxy == nil {
|
||||||
|
http.Error(w, "Invalid target configuration", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
proxy.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import "path/filepath"
|
||||||
|
|
||||||
|
func updatePaths() {
|
||||||
|
certDir = filepath.Join(baseDir, config.CertDir)
|
||||||
|
certPath = filepath.Join(certDir, config.CertFile)
|
||||||
|
keyPath = filepath.Join(certDir, config.KeyFile)
|
||||||
|
configPath = filepath.Join(baseDir, "config.yaml")
|
||||||
|
}
|
||||||
+6
-7
@@ -8,7 +8,6 @@ import (
|
|||||||
"crypto/x509/pkix"
|
"crypto/x509/pkix"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
|
||||||
"math/big"
|
"math/big"
|
||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
@@ -60,7 +59,7 @@ func generateSelfSignedCert() error {
|
|||||||
}
|
}
|
||||||
certOut.Close()
|
certOut.Close()
|
||||||
|
|
||||||
keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_WRONLY, 0600)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to open %s for writing: %v", keyPath, err)
|
return fmt.Errorf("failed to open %s for writing: %v", keyPath, err)
|
||||||
}
|
}
|
||||||
@@ -70,7 +69,7 @@ func generateSelfSignedCert() error {
|
|||||||
}
|
}
|
||||||
keyOut.Close()
|
keyOut.Close()
|
||||||
|
|
||||||
log.Printf("Generated self-signed certificate in %s", certDir)
|
refreshLogger.Printf("Generated self-signed certificate in %s", certDir)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -100,23 +99,23 @@ func monitorCertificates() {
|
|||||||
for {
|
for {
|
||||||
certInfo, err := os.Stat(certPath)
|
certInfo, err := os.Stat(certPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Error checking certificate: %v", err)
|
errorLogger.Printf("Error checking certificate: %v", err)
|
||||||
time.Sleep(5 * time.Second)
|
time.Sleep(5 * time.Second)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
keyInfo, err := os.Stat(keyPath)
|
keyInfo, err := os.Stat(keyPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Error checking key: %v", err)
|
errorLogger.Printf("Error checking key: %v", err)
|
||||||
time.Sleep(5 * time.Second)
|
time.Sleep(5 * time.Second)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if certInfo.ModTime() != lastModTime || keyInfo.ModTime() != lastModTime {
|
if certInfo.ModTime() != lastModTime || keyInfo.ModTime() != lastModTime {
|
||||||
if err := loadCertificate(); err != nil {
|
if err := loadCertificate(); err != nil {
|
||||||
log.Printf("Error reloading certificate: %v", err)
|
errorLogger.Printf("Error reloading certificate: %v", err)
|
||||||
} else {
|
} else {
|
||||||
log.Println("Certificate reloaded successfully")
|
refreshLogger.Println("Certificate reloaded successfully")
|
||||||
lastModTime = certInfo.ModTime()
|
lastModTime = certInfo.ModTime()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
|
||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -17,7 +16,7 @@ type Config struct {
|
|||||||
KeyFile string `yaml:"key_file"`
|
KeyFile string `yaml:"key_file"`
|
||||||
Routes map[string]string `yaml:"routes"`
|
Routes map[string]string `yaml:"routes"`
|
||||||
TrustTarget map[string]bool `yaml:"trust_target"`
|
TrustTarget map[string]bool `yaml:"trust_target"`
|
||||||
NoHTTPSRedirect map[string]bool `yaml:"no_https_redirect"` // New field
|
NoHTTPSRedirect map[string]bool `yaml:"no_https_redirect"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func loadConfig() (Config, error) {
|
func loadConfig() (Config, error) {
|
||||||
@@ -38,20 +37,20 @@ func generateDefaultConfig() Config {
|
|||||||
return Config{
|
return Config{
|
||||||
ListenHTTP: ":80",
|
ListenHTTP: ":80",
|
||||||
ListenHTTPS: ":443",
|
ListenHTTPS: ":443",
|
||||||
CertDir: "./certificates", // default current directory creates certificates folder
|
CertDir: "certificates",
|
||||||
CertFile: "certificate.pem",
|
CertFile: "certificate.pem",
|
||||||
KeyFile: "key.pem",
|
KeyFile: "key.pem",
|
||||||
Routes: map[string]string{
|
Routes: map[string]string{
|
||||||
"*": "http://127.0.0.1:80",
|
"*": "https://127.0.0.1:3000",
|
||||||
"main.example.com": "http://127.0.0.1:80",
|
"main.example.com": "https://10.100.111.254:4444",
|
||||||
},
|
},
|
||||||
TrustTarget: map[string]bool{
|
TrustTarget: map[string]bool{
|
||||||
"*": true, // default trust all certificates
|
"*": true,
|
||||||
"main.example.com": false, // use only trusted certificate
|
"main.example.com": true,
|
||||||
},
|
},
|
||||||
NoHTTPSRedirect: map[string]bool{
|
NoHTTPSRedirect: map[string]bool{
|
||||||
"*": false, // Default: redirect to HTTPS
|
"*": false,
|
||||||
"main.example.com": true, // set to not redirect HTTP to HTTPS
|
"main.example.com": false,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -78,7 +77,7 @@ func monitorConfig() {
|
|||||||
for {
|
for {
|
||||||
configInfo, err := os.Stat(configPath)
|
configInfo, err := os.Stat(configPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Error checking config file: %v", err)
|
errorLogger.Printf("Error checking config file: %v", err)
|
||||||
time.Sleep(5 * time.Second)
|
time.Sleep(5 * time.Second)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -86,16 +85,16 @@ func monitorConfig() {
|
|||||||
if configInfo.ModTime() != lastModTime {
|
if configInfo.ModTime() != lastModTime {
|
||||||
newConfig, err := loadConfig()
|
newConfig, err := loadConfig()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Error reloading config: %v", err)
|
errorLogger.Printf("Error reloading config: %v", err)
|
||||||
} else {
|
} else {
|
||||||
configMux.Lock()
|
configMux.Lock()
|
||||||
if newConfig.ListenHTTP != config.ListenHTTP {
|
if newConfig.ListenHTTP != config.ListenHTTP {
|
||||||
config.ListenHTTP = newConfig.ListenHTTP
|
config.ListenHTTP = newConfig.ListenHTTP
|
||||||
log.Printf("Updated listen_http to %s", config.ListenHTTP)
|
refreshLogger.Printf("Updated listen_http to %s", config.ListenHTTP)
|
||||||
}
|
}
|
||||||
if newConfig.ListenHTTPS != config.ListenHTTPS {
|
if newConfig.ListenHTTPS != config.ListenHTTPS {
|
||||||
config.ListenHTTPS = newConfig.ListenHTTPS
|
config.ListenHTTPS = newConfig.ListenHTTPS
|
||||||
log.Printf("Updated listen_https to %s", config.ListenHTTPS)
|
refreshLogger.Printf("Updated listen_https to %s", config.ListenHTTPS)
|
||||||
}
|
}
|
||||||
if newConfig.CertDir != config.CertDir || newConfig.CertFile != config.CertFile || newConfig.KeyFile != config.KeyFile {
|
if newConfig.CertDir != config.CertDir || newConfig.CertFile != config.CertFile || newConfig.KeyFile != config.KeyFile {
|
||||||
config.CertDir = newConfig.CertDir
|
config.CertDir = newConfig.CertDir
|
||||||
@@ -103,49 +102,49 @@ func monitorConfig() {
|
|||||||
config.KeyFile = newConfig.KeyFile
|
config.KeyFile = newConfig.KeyFile
|
||||||
updatePaths()
|
updatePaths()
|
||||||
if err := loadCertificate(); err != nil {
|
if err := loadCertificate(); err != nil {
|
||||||
log.Printf("Error reloading certificate after path change: %v", err)
|
errorLogger.Printf("Error reloading certificate after path change: %v", err)
|
||||||
} else {
|
} else {
|
||||||
log.Println("Updated certificate paths and reloaded certificate")
|
refreshLogger.Println("Updated certificate paths and reloaded certificate")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for k, v := range newConfig.Routes {
|
for k, v := range newConfig.Routes {
|
||||||
if oldV, exists := config.Routes[k]; !exists || oldV != v {
|
if oldV, exists := config.Routes[k]; !exists || oldV != v {
|
||||||
config.Routes[k] = v
|
config.Routes[k] = v
|
||||||
log.Printf("Updated route %s to %s", k, v)
|
refreshLogger.Printf("Updated route %s to %s", k, v)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for k := range config.Routes {
|
for k := range config.Routes {
|
||||||
if _, exists := newConfig.Routes[k]; !exists {
|
if _, exists := newConfig.Routes[k]; !exists {
|
||||||
delete(config.Routes, k)
|
delete(config.Routes, k)
|
||||||
log.Printf("Removed route %s", k)
|
refreshLogger.Printf("Removed route %s", k)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for k, v := range newConfig.TrustTarget {
|
for k, v := range newConfig.TrustTarget {
|
||||||
if oldV, exists := config.TrustTarget[k]; !exists || oldV != v {
|
if oldV, exists := config.TrustTarget[k]; !exists || oldV != v {
|
||||||
config.TrustTarget[k] = v
|
config.TrustTarget[k] = v
|
||||||
log.Printf("Updated trust_target %s to %v", k, v)
|
refreshLogger.Printf("Updated trust_target %s to %v", k, v)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for k := range config.TrustTarget {
|
for k := range config.TrustTarget {
|
||||||
if _, exists := newConfig.TrustTarget[k]; !exists {
|
if _, exists := newConfig.TrustTarget[k]; !exists {
|
||||||
delete(config.TrustTarget, k)
|
delete(config.TrustTarget, k)
|
||||||
log.Printf("Removed trust_target %s", k)
|
refreshLogger.Printf("Removed trust_target %s", k)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for k, v := range newConfig.NoHTTPSRedirect {
|
for k, v := range newConfig.NoHTTPSRedirect {
|
||||||
if oldV, exists := config.NoHTTPSRedirect[k]; !exists || oldV != v {
|
if oldV, exists := config.NoHTTPSRedirect[k]; !exists || oldV != v {
|
||||||
config.NoHTTPSRedirect[k] = v
|
config.NoHTTPSRedirect[k] = v
|
||||||
log.Printf("Updated no_https_redirect %s to %v", k, v)
|
refreshLogger.Printf("Updated no_https_redirect %s to %v", k, v)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for k := range config.NoHTTPSRedirect {
|
for k := range config.NoHTTPSRedirect {
|
||||||
if _, exists := newConfig.NoHTTPSRedirect[k]; !exists {
|
if _, exists := newConfig.NoHTTPSRedirect[k]; !exists {
|
||||||
delete(config.NoHTTPSRedirect, k)
|
delete(config.NoHTTPSRedirect, k)
|
||||||
log.Printf("Removed no_https_redirect %s", k)
|
refreshLogger.Printf("Removed no_https_redirect %s", k)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
configMux.Unlock()
|
configMux.Unlock()
|
||||||
log.Println("Config reloaded successfully")
|
refreshLogger.Println("Config reloaded successfully")
|
||||||
lastModTime = configInfo.ModTime()
|
lastModTime = configInfo.ModTime()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,3 +3,9 @@ module main
|
|||||||
go 1.24.0
|
go 1.24.0
|
||||||
|
|
||||||
require gopkg.in/yaml.v2 v2.4.0
|
require gopkg.in/yaml.v2 v2.4.0
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/fsnotify/fsnotify v1.8.0 // indirect
|
||||||
|
golang.org/x/sys v0.13.0 // indirect
|
||||||
|
golang.org/x/time v0.10.0 // indirect
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,3 +1,9 @@
|
|||||||
|
github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/8M=
|
||||||
|
github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||||
|
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
|
||||||
|
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/time v0.10.0 h1:3usCWA8tQn0L8+hFJQNgzpWbd89begxN66o1Ojdn5L4=
|
||||||
|
golang.org/x/time v0.10.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||||
|
|||||||
@@ -1,18 +1,17 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
// go run main.go config.go certificate.go proxy.go utils.go
|
|
||||||
// go build -o proxy
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Shared global variables (declared only here)
|
|
||||||
var (
|
var (
|
||||||
config Config
|
config Config
|
||||||
configMux sync.RWMutex
|
configMux sync.RWMutex
|
||||||
@@ -22,21 +21,106 @@ var (
|
|||||||
certPath string
|
certPath string
|
||||||
keyPath string
|
keyPath string
|
||||||
configPath string
|
configPath string
|
||||||
|
|
||||||
|
// Loggers
|
||||||
|
errorLogger *log.Logger
|
||||||
|
refreshLogger *log.Logger
|
||||||
|
trafficLogger *log.Logger
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
if runtime.GOOS == "windows" && len(os.Args) > 0 && filepath.Ext(os.Args[0]) == ".go" {
|
// Get the absolute path of the running executable (or main.go in `go run`)
|
||||||
var err error
|
exePath, err := os.Executable()
|
||||||
baseDir, err = os.Getwd()
|
if err != nil {
|
||||||
if err != nil {
|
log.Fatalf("Failed to get executable path: %v", err)
|
||||||
log.Fatalf("Failed to get working directory: %v", err)
|
}
|
||||||
|
baseDir = filepath.Dir(exePath)
|
||||||
|
/*
|
||||||
|
if runtime.GOOS == "windows" && len(os.Args) > 0 && filepath.Ext(os.Args[0]) == ".go" {
|
||||||
|
var err error
|
||||||
|
baseDir, err = os.Getwd()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to get working directory: %v", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
exePath, err := os.Executable()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to get executable path: %v", err)
|
||||||
|
}
|
||||||
|
baseDir = filepath.Dir(exePath)
|
||||||
}
|
}
|
||||||
} else {
|
*/
|
||||||
exePath, err := os.Executable()
|
// Setup logging
|
||||||
|
setupLogging()
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupLogging() {
|
||||||
|
logsDir := filepath.Join(baseDir, "logs")
|
||||||
|
if err := os.MkdirAll(logsDir, 0755); err != nil {
|
||||||
|
log.Fatalf("Failed to create logs directory: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error log
|
||||||
|
errorFile, err := os.OpenFile(filepath.Join(logsDir, "errors.log"), os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to open errors.log: %v", err)
|
||||||
|
}
|
||||||
|
errorLogger = log.New(io.MultiWriter(os.Stdout, errorFile), "ERROR: ", log.Ldate|log.Ltime|log.Lshortfile)
|
||||||
|
|
||||||
|
// Refresh log
|
||||||
|
refreshFile, err := os.OpenFile(filepath.Join(logsDir, "refresh.log"), os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to open refresh.log: %v", err)
|
||||||
|
}
|
||||||
|
refreshLogger = log.New(io.MultiWriter(os.Stdout, refreshFile), "REFRESH: ", log.Ldate|log.Ltime|log.Lshortfile)
|
||||||
|
|
||||||
|
// Traffic log (daily rotation)
|
||||||
|
go manageTrafficLogs(logsDir)
|
||||||
|
}
|
||||||
|
|
||||||
|
func manageTrafficLogs(logsDir string) {
|
||||||
|
for {
|
||||||
|
dateStr := time.Now().Format("2006-01-02")
|
||||||
|
trafficFile, err := os.OpenFile(filepath.Join(logsDir, "traffic-"+dateStr+".log"), os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Failed to get executable path: %v", err)
|
errorLogger.Printf("Failed to open traffic log: %v", err)
|
||||||
|
time.Sleep(1 * time.Minute)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
trafficLogger = log.New(io.MultiWriter(os.Stdout, trafficFile), "TRAFFIC: ", log.Ldate|log.Ltime)
|
||||||
|
|
||||||
|
// Cleanup old logs
|
||||||
|
cleanupOldLogs(logsDir)
|
||||||
|
|
||||||
|
// Wait until next day
|
||||||
|
nextDay := time.Now().Truncate(24 * time.Hour).Add(24 * time.Hour)
|
||||||
|
time.Sleep(time.Until(nextDay))
|
||||||
|
trafficFile.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func cleanupOldLogs(logsDir string) {
|
||||||
|
files, err := os.ReadDir(logsDir)
|
||||||
|
if err != nil {
|
||||||
|
errorLogger.Printf("Failed to read logs directory: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cutoff := time.Now().AddDate(0, 0, -7) // 7 days ago
|
||||||
|
for _, file := range files {
|
||||||
|
if strings.HasPrefix(file.Name(), "traffic-") && file.Name() != "traffic-"+time.Now().Format("2006-01-02")+".log" {
|
||||||
|
info, err := file.Info()
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if info.ModTime().Before(cutoff) {
|
||||||
|
if err := os.Remove(filepath.Join(logsDir, file.Name())); err != nil {
|
||||||
|
errorLogger.Printf("Failed to remove old log %s: %v", file.Name(), err)
|
||||||
|
} else {
|
||||||
|
refreshLogger.Printf("Removed old traffic log: %s", file.Name())
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
baseDir = filepath.Dir(exePath)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -45,13 +129,13 @@ func main() {
|
|||||||
if _, err := os.Stat(configPath); os.IsNotExist(err) {
|
if _, err := os.Stat(configPath); os.IsNotExist(err) {
|
||||||
config = generateDefaultConfig()
|
config = generateDefaultConfig()
|
||||||
if err := saveConfig(config); err != nil {
|
if err := saveConfig(config); err != nil {
|
||||||
log.Fatalf("Failed to save default config: %v", err)
|
errorLogger.Fatalf("Failed to save default config: %v", err)
|
||||||
}
|
}
|
||||||
log.Println("Generated default config file")
|
refreshLogger.Println("Generated default config file")
|
||||||
} else {
|
} else {
|
||||||
cfg, err := loadConfig()
|
cfg, err := loadConfig()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Failed to load config: %v", err)
|
errorLogger.Fatalf("Failed to load config: %v", err)
|
||||||
}
|
}
|
||||||
config = cfg
|
config = cfg
|
||||||
}
|
}
|
||||||
@@ -62,20 +146,23 @@ func main() {
|
|||||||
_, keyErr := os.Stat(keyPath)
|
_, keyErr := os.Stat(keyPath)
|
||||||
if os.IsNotExist(certErr) || os.IsNotExist(keyErr) {
|
if os.IsNotExist(certErr) || os.IsNotExist(keyErr) {
|
||||||
if err := generateSelfSignedCert(); err != nil {
|
if err := generateSelfSignedCert(); err != nil {
|
||||||
log.Fatalf("Failed to generate self-signed certificate: %v", err)
|
errorLogger.Fatalf("Failed to generate self-signed certificate: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := loadCertificate(); err != nil {
|
if err := loadCertificate(); err != nil {
|
||||||
log.Fatalf("Failed to load certificate: %v", err)
|
errorLogger.Fatalf("Failed to load certificate: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
go monitorCertificates()
|
go monitorCertificates()
|
||||||
go monitorConfig()
|
go monitorConfig()
|
||||||
|
|
||||||
httpServer := &http.Server{
|
httpServer := &http.Server{
|
||||||
Addr: config.ListenHTTP,
|
Addr: config.ListenHTTP,
|
||||||
Handler: http.HandlerFunc(handler),
|
Handler: http.HandlerFunc(handler),
|
||||||
|
MaxHeaderBytes: 1 << 20,
|
||||||
|
ReadTimeout: 10 * time.Second,
|
||||||
|
WriteTimeout: 10 * time.Second,
|
||||||
}
|
}
|
||||||
|
|
||||||
tlsConfig := &tls.Config{
|
tlsConfig := &tls.Config{
|
||||||
@@ -86,20 +173,25 @@ func main() {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
httpsServer := &http.Server{
|
httpsServer := &http.Server{
|
||||||
Addr: config.ListenHTTPS,
|
Addr: config.ListenHTTPS,
|
||||||
Handler: http.HandlerFunc(handler),
|
Handler: http.HandlerFunc(handler),
|
||||||
TLSConfig: tlsConfig,
|
TLSConfig: tlsConfig,
|
||||||
|
MaxHeaderBytes: 1 << 20,
|
||||||
|
ReadTimeout: 10 * time.Second,
|
||||||
|
WriteTimeout: 10 * time.Second,
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
log.Printf("Starting HTTP server on %s", config.ListenHTTP)
|
log.Printf("Starting HTTP server on %s", config.ListenHTTP)
|
||||||
|
trafficLogger.Printf("Starting HTTP server on %s", config.ListenHTTP)
|
||||||
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||||
log.Fatalf("HTTP server error: %v", err)
|
errorLogger.Fatalf("HTTP server error: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
log.Printf("Starting HTTPS server on %s", config.ListenHTTPS)
|
log.Printf("Starting HTTPS server on %s", config.ListenHTTPS)
|
||||||
|
trafficLogger.Printf("Starting HTTPS server on %s", config.ListenHTTPS)
|
||||||
if err := httpsServer.ListenAndServeTLS("", ""); err != nil && err != http.ErrServerClosed {
|
if err := httpsServer.ListenAndServeTLS("", ""); err != nil && err != http.ErrServerClosed {
|
||||||
log.Fatalf("HTTPS server error: %v", err)
|
errorLogger.Fatalf("HTTPS server error: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,65 +1,245 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"compress/gzip"
|
||||||
|
"crypto/md5"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
"log"
|
"log"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/time/rate"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
transportPool = &http.Transport{
|
||||||
|
MaxIdleConns: 100,
|
||||||
|
MaxIdleConnsPerHost: 10,
|
||||||
|
IdleConnTimeout: 90 * time.Second,
|
||||||
|
DialContext: (&net.Dialer{Timeout: 30 * time.Second}).DialContext,
|
||||||
|
ResponseHeaderTimeout: 60 * time.Second,
|
||||||
|
TLSHandshakeTimeout: 10 * time.Second,
|
||||||
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||||
|
}
|
||||||
|
cache = make(map[string]cachedResponse)
|
||||||
|
defaultCacheTTL = 5 * time.Minute
|
||||||
|
cacheMutex sync.RWMutex
|
||||||
|
|
||||||
|
// Rate limiter per client IP
|
||||||
|
rateLimiters = make(map[string]*rate.Limiter)
|
||||||
|
rateMutex sync.RWMutex
|
||||||
|
rateLimit = rate.Limit(10) // 10 requests per second per client
|
||||||
|
rateBurst = 20 // Allow burst of 20 requests
|
||||||
|
)
|
||||||
|
|
||||||
|
type cachedResponse struct {
|
||||||
|
body []byte
|
||||||
|
headers http.Header
|
||||||
|
statusCode int
|
||||||
|
cachedAt time.Time
|
||||||
|
cacheDuration time.Duration
|
||||||
|
etag string
|
||||||
|
}
|
||||||
|
|
||||||
func getReverseProxy(target string, skipVerify bool) *httputil.ReverseProxy {
|
func getReverseProxy(target string, skipVerify bool) *httputil.ReverseProxy {
|
||||||
targetURL, err := url.Parse(target)
|
targetURL, err := url.Parse(target)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Error parsing target URL %s: %v", target, err)
|
log.Printf("Error parsing target URL %s: %v", target, err)
|
||||||
|
errorLogger.Printf("Error parsing target URL %s: %v", target, err)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
director := func(req *http.Request) {
|
director := func(req *http.Request) {
|
||||||
req.URL.Scheme = targetURL.Scheme
|
req.URL.Scheme = targetURL.Scheme
|
||||||
req.URL.Host = targetURL.Host
|
req.URL.Host = targetURL.Host
|
||||||
req.Host = targetURL.Host // Set Host header to match target
|
// Preserve client's Host header for session continuity
|
||||||
|
if req.Header.Get("Host") != "" {
|
||||||
// Preserve original headers
|
req.Host = req.Header.Get("Host")
|
||||||
for k, v := range req.Header {
|
} else {
|
||||||
if k != "Host" { // Host is set above
|
req.Host = targetURL.Host
|
||||||
req.Header[k] = v
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure the full path is preserved
|
// Preserve all request headers, including cookies
|
||||||
|
for k, v := range req.Header {
|
||||||
|
req.Header[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
// Preserve query parameters and path
|
||||||
|
req.URL.RawQuery = req.URL.RawQuery
|
||||||
if targetURL.Path != "" {
|
if targetURL.Path != "" {
|
||||||
req.URL.Path = strings.TrimPrefix(req.URL.Path, "/") // Avoid double slashes
|
req.URL.Path = strings.TrimPrefix(req.URL.Path, "/")
|
||||||
req.URL.Path = singleJoin(targetURL.Path, req.URL.Path)
|
req.URL.Path = singleJoin(targetURL.Path, req.URL.Path)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WebSocket support: pass upgrade headers
|
||||||
|
if strings.ToLower(req.Header.Get("Upgrade")) == "websocket" {
|
||||||
|
req.Header.Set("Connection", "Upgrade")
|
||||||
|
}
|
||||||
|
|
||||||
|
trafficLogger.Printf("Request: %s %s -> %s [Host: %s]", req.Method, req.URL.String(), target, req.Host)
|
||||||
}
|
}
|
||||||
|
|
||||||
transport := &http.Transport{
|
transport := transportPool
|
||||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: skipVerify},
|
if !skipVerify {
|
||||||
|
transport = &http.Transport{
|
||||||
|
MaxIdleConns: transportPool.MaxIdleConns,
|
||||||
|
MaxIdleConnsPerHost: transportPool.MaxIdleConnsPerHost,
|
||||||
|
IdleConnTimeout: transportPool.IdleConnTimeout,
|
||||||
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: false},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &httputil.ReverseProxy{
|
return &httputil.ReverseProxy{
|
||||||
Director: director,
|
Director: director,
|
||||||
Transport: transport,
|
Transport: transport,
|
||||||
|
ModifyResponse: func(resp *http.Response) error {
|
||||||
|
// Preserve all response headers, including Set-Cookie
|
||||||
|
for k, v := range resp.Header {
|
||||||
|
resp.Header[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle WebSocket upgrade
|
||||||
|
if strings.ToLower(resp.Header.Get("Upgrade")) == "websocket" {
|
||||||
|
return nil // No further modification needed for WebSocket
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compression if client supports it and response isn’t already compressed
|
||||||
|
if resp.Header.Get("Content-Encoding") == "" && strings.Contains(resp.Request.Header.Get("Accept-Encoding"), "gzip") {
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
errorLogger.Printf("Error reading response body for compression: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
resp.Body.Close()
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
gw := gzip.NewWriter(&buf)
|
||||||
|
if _, err := gw.Write(body); err != nil {
|
||||||
|
errorLogger.Printf("Error compressing response: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
gw.Close()
|
||||||
|
|
||||||
|
resp.Body = io.NopCloser(&buf)
|
||||||
|
resp.Header.Set("Content-Encoding", "gzip")
|
||||||
|
resp.Header.Del("Content-Length") // Length changes after compression
|
||||||
|
trafficLogger.Printf("Compressed response for %s", resp.Request.URL.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cache static content with ETag support
|
||||||
|
if shouldCache(resp) {
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
errorLogger.Printf("Error reading response body for caching: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
resp.Body.Close()
|
||||||
|
resp.Body = io.NopCloser(bytes.NewReader(body))
|
||||||
|
|
||||||
|
etag := resp.Header.Get("ETag")
|
||||||
|
if etag == "" {
|
||||||
|
etag = generateETag(body) // Simple ETag generation if not provided
|
||||||
|
}
|
||||||
|
|
||||||
|
cacheDuration := parseCacheControl(resp.Header.Get("Cache-Control"))
|
||||||
|
if cacheDuration == 0 {
|
||||||
|
cacheDuration = defaultCacheTTL
|
||||||
|
}
|
||||||
|
|
||||||
|
cacheMutex.Lock()
|
||||||
|
cache[resp.Request.URL.String()] = cachedResponse{
|
||||||
|
body: body,
|
||||||
|
headers: resp.Header.Clone(),
|
||||||
|
statusCode: resp.StatusCode,
|
||||||
|
cachedAt: time.Now(),
|
||||||
|
cacheDuration: cacheDuration,
|
||||||
|
etag: etag,
|
||||||
|
}
|
||||||
|
cacheMutex.Unlock()
|
||||||
|
trafficLogger.Printf("Cached response for %s [ETag: %s]", resp.Request.URL.String(), etag)
|
||||||
|
}
|
||||||
|
|
||||||
|
trafficLogger.Printf("Response: %s %d from %s", resp.Status, resp.StatusCode, target)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
|
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
|
||||||
log.Printf("Proxy error for %s: %v", r.Host, err)
|
log.Printf("Proxy error for %s: %v", r.Host, err)
|
||||||
|
errorLogger.Printf("Proxy error for %s: %v", r.Host, err)
|
||||||
http.Error(w, "Proxy error", http.StatusBadGateway)
|
http.Error(w, "Proxy error", http.StatusBadGateway)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// singleJoin ensures a single slash between path segments
|
func shouldCache(resp *http.Response) bool {
|
||||||
|
if resp.Request.Method != "GET" || resp.StatusCode != http.StatusOK {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
contentType := resp.Header.Get("Content-Type")
|
||||||
|
return strings.HasPrefix(contentType, "text/") || strings.HasPrefix(contentType, "image/") ||
|
||||||
|
strings.HasPrefix(contentType, "application/javascript") || strings.HasPrefix(contentType, "application/json")
|
||||||
|
}
|
||||||
|
|
||||||
func singleJoin(prefix, suffix string) string {
|
func singleJoin(prefix, suffix string) string {
|
||||||
prefix = strings.TrimSuffix(prefix, "/")
|
prefix = strings.TrimSuffix(prefix, "/")
|
||||||
suffix = strings.TrimPrefix(suffix, "/")
|
suffix = strings.TrimPrefix(suffix, "/")
|
||||||
return prefix + "/" + suffix
|
return prefix + "/" + suffix
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func generateETag(body []byte) string {
|
||||||
|
return fmt.Sprintf(`"%x"`, md5.Sum(body)) // Simple ETag based on MD5 hash
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseCacheControl(header string) time.Duration {
|
||||||
|
if header == "" {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
parts := strings.Split(header, ",")
|
||||||
|
for _, part := range parts {
|
||||||
|
if strings.Contains(part, "max-age=") {
|
||||||
|
ageStr := strings.TrimPrefix(part, "max-age=")
|
||||||
|
ageStr = strings.TrimSpace(ageStr)
|
||||||
|
if age, err := strconv.Atoi(ageStr); err == nil {
|
||||||
|
return time.Duration(age) * time.Second
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func getLimiter(ip string) *rate.Limiter {
|
||||||
|
rateMutex.Lock()
|
||||||
|
defer rateMutex.Unlock()
|
||||||
|
|
||||||
|
if limiter, exists := rateLimiters[ip]; exists {
|
||||||
|
return limiter
|
||||||
|
}
|
||||||
|
limiter := rate.NewLimiter(rateLimit, rateBurst)
|
||||||
|
rateLimiters[ip] = limiter
|
||||||
|
return limiter
|
||||||
|
}
|
||||||
|
|
||||||
func handler(w http.ResponseWriter, r *http.Request) {
|
func handler(w http.ResponseWriter, r *http.Request) {
|
||||||
configMux.RLock()
|
configMux.RLock()
|
||||||
defer configMux.RUnlock()
|
defer configMux.RUnlock()
|
||||||
|
|
||||||
|
// Rate limiting based on client IP
|
||||||
|
clientIP := r.RemoteAddr[:strings.LastIndex(r.RemoteAddr, ":")] // Strip port
|
||||||
|
limiter := getLimiter(clientIP)
|
||||||
|
if !limiter.Allow() {
|
||||||
|
http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
|
||||||
|
trafficLogger.Printf("Rate limited client %s", clientIP)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
target, exists := config.Routes[r.Host]
|
target, exists := config.Routes[r.Host]
|
||||||
skipVerify := config.TrustTarget[r.Host]
|
skipVerify := config.TrustTarget[r.Host]
|
||||||
noHTTPSRedirect := config.NoHTTPSRedirect[r.Host]
|
noHTTPSRedirect := config.NoHTTPSRedirect[r.Host]
|
||||||
@@ -73,21 +253,80 @@ func handler(w http.ResponseWriter, r *http.Request) {
|
|||||||
noHTTPSRedirect = config.NoHTTPSRedirect["*"]
|
noHTTPSRedirect = config.NoHTTPSRedirect["*"]
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if request is HTTP and target is HTTPS
|
|
||||||
isHTTPS := target[:5] == "https"
|
isHTTPS := target[:5] == "https"
|
||||||
isHTTPReq := r.TLS == nil // r.TLS is nil for HTTP, non-nil for HTTPS
|
isHTTPReq := r.TLS == nil
|
||||||
|
|
||||||
if isHTTPReq && isHTTPS && !noHTTPSRedirect {
|
if isHTTPReq && isHTTPS && !noHTTPSRedirect {
|
||||||
// Redirect to HTTPS version of the host
|
|
||||||
redirectURL := "https://" + r.Host + r.RequestURI
|
redirectURL := "https://" + r.Host + r.RequestURI
|
||||||
http.Redirect(w, r, redirectURL, http.StatusMovedPermanently)
|
http.Redirect(w, r, redirectURL, http.StatusMovedPermanently)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cacheKey := r.URL.String()
|
||||||
|
cacheMutex.RLock()
|
||||||
|
if cached, ok := cache[cacheKey]; ok && time.Since(cached.cachedAt) < cached.cacheDuration {
|
||||||
|
if etag := r.Header.Get("If-None-Match"); etag != "" && etag == cached.etag {
|
||||||
|
w.WriteHeader(http.StatusNotModified)
|
||||||
|
trafficLogger.Printf("Served 304 Not Modified from cache: %s [ETag: %s]", cacheKey, cached.etag)
|
||||||
|
} else {
|
||||||
|
for k, v := range cached.headers {
|
||||||
|
w.Header()[k] = v
|
||||||
|
}
|
||||||
|
w.WriteHeader(cached.statusCode)
|
||||||
|
w.Write(cached.body)
|
||||||
|
trafficLogger.Printf("Served from cache: %s [ETag: %s]", cacheKey, cached.etag)
|
||||||
|
}
|
||||||
|
cacheMutex.RUnlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cacheMutex.RUnlock()
|
||||||
|
|
||||||
proxy := getReverseProxy(target, skipVerify)
|
proxy := getReverseProxy(target, skipVerify)
|
||||||
if proxy == nil {
|
if proxy == nil {
|
||||||
http.Error(w, "Invalid target configuration", http.StatusInternalServerError)
|
http.Error(w, "Invalid target configuration", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
targetURL, err := url.Parse(target)
|
||||||
|
if err != nil {
|
||||||
|
errorLogger.Printf("Error parsing target URL %s: %v", target, err)
|
||||||
|
http.Error(w, "Invalid target configuration", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// WebSocket upgrade handling
|
||||||
|
if strings.ToLower(r.Header.Get("Upgrade")) == "websocket" {
|
||||||
|
hijacker, ok := w.(http.Hijacker)
|
||||||
|
if !ok {
|
||||||
|
http.Error(w, "WebSocket upgrade not supported", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
conn, _, err := hijacker.Hijack()
|
||||||
|
if err != nil {
|
||||||
|
errorLogger.Printf("Failed to hijack connection for WebSocket: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
targetConn, err := transportPool.Dial("tcp", targetURL.Host)
|
||||||
|
if err != nil {
|
||||||
|
errorLogger.Printf("Failed to dial target for WebSocket: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer targetConn.Close()
|
||||||
|
|
||||||
|
// Forward request to target
|
||||||
|
if err := r.Write(targetConn); err != nil {
|
||||||
|
errorLogger.Printf("Failed to forward WebSocket request: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pipe connections
|
||||||
|
go io.Copy(conn, targetConn)
|
||||||
|
io.Copy(targetConn, conn)
|
||||||
|
trafficLogger.Printf("WebSocket connection established: %s -> %s", r.Host, target)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
proxy.ServeHTTP(w, r)
|
proxy.ServeHTTP(w, r)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user