Files
2025-03-02 17:14:05 +00:00

346 lines
10 KiB
Go

package main
import (
"context"
"crypto/tls"
"log"
"net/http"
"os"
"os/signal"
"strings"
"sync"
"syscall"
"time"
"github.com/fsnotify/fsnotify"
"golangproxy/config"
"golangproxy/logger"
"golangproxy/proxy"
"golangproxy/server"
"golangproxy/ssl"
)
// Global variables for dynamic configuration and certificate updates
var (
configPath = "config.yaml"
routesMutex sync.RWMutex // Protects routes and defaultRoute
certMutex sync.RWMutex // Protects currentCert
currentConfig *config.Config // Current configuration
currentCert *tls.Certificate // Current SSL certificate
routes map[string]*proxy.Route // Host-specific routes
defaultRoute *proxy.Route // Wildcard route
watcher *fsnotify.Watcher // File watcher instance
)
// main initializes and runs the reverse proxy application
func main() {
// Initialize logging to file and terminal
logger.InitLogger()
log := logger.Logger
// Load initial configuration
var err error
currentConfig, err = config.LoadConfig(configPath)
if err != nil {
log.Fatalf("Error loading config: %v", err)
}
// Ensure SSL certificate and key files exist
err = ssl.EnsureCertFiles(currentConfig.CertFile, currentConfig.KeyFile)
if err != nil {
log.Fatalf("Error ensuring cert files: %v", err)
}
// Load initial SSL certificate
cert, err := tls.LoadX509KeyPair(currentConfig.CertFile, currentConfig.KeyFile)
if err != nil {
log.Fatalf("Error loading cert: %v", err)
}
certMutex.Lock()
currentCert = &cert
certMutex.Unlock()
// Initialize proxy routes from config
initializeRoutes(log)
// Start the simple web server in a goroutine
go server.StartServer()
// Configure HTTP server
httpServer := &http.Server{
Addr: currentConfig.ListenHTTP,
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
routesMutex.RLock()
route := getRoute(r.Host)
routesMutex.RUnlock()
if strings.HasPrefix(route.Target, "https://") && !route.NoHTTPSRedirect {
httpsURL := "https://" + r.Host + r.URL.Path
if r.URL.RawQuery != "" {
httpsURL += "?" + r.URL.RawQuery
}
http.Redirect(w, r, httpsURL, http.StatusMovedPermanently)
return
}
route.Handler.ServeHTTP(w, r) // Use Handler instead of Proxy
}),
ErrorLog: logger.Logger, // Add this to filter server-level errors (from previous fix)
}
// Configure HTTPS server
httpsServer := &http.Server{
Addr: currentConfig.ListenHTTPS,
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
routesMutex.RLock()
route := getRoute(r.Host)
routesMutex.RUnlock()
route.Handler.ServeHTTP(w, r) // Use Handler instead of Proxy
}),
TLSConfig: &tls.Config{
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
certMutex.RLock()
defer certMutex.RUnlock()
return currentCert, nil
},
},
ErrorLog: logger.Logger, // Add this to filter server-level errors (from previous fix)
}
// Start servers in goroutines
go func() {
log.Println("Starting HTTP server on", currentConfig.ListenHTTP)
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("HTTP server error: %v", err)
}
}()
go func() {
log.Println("Starting HTTPS server on", currentConfig.ListenHTTPS)
if err := httpsServer.ListenAndServeTLS("", ""); err != nil && err != http.ErrServerClosed {
log.Fatalf("HTTPS server error: %v", err)
}
}()
// Initialize file watcher
watcher, err = fsnotify.NewWatcher()
if err != nil {
log.Fatalf("Error creating watcher: %v", err)
}
defer watcher.Close()
// Watch initial config and cert files
err = watcher.Add(configPath)
if err != nil {
log.Fatalf("Error watching config file: %v", err)
}
err = watcher.Add(currentConfig.CertFile)
if err != nil {
log.Println("Error watching cert file:", err)
}
err = watcher.Add(currentConfig.KeyFile)
if err != nil {
log.Println("Error watching key file:", err)
}
// Handle file updates in a goroutine
go func() {
for {
select {
case event, ok := <-watcher.Events:
if !ok {
return
}
if event.Op&fsnotify.Write == fsnotify.Write {
switch event.Name {
case configPath:
log.Println("Config file changed, reloading...")
reloadConfig(log)
case currentConfig.CertFile, currentConfig.KeyFile:
log.Println("Cert files changed, reloading cert...")
reloadCert(log)
}
}
case err, ok := <-watcher.Errors:
if !ok {
return
}
log.Println("Watcher error:", err)
}
}
}()
// Handle graceful shutdown
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
<-quit
log.Println("Shutting down...")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := httpServer.Shutdown(ctx); err != nil {
log.Println("HTTP server shutdown error:", err)
}
if err := httpsServer.Shutdown(ctx); err != nil {
log.Println("HTTPS server shutdown error:", err)
}
}
// getRoute retrieves the appropriate proxy route for a host
func getRoute(host string) *proxy.Route {
routesMutex.RLock()
defer routesMutex.RUnlock()
if route, ok := routes[host]; ok {
return route
}
return defaultRoute
}
// initializeRoutes sets up the routes map and default route from the current config
func initializeRoutes(log *log.Logger) {
routesMutex.Lock()
defer routesMutex.Unlock()
routes = make(map[string]*proxy.Route)
for host, target := range currentConfig.Routes {
if host == "*" {
continue
}
trust := getConfigBool(currentConfig.TrustTarget, host)
noRedirect := getConfigBool(currentConfig.NoHTTPSRedirect, host)
route := proxy.CreateRoute(target, trust)
route.NoHTTPSRedirect = noRedirect
routes[host] = route
}
defaultTarget, ok := currentConfig.Routes["*"]
if !ok {
log.Fatal("Default route '*' not found in config")
}
defaultTrust := currentConfig.TrustTarget["*"]
defaultNoRedirect := currentConfig.NoHTTPSRedirect["*"]
defaultRoute = proxy.CreateRoute(defaultTarget, defaultTrust)
defaultRoute.NoHTTPSRedirect = defaultNoRedirect
}
// getConfigBool retrieves a boolean config value, falling back to '*' if host-specific value is absent
func getConfigBool(m map[string]bool, host string) bool {
if val, ok := m[host]; ok {
return val
}
return m["*"]
}
// reloadConfig reloads the configuration and updates routes and certs if necessary
func reloadConfig(log *log.Logger) {
newConfig, err := config.LoadConfig(configPath)
if err != nil {
log.Println("Error reloading config:", err)
return
}
// Log differences between old and new config
log.Println("Config file changed, reloading...")
logConfigChanges(log, currentConfig, newConfig)
// Store old cert file paths before updating config
oldCertFile := currentConfig.CertFile
oldKeyFile := currentConfig.KeyFile
certChanged := newConfig.CertFile != oldCertFile || newConfig.KeyFile != oldKeyFile
currentConfig = newConfig
// Update routes
initializeRoutes(log)
// Update certificates and watcher if paths changed
if certChanged {
reloadCert(log)
updateCertWatchers(log, oldCertFile, oldKeyFile)
}
}
// logConfigChanges logs the differences between old and new config
func logConfigChanges(log *log.Logger, oldConfig, newConfig *config.Config) {
if oldConfig.ListenHTTP != newConfig.ListenHTTP {
log.Printf("listen_http changed from %s to %s", oldConfig.ListenHTTP, newConfig.ListenHTTP)
}
if oldConfig.ListenHTTPS != newConfig.ListenHTTPS {
log.Printf("listen_https changed from %s to %s", oldConfig.ListenHTTPS, newConfig.ListenHTTPS)
}
if oldConfig.CertFile != newConfig.CertFile {
log.Printf("cert_file changed from %s to %s", oldConfig.CertFile, newConfig.CertFile)
}
if oldConfig.KeyFile != newConfig.KeyFile {
log.Printf("key_file changed from %s to %s", oldConfig.KeyFile, newConfig.KeyFile)
}
// Compare Routes
for key := range oldConfig.Routes {
if newVal, ok := newConfig.Routes[key]; !ok {
log.Printf("Route %s removed (was %s)", key, oldConfig.Routes[key])
} else if oldConfig.Routes[key] != newVal {
log.Printf("Route %s changed from %s to %s", key, oldConfig.Routes[key], newVal)
}
}
for key, newVal := range newConfig.Routes {
if _, ok := oldConfig.Routes[key]; !ok {
log.Printf("Route %s added: %s", key, newVal)
}
}
// Compare TrustTarget
for key := range oldConfig.TrustTarget {
if newVal, ok := newConfig.TrustTarget[key]; !ok {
log.Printf("trust_target %s removed (was %t)", key, oldConfig.TrustTarget[key])
} else if oldConfig.TrustTarget[key] != newVal {
log.Printf("trust_target %s changed from %t to %t", key, oldConfig.TrustTarget[key], newVal)
}
}
for key, newVal := range newConfig.TrustTarget {
if _, ok := oldConfig.TrustTarget[key]; !ok {
log.Printf("trust_target %s added: %t", key, newVal)
}
}
// Compare NoHTTPSRedirect
for key := range oldConfig.NoHTTPSRedirect {
if newVal, ok := newConfig.NoHTTPSRedirect[key]; !ok {
log.Printf("no_https_redirect %s removed (was %t)", key, oldConfig.NoHTTPSRedirect[key])
} else if oldConfig.NoHTTPSRedirect[key] != newVal {
log.Printf("no_https_redirect %s changed from %t to %t", key, oldConfig.NoHTTPSRedirect[key], newVal)
}
}
for key, newVal := range newConfig.NoHTTPSRedirect {
if _, ok := oldConfig.NoHTTPSRedirect[key]; !ok {
log.Printf("no_https_redirect %s added: %t", key, newVal)
}
}
}
// reloadCert reloads the SSL certificate from disk
func reloadCert(log *log.Logger) {
cert, err := tls.LoadX509KeyPair(currentConfig.CertFile, currentConfig.KeyFile)
if err != nil {
log.Println("Error reloading cert:", err)
return
}
certMutex.Lock()
currentCert = &cert
certMutex.Unlock()
}
// updateCertWatchers updates the file watcher for new cert file paths
func updateCertWatchers(log *log.Logger, oldCertFile, oldKeyFile string) {
if oldCertFile != currentConfig.CertFile {
watcher.Remove(oldCertFile)
if err := watcher.Add(currentConfig.CertFile); err != nil {
log.Println("Error watching new cert file:", err)
}
}
if oldKeyFile != currentConfig.KeyFile {
watcher.Remove(oldKeyFile)
if err := watcher.Add(currentConfig.KeyFile); err != nil {
log.Println("Error watching new key file:", err)
}
}
}