updated app build
This commit is contained in:
32
.gitignore
vendored
Normal file
32
.gitignore
vendored
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
# If you prefer the allow list template instead of the deny list, see community template:
|
||||||
|
# https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore
|
||||||
|
#
|
||||||
|
# Binaries for programs and plugins
|
||||||
|
*.exe
|
||||||
|
*.exe~
|
||||||
|
*.dll
|
||||||
|
*.so
|
||||||
|
*.dylib
|
||||||
|
|
||||||
|
# Test binary, built with `go test -c`
|
||||||
|
*.test
|
||||||
|
|
||||||
|
# Output of the go coverage tool, specifically when used with LiteIDE
|
||||||
|
*.out
|
||||||
|
|
||||||
|
# Dependency directories (remove the comment below to include it)
|
||||||
|
# vendor/
|
||||||
|
|
||||||
|
# Go workspace file
|
||||||
|
go.work
|
||||||
|
go.work.sum
|
||||||
|
|
||||||
|
# env file
|
||||||
|
.env
|
||||||
|
|
||||||
|
*.pem
|
||||||
|
.bcp/
|
||||||
|
bcp/
|
||||||
|
logs/
|
||||||
|
www/
|
||||||
|
*.bcp
|
||||||
20
README.md
20
README.md
@@ -25,17 +25,19 @@ no_https_redirect:
|
|||||||
'*': false
|
'*': false
|
||||||
main.example.com: true
|
main.example.com: true
|
||||||
```
|
```
|
||||||
### setup project
|
### setup project - powershell:
|
||||||
go mod init proxy
|
go mod init golangproxy ; go mod tidy
|
||||||
|
### setup project - cmd,bash:
|
||||||
|
go mod init golangproxy && go mod tidy
|
||||||
|
|
||||||
### Running Proxy app without compiling.
|
### Running Proxy app without compiling.
|
||||||
go run main.go config.go certificate.go proxy.go utils.go
|
go run main.go
|
||||||
|
|
||||||
### Building app:
|
### Building app:
|
||||||
go build -o proxy
|
go build -o build/golangproxy.exe
|
||||||
|
go build -ldflags="-H=windowsgui" -o build/golangproxy.exe
|
||||||
|
|
||||||
## Thins to do:
|
### Known issue.
|
||||||
- improve performance
|
- currently there is logic to proxy ip address target differently then hostname target.
|
||||||
- add logging to a file - failures, refreshed config or certificate, logging for proxied traffic
|
- this is because if same logic is applied for hostname as for IP target the IP target may experience issue with sessions ( I am not expert, so I do not know what causing it, but some sessions been dicsonnected, part of the website was not acting as normal)
|
||||||
- solve issue where some https proxy queries are rejected by other side (could be just issue with my test target)
|
- if logic applied for IP target is applied for hostname Target the website may not load or will be 404
|
||||||
- solve issue when the app randomly starts refershing cert without certificate being changed
|
|
||||||
@@ -1,125 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/rsa"
|
|
||||||
"crypto/tls"
|
|
||||||
"crypto/x509"
|
|
||||||
"crypto/x509/pkix"
|
|
||||||
"encoding/pem"
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
"math/big"
|
|
||||||
"os"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func generateSelfSignedCert() error {
|
|
||||||
if err := os.MkdirAll(certDir, 0755); err != nil {
|
|
||||||
return fmt.Errorf("failed to create certificate directory %s: %v", certDir, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to generate private key: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
notBefore := time.Now()
|
|
||||||
notAfter := notBefore.Add(365 * 24 * time.Hour)
|
|
||||||
serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to generate serial number: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
template := x509.Certificate{
|
|
||||||
SerialNumber: serialNumber,
|
|
||||||
Subject: pkix.Name{
|
|
||||||
Organization: []string{"Proxy Self-Signed"},
|
|
||||||
CommonName: "localhost",
|
|
||||||
},
|
|
||||||
NotBefore: notBefore,
|
|
||||||
NotAfter: notAfter,
|
|
||||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
|
||||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
|
||||||
BasicConstraintsValid: true,
|
|
||||||
DNSNames: []string{"localhost", "*.example.com"},
|
|
||||||
}
|
|
||||||
|
|
||||||
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create certificate: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
certOut, err := os.Create(certPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to open %s for writing: %v", certPath, err)
|
|
||||||
}
|
|
||||||
if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}); err != nil {
|
|
||||||
certOut.Close()
|
|
||||||
return fmt.Errorf("failed to encode certificate: %v", err)
|
|
||||||
}
|
|
||||||
certOut.Close()
|
|
||||||
|
|
||||||
keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to open %s for writing: %v", keyPath, err)
|
|
||||||
}
|
|
||||||
if err := pem.Encode(keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}); err != nil {
|
|
||||||
keyOut.Close()
|
|
||||||
return fmt.Errorf("failed to encode private key: %v", err)
|
|
||||||
}
|
|
||||||
keyOut.Close()
|
|
||||||
|
|
||||||
log.Printf("Generated self-signed certificate in %s", certDir)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func loadCertificate() error {
|
|
||||||
certFile, err := os.ReadFile(certPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to read certificate %s: %v", certPath, err)
|
|
||||||
}
|
|
||||||
keyFile, err := os.ReadFile(keyPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to read key %s: %v", keyPath, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
newCert, err := tls.X509KeyPair(certFile, keyFile)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to parse certificate: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
configMux.Lock()
|
|
||||||
cert = &newCert
|
|
||||||
configMux.Unlock()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func monitorCertificates() {
|
|
||||||
var lastModTime time.Time
|
|
||||||
for {
|
|
||||||
certInfo, err := os.Stat(certPath)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Error checking certificate: %v", err)
|
|
||||||
time.Sleep(5 * time.Second)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
keyInfo, err := os.Stat(keyPath)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Error checking key: %v", err)
|
|
||||||
time.Sleep(5 * time.Second)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if certInfo.ModTime() != lastModTime || keyInfo.ModTime() != lastModTime {
|
|
||||||
if err := loadCertificate(); err != nil {
|
|
||||||
log.Printf("Error reloading certificate: %v", err)
|
|
||||||
} else {
|
|
||||||
log.Println("Certificate reloaded successfully")
|
|
||||||
lastModTime = certInfo.ModTime()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
time.Sleep(5 * time.Second)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
154
bcp/v1/config.go
154
bcp/v1/config.go
@@ -1,154 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
"os"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"gopkg.in/yaml.v2"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Config struct {
|
|
||||||
ListenHTTP string `yaml:"listen_http"`
|
|
||||||
ListenHTTPS string `yaml:"listen_https"`
|
|
||||||
CertDir string `yaml:"cert_dir"`
|
|
||||||
CertFile string `yaml:"cert_file"`
|
|
||||||
KeyFile string `yaml:"key_file"`
|
|
||||||
Routes map[string]string `yaml:"routes"`
|
|
||||||
TrustTarget map[string]bool `yaml:"trust_target"`
|
|
||||||
NoHTTPSRedirect map[string]bool `yaml:"no_https_redirect"` // New field
|
|
||||||
}
|
|
||||||
|
|
||||||
func loadConfig() (Config, error) {
|
|
||||||
data, err := os.ReadFile(configPath)
|
|
||||||
if err != nil {
|
|
||||||
return Config{}, fmt.Errorf("failed to read config %s: %v", configPath, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var cfg Config
|
|
||||||
err = yaml.Unmarshal(data, &cfg)
|
|
||||||
if err != nil {
|
|
||||||
return Config{}, fmt.Errorf("failed to unmarshal config: %v", err)
|
|
||||||
}
|
|
||||||
return cfg, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func generateDefaultConfig() Config {
|
|
||||||
return Config{
|
|
||||||
ListenHTTP: ":80",
|
|
||||||
ListenHTTPS: ":443",
|
|
||||||
CertDir: "certificates", // default current directory creates certificates folder
|
|
||||||
CertFile: "certificate.pem",
|
|
||||||
KeyFile: "key.pem",
|
|
||||||
Routes: map[string]string{
|
|
||||||
"*": "http://127.0.0.1:80",
|
|
||||||
"main.example.com": "http://127.0.0.1:80",
|
|
||||||
},
|
|
||||||
TrustTarget: map[string]bool{
|
|
||||||
"*": true, // default trust all certificates
|
|
||||||
"main.example.com": false, // use only trusted certificate
|
|
||||||
},
|
|
||||||
NoHTTPSRedirect: map[string]bool{
|
|
||||||
"*": false, // Default: redirect to HTTPS
|
|
||||||
"main.example.com": true, // set to not redirect HTTP to HTTPS
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func saveConfig(cfg Config) error {
|
|
||||||
if err := os.MkdirAll(baseDir, 0755); err != nil {
|
|
||||||
return fmt.Errorf("failed to create base directory %s: %v", baseDir, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := yaml.Marshal(cfg)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to marshal config: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = os.WriteFile(configPath, data, 0644)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to write config to %s: %v", configPath, err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func monitorConfig() {
|
|
||||||
var lastModTime time.Time
|
|
||||||
for {
|
|
||||||
configInfo, err := os.Stat(configPath)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Error checking config file: %v", err)
|
|
||||||
time.Sleep(5 * time.Second)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if configInfo.ModTime() != lastModTime {
|
|
||||||
newConfig, err := loadConfig()
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Error reloading config: %v", err)
|
|
||||||
} else {
|
|
||||||
configMux.Lock()
|
|
||||||
if newConfig.ListenHTTP != config.ListenHTTP {
|
|
||||||
config.ListenHTTP = newConfig.ListenHTTP
|
|
||||||
log.Printf("Updated listen_http to %s", config.ListenHTTP)
|
|
||||||
}
|
|
||||||
if newConfig.ListenHTTPS != config.ListenHTTPS {
|
|
||||||
config.ListenHTTPS = newConfig.ListenHTTPS
|
|
||||||
log.Printf("Updated listen_https to %s", config.ListenHTTPS)
|
|
||||||
}
|
|
||||||
if newConfig.CertDir != config.CertDir || newConfig.CertFile != config.CertFile || newConfig.KeyFile != config.KeyFile {
|
|
||||||
config.CertDir = newConfig.CertDir
|
|
||||||
config.CertFile = newConfig.CertFile
|
|
||||||
config.KeyFile = newConfig.KeyFile
|
|
||||||
updatePaths()
|
|
||||||
if err := loadCertificate(); err != nil {
|
|
||||||
log.Printf("Error reloading certificate after path change: %v", err)
|
|
||||||
} else {
|
|
||||||
log.Println("Updated certificate paths and reloaded certificate")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for k, v := range newConfig.Routes {
|
|
||||||
if oldV, exists := config.Routes[k]; !exists || oldV != v {
|
|
||||||
config.Routes[k] = v
|
|
||||||
log.Printf("Updated route %s to %s", k, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for k := range config.Routes {
|
|
||||||
if _, exists := newConfig.Routes[k]; !exists {
|
|
||||||
delete(config.Routes, k)
|
|
||||||
log.Printf("Removed route %s", k)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for k, v := range newConfig.TrustTarget {
|
|
||||||
if oldV, exists := config.TrustTarget[k]; !exists || oldV != v {
|
|
||||||
config.TrustTarget[k] = v
|
|
||||||
log.Printf("Updated trust_target %s to %v", k, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for k := range config.TrustTarget {
|
|
||||||
if _, exists := newConfig.TrustTarget[k]; !exists {
|
|
||||||
delete(config.TrustTarget, k)
|
|
||||||
log.Printf("Removed trust_target %s", k)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for k, v := range newConfig.NoHTTPSRedirect {
|
|
||||||
if oldV, exists := config.NoHTTPSRedirect[k]; !exists || oldV != v {
|
|
||||||
config.NoHTTPSRedirect[k] = v
|
|
||||||
log.Printf("Updated no_https_redirect %s to %v", k, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for k := range config.NoHTTPSRedirect {
|
|
||||||
if _, exists := newConfig.NoHTTPSRedirect[k]; !exists {
|
|
||||||
delete(config.NoHTTPSRedirect, k)
|
|
||||||
log.Printf("Removed no_https_redirect %s", k)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
configMux.Unlock()
|
|
||||||
log.Println("Config reloaded successfully")
|
|
||||||
lastModTime = configInfo.ModTime()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
time.Sleep(5 * time.Second)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
105
bcp/v1/main.go
105
bcp/v1/main.go
@@ -1,105 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
// go run main.go config.go certificate.go proxy.go utils.go
|
|
||||||
// go build -o proxy
|
|
||||||
import (
|
|
||||||
"crypto/tls"
|
|
||||||
"log"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"runtime"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Shared global variables (declared only here)
|
|
||||||
var (
|
|
||||||
config Config
|
|
||||||
configMux sync.RWMutex
|
|
||||||
cert *tls.Certificate
|
|
||||||
baseDir string
|
|
||||||
certDir string
|
|
||||||
certPath string
|
|
||||||
keyPath string
|
|
||||||
configPath string
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
if runtime.GOOS == "windows" && len(os.Args) > 0 && filepath.Ext(os.Args[0]) == ".go" {
|
|
||||||
var err error
|
|
||||||
baseDir, err = os.Getwd()
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("Failed to get working directory: %v", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
exePath, err := os.Executable()
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("Failed to get executable path: %v", err)
|
|
||||||
}
|
|
||||||
baseDir = filepath.Dir(exePath)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
configPath = filepath.Join(baseDir, "config.yaml")
|
|
||||||
if _, err := os.Stat(configPath); os.IsNotExist(err) {
|
|
||||||
config = generateDefaultConfig()
|
|
||||||
if err := saveConfig(config); err != nil {
|
|
||||||
log.Fatalf("Failed to save default config: %v", err)
|
|
||||||
}
|
|
||||||
log.Println("Generated default config file")
|
|
||||||
} else {
|
|
||||||
cfg, err := loadConfig()
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("Failed to load config: %v", err)
|
|
||||||
}
|
|
||||||
config = cfg
|
|
||||||
}
|
|
||||||
|
|
||||||
updatePaths()
|
|
||||||
|
|
||||||
_, certErr := os.Stat(certPath)
|
|
||||||
_, keyErr := os.Stat(keyPath)
|
|
||||||
if os.IsNotExist(certErr) || os.IsNotExist(keyErr) {
|
|
||||||
if err := generateSelfSignedCert(); err != nil {
|
|
||||||
log.Fatalf("Failed to generate self-signed certificate: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := loadCertificate(); err != nil {
|
|
||||||
log.Fatalf("Failed to load certificate: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
go monitorCertificates()
|
|
||||||
go monitorConfig()
|
|
||||||
|
|
||||||
httpServer := &http.Server{
|
|
||||||
Addr: config.ListenHTTP,
|
|
||||||
Handler: http.HandlerFunc(handler),
|
|
||||||
}
|
|
||||||
|
|
||||||
tlsConfig := &tls.Config{
|
|
||||||
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
||||||
configMux.RLock()
|
|
||||||
defer configMux.RUnlock()
|
|
||||||
return cert, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
httpsServer := &http.Server{
|
|
||||||
Addr: config.ListenHTTPS,
|
|
||||||
Handler: http.HandlerFunc(handler),
|
|
||||||
TLSConfig: tlsConfig,
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
log.Printf("Starting HTTP server on %s", config.ListenHTTP)
|
|
||||||
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
|
||||||
log.Fatalf("HTTP server error: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
log.Printf("Starting HTTPS server on %s", config.ListenHTTPS)
|
|
||||||
if err := httpsServer.ListenAndServeTLS("", ""); err != nil && err != http.ErrServerClosed {
|
|
||||||
log.Fatalf("HTTPS server error: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,93 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/tls"
|
|
||||||
"log"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httputil"
|
|
||||||
"net/url"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
func getReverseProxy(target string, skipVerify bool) *httputil.ReverseProxy {
|
|
||||||
targetURL, err := url.Parse(target)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Error parsing target URL %s: %v", target, err)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
director := func(req *http.Request) {
|
|
||||||
req.URL.Scheme = targetURL.Scheme
|
|
||||||
req.URL.Host = targetURL.Host
|
|
||||||
req.Host = targetURL.Host // Set Host header to match target
|
|
||||||
|
|
||||||
// Preserve original headers
|
|
||||||
for k, v := range req.Header {
|
|
||||||
if k != "Host" { // Host is set above
|
|
||||||
req.Header[k] = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure the full path is preserved
|
|
||||||
if targetURL.Path != "" {
|
|
||||||
req.URL.Path = strings.TrimPrefix(req.URL.Path, "/") // Avoid double slashes
|
|
||||||
req.URL.Path = singleJoin(targetURL.Path, req.URL.Path)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
transport := &http.Transport{
|
|
||||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: skipVerify},
|
|
||||||
}
|
|
||||||
|
|
||||||
return &httputil.ReverseProxy{
|
|
||||||
Director: director,
|
|
||||||
Transport: transport,
|
|
||||||
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
|
|
||||||
log.Printf("Proxy error for %s: %v", r.Host, err)
|
|
||||||
http.Error(w, "Proxy error", http.StatusBadGateway)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// singleJoin ensures a single slash between path segments
|
|
||||||
func singleJoin(prefix, suffix string) string {
|
|
||||||
prefix = strings.TrimSuffix(prefix, "/")
|
|
||||||
suffix = strings.TrimPrefix(suffix, "/")
|
|
||||||
return prefix + "/" + suffix
|
|
||||||
}
|
|
||||||
|
|
||||||
func handler(w http.ResponseWriter, r *http.Request) {
|
|
||||||
configMux.RLock()
|
|
||||||
defer configMux.RUnlock()
|
|
||||||
|
|
||||||
target, exists := config.Routes[r.Host]
|
|
||||||
skipVerify := config.TrustTarget[r.Host]
|
|
||||||
noHTTPSRedirect := config.NoHTTPSRedirect[r.Host]
|
|
||||||
|
|
||||||
if !exists {
|
|
||||||
if target, exists = config.Routes["*"]; !exists {
|
|
||||||
http.Error(w, "Host not configured", http.StatusNotFound)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
skipVerify = config.TrustTarget["*"]
|
|
||||||
noHTTPSRedirect = config.NoHTTPSRedirect["*"]
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if request is HTTP and target is HTTPS
|
|
||||||
isHTTPS := target[:5] == "https"
|
|
||||||
isHTTPReq := r.TLS == nil // r.TLS is nil for HTTP, non-nil for HTTPS
|
|
||||||
|
|
||||||
if isHTTPReq && isHTTPS && !noHTTPSRedirect {
|
|
||||||
// Redirect to HTTPS version of the host
|
|
||||||
redirectURL := "https://" + r.Host + r.RequestURI
|
|
||||||
http.Redirect(w, r, redirectURL, http.StatusMovedPermanently)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
proxy := getReverseProxy(target, skipVerify)
|
|
||||||
if proxy == nil {
|
|
||||||
http.Error(w, "Invalid target configuration", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
proxy.ServeHTTP(w, r)
|
|
||||||
}
|
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import "path/filepath"
|
|
||||||
|
|
||||||
func updatePaths() {
|
|
||||||
certDir = filepath.Join(baseDir, config.CertDir)
|
|
||||||
certPath = filepath.Join(certDir, config.CertFile)
|
|
||||||
keyPath = filepath.Join(certDir, config.KeyFile)
|
|
||||||
configPath = filepath.Join(baseDir, "config.yaml")
|
|
||||||
}
|
|
||||||
@@ -1,124 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/rsa"
|
|
||||||
"crypto/tls"
|
|
||||||
"crypto/x509"
|
|
||||||
"crypto/x509/pkix"
|
|
||||||
"encoding/pem"
|
|
||||||
"fmt"
|
|
||||||
"math/big"
|
|
||||||
"os"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func generateSelfSignedCert() error {
|
|
||||||
if err := os.MkdirAll(certDir, 0755); err != nil {
|
|
||||||
return fmt.Errorf("failed to create certificate directory %s: %v", certDir, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to generate private key: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
notBefore := time.Now()
|
|
||||||
notAfter := notBefore.Add(365 * 24 * time.Hour)
|
|
||||||
serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to generate serial number: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
template := x509.Certificate{
|
|
||||||
SerialNumber: serialNumber,
|
|
||||||
Subject: pkix.Name{
|
|
||||||
Organization: []string{"Proxy Self-Signed"},
|
|
||||||
CommonName: "localhost",
|
|
||||||
},
|
|
||||||
NotBefore: notBefore,
|
|
||||||
NotAfter: notAfter,
|
|
||||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
|
||||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
|
||||||
BasicConstraintsValid: true,
|
|
||||||
DNSNames: []string{"localhost", "*.example.com"},
|
|
||||||
}
|
|
||||||
|
|
||||||
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create certificate: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
certOut, err := os.Create(certPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to open %s for writing: %v", certPath, err)
|
|
||||||
}
|
|
||||||
if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}); err != nil {
|
|
||||||
certOut.Close()
|
|
||||||
return fmt.Errorf("failed to encode certificate: %v", err)
|
|
||||||
}
|
|
||||||
certOut.Close()
|
|
||||||
|
|
||||||
keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_WRONLY, 0600)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to open %s for writing: %v", keyPath, err)
|
|
||||||
}
|
|
||||||
if err := pem.Encode(keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}); err != nil {
|
|
||||||
keyOut.Close()
|
|
||||||
return fmt.Errorf("failed to encode private key: %v", err)
|
|
||||||
}
|
|
||||||
keyOut.Close()
|
|
||||||
|
|
||||||
refreshLogger.Printf("Generated self-signed certificate in %s", certDir)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func loadCertificate() error {
|
|
||||||
certFile, err := os.ReadFile(certPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to read certificate %s: %v", certPath, err)
|
|
||||||
}
|
|
||||||
keyFile, err := os.ReadFile(keyPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to read key %s: %v", keyPath, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
newCert, err := tls.X509KeyPair(certFile, keyFile)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to parse certificate: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
configMux.Lock()
|
|
||||||
cert = &newCert
|
|
||||||
configMux.Unlock()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func monitorCertificates() {
|
|
||||||
var lastModTime time.Time
|
|
||||||
for {
|
|
||||||
certInfo, err := os.Stat(certPath)
|
|
||||||
if err != nil {
|
|
||||||
errorLogger.Printf("Error checking certificate: %v", err)
|
|
||||||
time.Sleep(5 * time.Second)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
keyInfo, err := os.Stat(keyPath)
|
|
||||||
if err != nil {
|
|
||||||
errorLogger.Printf("Error checking key: %v", err)
|
|
||||||
time.Sleep(5 * time.Second)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if certInfo.ModTime() != lastModTime || keyInfo.ModTime() != lastModTime {
|
|
||||||
if err := loadCertificate(); err != nil {
|
|
||||||
errorLogger.Printf("Error reloading certificate: %v", err)
|
|
||||||
} else {
|
|
||||||
refreshLogger.Println("Certificate reloaded successfully")
|
|
||||||
lastModTime = certInfo.ModTime()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
time.Sleep(5 * time.Second)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
153
bcp/v2/config.go
153
bcp/v2/config.go
@@ -1,153 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"gopkg.in/yaml.v2"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Config struct {
|
|
||||||
ListenHTTP string `yaml:"listen_http"`
|
|
||||||
ListenHTTPS string `yaml:"listen_https"`
|
|
||||||
CertDir string `yaml:"cert_dir"`
|
|
||||||
CertFile string `yaml:"cert_file"`
|
|
||||||
KeyFile string `yaml:"key_file"`
|
|
||||||
Routes map[string]string `yaml:"routes"`
|
|
||||||
TrustTarget map[string]bool `yaml:"trust_target"`
|
|
||||||
NoHTTPSRedirect map[string]bool `yaml:"no_https_redirect"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func loadConfig() (Config, error) {
|
|
||||||
data, err := os.ReadFile(configPath)
|
|
||||||
if err != nil {
|
|
||||||
return Config{}, fmt.Errorf("failed to read config %s: %v", configPath, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var cfg Config
|
|
||||||
err = yaml.Unmarshal(data, &cfg)
|
|
||||||
if err != nil {
|
|
||||||
return Config{}, fmt.Errorf("failed to unmarshal config: %v", err)
|
|
||||||
}
|
|
||||||
return cfg, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func generateDefaultConfig() Config {
|
|
||||||
return Config{
|
|
||||||
ListenHTTP: ":80",
|
|
||||||
ListenHTTPS: ":443",
|
|
||||||
CertDir: "certificates",
|
|
||||||
CertFile: "certificate.pem",
|
|
||||||
KeyFile: "key.pem",
|
|
||||||
Routes: map[string]string{
|
|
||||||
"*": "https://127.0.0.1:3000",
|
|
||||||
"main.example.com": "https://10.100.111.254:4444",
|
|
||||||
},
|
|
||||||
TrustTarget: map[string]bool{
|
|
||||||
"*": true,
|
|
||||||
"main.example.com": true,
|
|
||||||
},
|
|
||||||
NoHTTPSRedirect: map[string]bool{
|
|
||||||
"*": false,
|
|
||||||
"main.example.com": false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func saveConfig(cfg Config) error {
|
|
||||||
if err := os.MkdirAll(baseDir, 0755); err != nil {
|
|
||||||
return fmt.Errorf("failed to create base directory %s: %v", baseDir, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := yaml.Marshal(cfg)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to marshal config: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = os.WriteFile(configPath, data, 0644)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to write config to %s: %v", configPath, err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func monitorConfig() {
|
|
||||||
var lastModTime time.Time
|
|
||||||
for {
|
|
||||||
configInfo, err := os.Stat(configPath)
|
|
||||||
if err != nil {
|
|
||||||
errorLogger.Printf("Error checking config file: %v", err)
|
|
||||||
time.Sleep(5 * time.Second)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if configInfo.ModTime() != lastModTime {
|
|
||||||
newConfig, err := loadConfig()
|
|
||||||
if err != nil {
|
|
||||||
errorLogger.Printf("Error reloading config: %v", err)
|
|
||||||
} else {
|
|
||||||
configMux.Lock()
|
|
||||||
if newConfig.ListenHTTP != config.ListenHTTP {
|
|
||||||
config.ListenHTTP = newConfig.ListenHTTP
|
|
||||||
refreshLogger.Printf("Updated listen_http to %s", config.ListenHTTP)
|
|
||||||
}
|
|
||||||
if newConfig.ListenHTTPS != config.ListenHTTPS {
|
|
||||||
config.ListenHTTPS = newConfig.ListenHTTPS
|
|
||||||
refreshLogger.Printf("Updated listen_https to %s", config.ListenHTTPS)
|
|
||||||
}
|
|
||||||
if newConfig.CertDir != config.CertDir || newConfig.CertFile != config.CertFile || newConfig.KeyFile != config.KeyFile {
|
|
||||||
config.CertDir = newConfig.CertDir
|
|
||||||
config.CertFile = newConfig.CertFile
|
|
||||||
config.KeyFile = newConfig.KeyFile
|
|
||||||
updatePaths()
|
|
||||||
if err := loadCertificate(); err != nil {
|
|
||||||
errorLogger.Printf("Error reloading certificate after path change: %v", err)
|
|
||||||
} else {
|
|
||||||
refreshLogger.Println("Updated certificate paths and reloaded certificate")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for k, v := range newConfig.Routes {
|
|
||||||
if oldV, exists := config.Routes[k]; !exists || oldV != v {
|
|
||||||
config.Routes[k] = v
|
|
||||||
refreshLogger.Printf("Updated route %s to %s", k, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for k := range config.Routes {
|
|
||||||
if _, exists := newConfig.Routes[k]; !exists {
|
|
||||||
delete(config.Routes, k)
|
|
||||||
refreshLogger.Printf("Removed route %s", k)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for k, v := range newConfig.TrustTarget {
|
|
||||||
if oldV, exists := config.TrustTarget[k]; !exists || oldV != v {
|
|
||||||
config.TrustTarget[k] = v
|
|
||||||
refreshLogger.Printf("Updated trust_target %s to %v", k, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for k := range config.TrustTarget {
|
|
||||||
if _, exists := newConfig.TrustTarget[k]; !exists {
|
|
||||||
delete(config.TrustTarget, k)
|
|
||||||
refreshLogger.Printf("Removed trust_target %s", k)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for k, v := range newConfig.NoHTTPSRedirect {
|
|
||||||
if oldV, exists := config.NoHTTPSRedirect[k]; !exists || oldV != v {
|
|
||||||
config.NoHTTPSRedirect[k] = v
|
|
||||||
refreshLogger.Printf("Updated no_https_redirect %s to %v", k, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for k := range config.NoHTTPSRedirect {
|
|
||||||
if _, exists := newConfig.NoHTTPSRedirect[k]; !exists {
|
|
||||||
delete(config.NoHTTPSRedirect, k)
|
|
||||||
refreshLogger.Printf("Removed no_https_redirect %s", k)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
configMux.Unlock()
|
|
||||||
refreshLogger.Println("Config reloaded successfully")
|
|
||||||
lastModTime = configInfo.ModTime()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
time.Sleep(5 * time.Second)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
197
bcp/v2/main.go
197
bcp/v2/main.go
@@ -1,197 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/tls"
|
|
||||||
"io"
|
|
||||||
"log"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
config Config
|
|
||||||
configMux sync.RWMutex
|
|
||||||
cert *tls.Certificate
|
|
||||||
baseDir string
|
|
||||||
certDir string
|
|
||||||
certPath string
|
|
||||||
keyPath string
|
|
||||||
configPath string
|
|
||||||
|
|
||||||
// Loggers
|
|
||||||
errorLogger *log.Logger
|
|
||||||
refreshLogger *log.Logger
|
|
||||||
trafficLogger *log.Logger
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
// Get the absolute path of the running executable (or main.go in `go run`)
|
|
||||||
exePath, err := os.Executable()
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("Failed to get executable path: %v", err)
|
|
||||||
}
|
|
||||||
baseDir = filepath.Dir(exePath)
|
|
||||||
/*
|
|
||||||
if runtime.GOOS == "windows" && len(os.Args) > 0 && filepath.Ext(os.Args[0]) == ".go" {
|
|
||||||
var err error
|
|
||||||
baseDir, err = os.Getwd()
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("Failed to get working directory: %v", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
exePath, err := os.Executable()
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("Failed to get executable path: %v", err)
|
|
||||||
}
|
|
||||||
baseDir = filepath.Dir(exePath)
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
// Setup logging
|
|
||||||
setupLogging()
|
|
||||||
}
|
|
||||||
|
|
||||||
func setupLogging() {
|
|
||||||
logsDir := filepath.Join(baseDir, "logs")
|
|
||||||
if err := os.MkdirAll(logsDir, 0755); err != nil {
|
|
||||||
log.Fatalf("Failed to create logs directory: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Error log
|
|
||||||
errorFile, err := os.OpenFile(filepath.Join(logsDir, "errors.log"), os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("Failed to open errors.log: %v", err)
|
|
||||||
}
|
|
||||||
errorLogger = log.New(io.MultiWriter(os.Stdout, errorFile), "ERROR: ", log.Ldate|log.Ltime|log.Lshortfile)
|
|
||||||
|
|
||||||
// Refresh log
|
|
||||||
refreshFile, err := os.OpenFile(filepath.Join(logsDir, "refresh.log"), os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("Failed to open refresh.log: %v", err)
|
|
||||||
}
|
|
||||||
refreshLogger = log.New(io.MultiWriter(os.Stdout, refreshFile), "REFRESH: ", log.Ldate|log.Ltime|log.Lshortfile)
|
|
||||||
|
|
||||||
// Traffic log (daily rotation)
|
|
||||||
go manageTrafficLogs(logsDir)
|
|
||||||
}
|
|
||||||
|
|
||||||
func manageTrafficLogs(logsDir string) {
|
|
||||||
for {
|
|
||||||
dateStr := time.Now().Format("2006-01-02")
|
|
||||||
trafficFile, err := os.OpenFile(filepath.Join(logsDir, "traffic-"+dateStr+".log"), os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
|
||||||
if err != nil {
|
|
||||||
errorLogger.Printf("Failed to open traffic log: %v", err)
|
|
||||||
time.Sleep(1 * time.Minute)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
trafficLogger = log.New(io.MultiWriter(os.Stdout, trafficFile), "TRAFFIC: ", log.Ldate|log.Ltime)
|
|
||||||
|
|
||||||
// Cleanup old logs
|
|
||||||
cleanupOldLogs(logsDir)
|
|
||||||
|
|
||||||
// Wait until next day
|
|
||||||
nextDay := time.Now().Truncate(24 * time.Hour).Add(24 * time.Hour)
|
|
||||||
time.Sleep(time.Until(nextDay))
|
|
||||||
trafficFile.Close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func cleanupOldLogs(logsDir string) {
|
|
||||||
files, err := os.ReadDir(logsDir)
|
|
||||||
if err != nil {
|
|
||||||
errorLogger.Printf("Failed to read logs directory: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
cutoff := time.Now().AddDate(0, 0, -7) // 7 days ago
|
|
||||||
for _, file := range files {
|
|
||||||
if strings.HasPrefix(file.Name(), "traffic-") && file.Name() != "traffic-"+time.Now().Format("2006-01-02")+".log" {
|
|
||||||
info, err := file.Info()
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if info.ModTime().Before(cutoff) {
|
|
||||||
if err := os.Remove(filepath.Join(logsDir, file.Name())); err != nil {
|
|
||||||
errorLogger.Printf("Failed to remove old log %s: %v", file.Name(), err)
|
|
||||||
} else {
|
|
||||||
refreshLogger.Printf("Removed old traffic log: %s", file.Name())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
configPath = filepath.Join(baseDir, "config.yaml")
|
|
||||||
if _, err := os.Stat(configPath); os.IsNotExist(err) {
|
|
||||||
config = generateDefaultConfig()
|
|
||||||
if err := saveConfig(config); err != nil {
|
|
||||||
errorLogger.Fatalf("Failed to save default config: %v", err)
|
|
||||||
}
|
|
||||||
refreshLogger.Println("Generated default config file")
|
|
||||||
} else {
|
|
||||||
cfg, err := loadConfig()
|
|
||||||
if err != nil {
|
|
||||||
errorLogger.Fatalf("Failed to load config: %v", err)
|
|
||||||
}
|
|
||||||
config = cfg
|
|
||||||
}
|
|
||||||
|
|
||||||
updatePaths()
|
|
||||||
|
|
||||||
_, certErr := os.Stat(certPath)
|
|
||||||
_, keyErr := os.Stat(keyPath)
|
|
||||||
if os.IsNotExist(certErr) || os.IsNotExist(keyErr) {
|
|
||||||
if err := generateSelfSignedCert(); err != nil {
|
|
||||||
errorLogger.Fatalf("Failed to generate self-signed certificate: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := loadCertificate(); err != nil {
|
|
||||||
errorLogger.Fatalf("Failed to load certificate: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
go monitorCertificates()
|
|
||||||
go monitorConfig()
|
|
||||||
|
|
||||||
httpServer := &http.Server{
|
|
||||||
Addr: config.ListenHTTP,
|
|
||||||
Handler: http.HandlerFunc(handler),
|
|
||||||
MaxHeaderBytes: 1 << 20,
|
|
||||||
ReadTimeout: 10 * time.Second,
|
|
||||||
WriteTimeout: 10 * time.Second,
|
|
||||||
}
|
|
||||||
|
|
||||||
tlsConfig := &tls.Config{
|
|
||||||
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
||||||
configMux.RLock()
|
|
||||||
defer configMux.RUnlock()
|
|
||||||
return cert, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
httpsServer := &http.Server{
|
|
||||||
Addr: config.ListenHTTPS,
|
|
||||||
Handler: http.HandlerFunc(handler),
|
|
||||||
TLSConfig: tlsConfig,
|
|
||||||
MaxHeaderBytes: 1 << 20,
|
|
||||||
ReadTimeout: 10 * time.Second,
|
|
||||||
WriteTimeout: 10 * time.Second,
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
log.Printf("Starting HTTP server on %s", config.ListenHTTP)
|
|
||||||
trafficLogger.Printf("Starting HTTP server on %s", config.ListenHTTP)
|
|
||||||
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
|
||||||
errorLogger.Fatalf("HTTP server error: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
log.Printf("Starting HTTPS server on %s", config.ListenHTTPS)
|
|
||||||
trafficLogger.Printf("Starting HTTPS server on %s", config.ListenHTTPS)
|
|
||||||
if err := httpsServer.ListenAndServeTLS("", ""); err != nil && err != http.ErrServerClosed {
|
|
||||||
errorLogger.Fatalf("HTTPS server error: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
332
bcp/v2/proxy.go
332
bcp/v2/proxy.go
@@ -1,332 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"compress/gzip"
|
|
||||||
"crypto/md5"
|
|
||||||
"crypto/tls"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"log"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httputil"
|
|
||||||
"net/url"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/time/rate"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
transportPool = &http.Transport{
|
|
||||||
MaxIdleConns: 100,
|
|
||||||
MaxIdleConnsPerHost: 10,
|
|
||||||
IdleConnTimeout: 90 * time.Second,
|
|
||||||
DialContext: (&net.Dialer{Timeout: 30 * time.Second}).DialContext,
|
|
||||||
ResponseHeaderTimeout: 60 * time.Second,
|
|
||||||
TLSHandshakeTimeout: 10 * time.Second,
|
|
||||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
|
||||||
}
|
|
||||||
cache = make(map[string]cachedResponse)
|
|
||||||
defaultCacheTTL = 5 * time.Minute
|
|
||||||
cacheMutex sync.RWMutex
|
|
||||||
|
|
||||||
// Rate limiter per client IP
|
|
||||||
rateLimiters = make(map[string]*rate.Limiter)
|
|
||||||
rateMutex sync.RWMutex
|
|
||||||
rateLimit = rate.Limit(10) // 10 requests per second per client
|
|
||||||
rateBurst = 20 // Allow burst of 20 requests
|
|
||||||
)
|
|
||||||
|
|
||||||
type cachedResponse struct {
|
|
||||||
body []byte
|
|
||||||
headers http.Header
|
|
||||||
statusCode int
|
|
||||||
cachedAt time.Time
|
|
||||||
cacheDuration time.Duration
|
|
||||||
etag string
|
|
||||||
}
|
|
||||||
|
|
||||||
func getReverseProxy(target string, skipVerify bool) *httputil.ReverseProxy {
|
|
||||||
targetURL, err := url.Parse(target)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Error parsing target URL %s: %v", target, err)
|
|
||||||
errorLogger.Printf("Error parsing target URL %s: %v", target, err)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
director := func(req *http.Request) {
|
|
||||||
req.URL.Scheme = targetURL.Scheme
|
|
||||||
req.URL.Host = targetURL.Host
|
|
||||||
// Preserve client's Host header for session continuity
|
|
||||||
if req.Header.Get("Host") != "" {
|
|
||||||
req.Host = req.Header.Get("Host")
|
|
||||||
} else {
|
|
||||||
req.Host = targetURL.Host
|
|
||||||
}
|
|
||||||
|
|
||||||
// Preserve all request headers, including cookies
|
|
||||||
for k, v := range req.Header {
|
|
||||||
req.Header[k] = v
|
|
||||||
}
|
|
||||||
|
|
||||||
// Preserve query parameters and path
|
|
||||||
req.URL.RawQuery = req.URL.RawQuery
|
|
||||||
if targetURL.Path != "" {
|
|
||||||
req.URL.Path = strings.TrimPrefix(req.URL.Path, "/")
|
|
||||||
req.URL.Path = singleJoin(targetURL.Path, req.URL.Path)
|
|
||||||
}
|
|
||||||
|
|
||||||
// WebSocket support: pass upgrade headers
|
|
||||||
if strings.ToLower(req.Header.Get("Upgrade")) == "websocket" {
|
|
||||||
req.Header.Set("Connection", "Upgrade")
|
|
||||||
}
|
|
||||||
|
|
||||||
trafficLogger.Printf("Request: %s %s -> %s [Host: %s]", req.Method, req.URL.String(), target, req.Host)
|
|
||||||
}
|
|
||||||
|
|
||||||
transport := transportPool
|
|
||||||
if !skipVerify {
|
|
||||||
transport = &http.Transport{
|
|
||||||
MaxIdleConns: transportPool.MaxIdleConns,
|
|
||||||
MaxIdleConnsPerHost: transportPool.MaxIdleConnsPerHost,
|
|
||||||
IdleConnTimeout: transportPool.IdleConnTimeout,
|
|
||||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: false},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &httputil.ReverseProxy{
|
|
||||||
Director: director,
|
|
||||||
Transport: transport,
|
|
||||||
ModifyResponse: func(resp *http.Response) error {
|
|
||||||
// Preserve all response headers, including Set-Cookie
|
|
||||||
for k, v := range resp.Header {
|
|
||||||
resp.Header[k] = v
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle WebSocket upgrade
|
|
||||||
if strings.ToLower(resp.Header.Get("Upgrade")) == "websocket" {
|
|
||||||
return nil // No further modification needed for WebSocket
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compression if client supports it and response isn’t already compressed
|
|
||||||
if resp.Header.Get("Content-Encoding") == "" && strings.Contains(resp.Request.Header.Get("Accept-Encoding"), "gzip") {
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
errorLogger.Printf("Error reading response body for compression: %v", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
resp.Body.Close()
|
|
||||||
|
|
||||||
var buf bytes.Buffer
|
|
||||||
gw := gzip.NewWriter(&buf)
|
|
||||||
if _, err := gw.Write(body); err != nil {
|
|
||||||
errorLogger.Printf("Error compressing response: %v", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
gw.Close()
|
|
||||||
|
|
||||||
resp.Body = io.NopCloser(&buf)
|
|
||||||
resp.Header.Set("Content-Encoding", "gzip")
|
|
||||||
resp.Header.Del("Content-Length") // Length changes after compression
|
|
||||||
trafficLogger.Printf("Compressed response for %s", resp.Request.URL.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cache static content with ETag support
|
|
||||||
if shouldCache(resp) {
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
errorLogger.Printf("Error reading response body for caching: %v", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
resp.Body.Close()
|
|
||||||
resp.Body = io.NopCloser(bytes.NewReader(body))
|
|
||||||
|
|
||||||
etag := resp.Header.Get("ETag")
|
|
||||||
if etag == "" {
|
|
||||||
etag = generateETag(body) // Simple ETag generation if not provided
|
|
||||||
}
|
|
||||||
|
|
||||||
cacheDuration := parseCacheControl(resp.Header.Get("Cache-Control"))
|
|
||||||
if cacheDuration == 0 {
|
|
||||||
cacheDuration = defaultCacheTTL
|
|
||||||
}
|
|
||||||
|
|
||||||
cacheMutex.Lock()
|
|
||||||
cache[resp.Request.URL.String()] = cachedResponse{
|
|
||||||
body: body,
|
|
||||||
headers: resp.Header.Clone(),
|
|
||||||
statusCode: resp.StatusCode,
|
|
||||||
cachedAt: time.Now(),
|
|
||||||
cacheDuration: cacheDuration,
|
|
||||||
etag: etag,
|
|
||||||
}
|
|
||||||
cacheMutex.Unlock()
|
|
||||||
trafficLogger.Printf("Cached response for %s [ETag: %s]", resp.Request.URL.String(), etag)
|
|
||||||
}
|
|
||||||
|
|
||||||
trafficLogger.Printf("Response: %s %d from %s", resp.Status, resp.StatusCode, target)
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
|
|
||||||
log.Printf("Proxy error for %s: %v", r.Host, err)
|
|
||||||
errorLogger.Printf("Proxy error for %s: %v", r.Host, err)
|
|
||||||
http.Error(w, "Proxy error", http.StatusBadGateway)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func shouldCache(resp *http.Response) bool {
|
|
||||||
if resp.Request.Method != "GET" || resp.StatusCode != http.StatusOK {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
contentType := resp.Header.Get("Content-Type")
|
|
||||||
return strings.HasPrefix(contentType, "text/") || strings.HasPrefix(contentType, "image/") ||
|
|
||||||
strings.HasPrefix(contentType, "application/javascript") || strings.HasPrefix(contentType, "application/json")
|
|
||||||
}
|
|
||||||
|
|
||||||
func singleJoin(prefix, suffix string) string {
|
|
||||||
prefix = strings.TrimSuffix(prefix, "/")
|
|
||||||
suffix = strings.TrimPrefix(suffix, "/")
|
|
||||||
return prefix + "/" + suffix
|
|
||||||
}
|
|
||||||
|
|
||||||
func generateETag(body []byte) string {
|
|
||||||
return fmt.Sprintf(`"%x"`, md5.Sum(body)) // Simple ETag based on MD5 hash
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseCacheControl(header string) time.Duration {
|
|
||||||
if header == "" {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
parts := strings.Split(header, ",")
|
|
||||||
for _, part := range parts {
|
|
||||||
if strings.Contains(part, "max-age=") {
|
|
||||||
ageStr := strings.TrimPrefix(part, "max-age=")
|
|
||||||
ageStr = strings.TrimSpace(ageStr)
|
|
||||||
if age, err := strconv.Atoi(ageStr); err == nil {
|
|
||||||
return time.Duration(age) * time.Second
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func getLimiter(ip string) *rate.Limiter {
|
|
||||||
rateMutex.Lock()
|
|
||||||
defer rateMutex.Unlock()
|
|
||||||
|
|
||||||
if limiter, exists := rateLimiters[ip]; exists {
|
|
||||||
return limiter
|
|
||||||
}
|
|
||||||
limiter := rate.NewLimiter(rateLimit, rateBurst)
|
|
||||||
rateLimiters[ip] = limiter
|
|
||||||
return limiter
|
|
||||||
}
|
|
||||||
|
|
||||||
func handler(w http.ResponseWriter, r *http.Request) {
|
|
||||||
configMux.RLock()
|
|
||||||
defer configMux.RUnlock()
|
|
||||||
|
|
||||||
// Rate limiting based on client IP
|
|
||||||
clientIP := r.RemoteAddr[:strings.LastIndex(r.RemoteAddr, ":")] // Strip port
|
|
||||||
limiter := getLimiter(clientIP)
|
|
||||||
if !limiter.Allow() {
|
|
||||||
http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
|
|
||||||
trafficLogger.Printf("Rate limited client %s", clientIP)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
target, exists := config.Routes[r.Host]
|
|
||||||
skipVerify := config.TrustTarget[r.Host]
|
|
||||||
noHTTPSRedirect := config.NoHTTPSRedirect[r.Host]
|
|
||||||
|
|
||||||
if !exists {
|
|
||||||
if target, exists = config.Routes["*"]; !exists {
|
|
||||||
http.Error(w, "Host not configured", http.StatusNotFound)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
skipVerify = config.TrustTarget["*"]
|
|
||||||
noHTTPSRedirect = config.NoHTTPSRedirect["*"]
|
|
||||||
}
|
|
||||||
|
|
||||||
isHTTPS := target[:5] == "https"
|
|
||||||
isHTTPReq := r.TLS == nil
|
|
||||||
|
|
||||||
if isHTTPReq && isHTTPS && !noHTTPSRedirect {
|
|
||||||
redirectURL := "https://" + r.Host + r.RequestURI
|
|
||||||
http.Redirect(w, r, redirectURL, http.StatusMovedPermanently)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
cacheKey := r.URL.String()
|
|
||||||
cacheMutex.RLock()
|
|
||||||
if cached, ok := cache[cacheKey]; ok && time.Since(cached.cachedAt) < cached.cacheDuration {
|
|
||||||
if etag := r.Header.Get("If-None-Match"); etag != "" && etag == cached.etag {
|
|
||||||
w.WriteHeader(http.StatusNotModified)
|
|
||||||
trafficLogger.Printf("Served 304 Not Modified from cache: %s [ETag: %s]", cacheKey, cached.etag)
|
|
||||||
} else {
|
|
||||||
for k, v := range cached.headers {
|
|
||||||
w.Header()[k] = v
|
|
||||||
}
|
|
||||||
w.WriteHeader(cached.statusCode)
|
|
||||||
w.Write(cached.body)
|
|
||||||
trafficLogger.Printf("Served from cache: %s [ETag: %s]", cacheKey, cached.etag)
|
|
||||||
}
|
|
||||||
cacheMutex.RUnlock()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
cacheMutex.RUnlock()
|
|
||||||
|
|
||||||
proxy := getReverseProxy(target, skipVerify)
|
|
||||||
if proxy == nil {
|
|
||||||
http.Error(w, "Invalid target configuration", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
targetURL, err := url.Parse(target)
|
|
||||||
if err != nil {
|
|
||||||
errorLogger.Printf("Error parsing target URL %s: %v", target, err)
|
|
||||||
http.Error(w, "Invalid target configuration", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// WebSocket upgrade handling
|
|
||||||
if strings.ToLower(r.Header.Get("Upgrade")) == "websocket" {
|
|
||||||
hijacker, ok := w.(http.Hijacker)
|
|
||||||
if !ok {
|
|
||||||
http.Error(w, "WebSocket upgrade not supported", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
conn, _, err := hijacker.Hijack()
|
|
||||||
if err != nil {
|
|
||||||
errorLogger.Printf("Failed to hijack connection for WebSocket: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
targetConn, err := transportPool.Dial("tcp", targetURL.Host)
|
|
||||||
if err != nil {
|
|
||||||
errorLogger.Printf("Failed to dial target for WebSocket: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer targetConn.Close()
|
|
||||||
|
|
||||||
// Forward request to target
|
|
||||||
if err := r.Write(targetConn); err != nil {
|
|
||||||
errorLogger.Printf("Failed to forward WebSocket request: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pipe connections
|
|
||||||
go io.Copy(conn, targetConn)
|
|
||||||
io.Copy(targetConn, conn)
|
|
||||||
trafficLogger.Printf("WebSocket connection established: %s -> %s", r.Host, target)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
proxy.ServeHTTP(w, r)
|
|
||||||
}
|
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import "path/filepath"
|
|
||||||
|
|
||||||
func updatePaths() {
|
|
||||||
certDir = filepath.Join(baseDir, config.CertDir)
|
|
||||||
certPath = filepath.Join(certDir, config.CertFile)
|
|
||||||
keyPath = filepath.Join(certDir, config.KeyFile)
|
|
||||||
configPath = filepath.Join(baseDir, "config.yaml")
|
|
||||||
}
|
|
||||||
137
certificate.go
137
certificate.go
@@ -1,137 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/rsa"
|
|
||||||
"crypto/tls"
|
|
||||||
"crypto/x509"
|
|
||||||
"crypto/x509/pkix"
|
|
||||||
"encoding/pem"
|
|
||||||
"fmt"
|
|
||||||
"math/big"
|
|
||||||
"os"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// generateSelfSignedCert creates a self-signed TLS certificate if one doesn’t exist
|
|
||||||
func generateSelfSignedCert() error {
|
|
||||||
// Ensure certificate directory exists
|
|
||||||
if err := os.MkdirAll(certDir, 0755); err != nil {
|
|
||||||
return fmt.Errorf("failed to create certificate directory %s: %v", certDir, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate a 2048-bit RSA private key
|
|
||||||
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to generate private key: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set certificate validity period (1 year)
|
|
||||||
notBefore := time.Now()
|
|
||||||
notAfter := notBefore.Add(365 * 24 * time.Hour)
|
|
||||||
|
|
||||||
// Generate a random serial number
|
|
||||||
serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to generate serial number: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Define certificate template with basic attributes
|
|
||||||
template := x509.Certificate{
|
|
||||||
SerialNumber: serialNumber,
|
|
||||||
Subject: pkix.Name{
|
|
||||||
Organization: []string{"Proxy Self-Signed"},
|
|
||||||
CommonName: "localhost",
|
|
||||||
},
|
|
||||||
NotBefore: notBefore,
|
|
||||||
NotAfter: notAfter,
|
|
||||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
|
||||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
|
||||||
BasicConstraintsValid: true,
|
|
||||||
DNSNames: []string{"localhost", "*.example.com"}, // Supported domains
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create self-signed certificate
|
|
||||||
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create certificate: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write certificate to file
|
|
||||||
certOut, err := os.Create(certPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to open %s for writing: %v", certPath, err)
|
|
||||||
}
|
|
||||||
if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}); err != nil {
|
|
||||||
certOut.Close()
|
|
||||||
return fmt.Errorf("failed to encode certificate: %v", err)
|
|
||||||
}
|
|
||||||
certOut.Close()
|
|
||||||
|
|
||||||
// Write private key to file
|
|
||||||
keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to open %s for writing: %v", keyPath, err)
|
|
||||||
}
|
|
||||||
if err := pem.Encode(keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}); err != nil {
|
|
||||||
keyOut.Close()
|
|
||||||
return fmt.Errorf("failed to encode private key: %v", err)
|
|
||||||
}
|
|
||||||
keyOut.Close()
|
|
||||||
|
|
||||||
refreshLogger.Printf("Generated self-signed certificate in %s", certDir)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// loadCertificate reads and parses the certificate and key files into memory
|
|
||||||
func loadCertificate() error {
|
|
||||||
certFile, err := os.ReadFile(certPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to read certificate %s: %v", certPath, err)
|
|
||||||
}
|
|
||||||
keyFile, err := os.ReadFile(keyPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to read key %s: %v", keyPath, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
newCert, err := tls.X509KeyPair(certFile, keyFile)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to parse certificate: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
configMux.Lock()
|
|
||||||
cert = &newCert
|
|
||||||
configMux.Unlock()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// monitorCertificates watches for changes to certificate/key files and reloads them
|
|
||||||
func monitorCertificates() {
|
|
||||||
var lastModTime time.Time
|
|
||||||
for {
|
|
||||||
certInfo, err := os.Stat(certPath)
|
|
||||||
if err != nil {
|
|
||||||
errorLogger.Printf("Error checking certificate: %v", err)
|
|
||||||
time.Sleep(5 * time.Second)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
keyInfo, err := os.Stat(keyPath)
|
|
||||||
if err != nil {
|
|
||||||
errorLogger.Printf("Error checking key: %v", err)
|
|
||||||
time.Sleep(5 * time.Second)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reload certificate if either file has changed
|
|
||||||
if certInfo.ModTime() != lastModTime || keyInfo.ModTime() != lastModTime {
|
|
||||||
if err := loadCertificate(); err != nil {
|
|
||||||
errorLogger.Printf("Error reloading certificate: %v", err)
|
|
||||||
} else {
|
|
||||||
refreshLogger.Println("Certificate reloaded successfully")
|
|
||||||
lastModTime = certInfo.ModTime()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
time.Sleep(5 * time.Second) // Poll every 5 seconds
|
|
||||||
}
|
|
||||||
}
|
|
||||||
160
config.go
160
config.go
@@ -1,160 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"gopkg.in/yaml.v2"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Config defines the structure of the proxy configuration loaded from config.yaml
|
|
||||||
type Config struct {
|
|
||||||
ListenHTTP string `yaml:"listen_http"` // Port for HTTP server (e.g., ":80")
|
|
||||||
ListenHTTPS string `yaml:"listen_https"` // Port for HTTPS server (e.g., ":443")
|
|
||||||
CertDir string `yaml:"cert_dir"` // Directory for certificate files
|
|
||||||
CertFile string `yaml:"cert_file"` // Certificate filename
|
|
||||||
KeyFile string `yaml:"key_file"` // Private key filename
|
|
||||||
Routes map[string]string `yaml:"routes"` // Mapping of hostnames to target URLs
|
|
||||||
TrustTarget map[string]bool `yaml:"trust_target"` // Whether to skip TLS verification for targets
|
|
||||||
NoHTTPSRedirect map[string]bool `yaml:"no_https_redirect"` // Whether to skip HTTP->HTTPS redirect
|
|
||||||
}
|
|
||||||
|
|
||||||
// loadConfig reads and parses the config.yaml file
|
|
||||||
func loadConfig() (Config, error) {
|
|
||||||
data, err := os.ReadFile(configPath)
|
|
||||||
if err != nil {
|
|
||||||
return Config{}, fmt.Errorf("failed to read config %s: %v", configPath, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var cfg Config
|
|
||||||
if err := yaml.Unmarshal(data, &cfg); err != nil {
|
|
||||||
return Config{}, fmt.Errorf("failed to unmarshal config: %v", err)
|
|
||||||
}
|
|
||||||
return cfg, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// generateDefaultConfig creates a default configuration if config.yaml doesn’t exist
|
|
||||||
func generateDefaultConfig() Config {
|
|
||||||
return Config{
|
|
||||||
ListenHTTP: ":80",
|
|
||||||
ListenHTTPS: ":443",
|
|
||||||
CertDir: "./certificate",
|
|
||||||
CertFile: "certificate.pem",
|
|
||||||
KeyFile: "key.pem",
|
|
||||||
Routes: map[string]string{
|
|
||||||
"*": "https://127.0.0.1:3000", // Wildcard route
|
|
||||||
"main.example.com": "https://10.100.111.254:4444", // Specific route
|
|
||||||
},
|
|
||||||
TrustTarget: map[string]bool{
|
|
||||||
"*": true, // Skip TLS verification by default
|
|
||||||
"main.example.com": true,
|
|
||||||
},
|
|
||||||
NoHTTPSRedirect: map[string]bool{
|
|
||||||
"*": false, // Redirect HTTP to HTTPS by default
|
|
||||||
"main.example.com": false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// saveConfig writes the configuration to config.yaml
|
|
||||||
func saveConfig(cfg Config) error {
|
|
||||||
if err := os.MkdirAll(baseDir, 0755); err != nil {
|
|
||||||
return fmt.Errorf("failed to create base directory %s: %v", baseDir, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := yaml.Marshal(cfg)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to marshal config: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := os.WriteFile(configPath, data, 0644); err != nil {
|
|
||||||
return fmt.Errorf("failed to write config to %s: %v", configPath, err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// monitorConfig watches config.yaml for changes and updates the in-memory config
|
|
||||||
func monitorConfig() {
|
|
||||||
var lastModTime time.Time
|
|
||||||
for {
|
|
||||||
configInfo, err := os.Stat(configPath)
|
|
||||||
if err != nil {
|
|
||||||
errorLogger.Printf("Error checking config file: %v", err)
|
|
||||||
time.Sleep(5 * time.Second)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if configInfo.ModTime() != lastModTime {
|
|
||||||
newConfig, err := loadConfig()
|
|
||||||
if err != nil {
|
|
||||||
errorLogger.Printf("Error reloading config: %v", err)
|
|
||||||
} else {
|
|
||||||
configMux.Lock()
|
|
||||||
// Update individual fields only if they’ve changed
|
|
||||||
if newConfig.ListenHTTP != config.ListenHTTP {
|
|
||||||
config.ListenHTTP = newConfig.ListenHTTP
|
|
||||||
refreshLogger.Printf("Updated listen_http to %s", config.ListenHTTP)
|
|
||||||
}
|
|
||||||
if newConfig.ListenHTTPS != config.ListenHTTPS {
|
|
||||||
config.ListenHTTPS = newConfig.ListenHTTPS
|
|
||||||
refreshLogger.Printf("Updated listen_https to %s", config.ListenHTTPS)
|
|
||||||
}
|
|
||||||
if newConfig.CertDir != config.CertDir || newConfig.CertFile != config.CertFile || newConfig.KeyFile != config.KeyFile {
|
|
||||||
config.CertDir = newConfig.CertDir
|
|
||||||
config.CertFile = newConfig.CertFile
|
|
||||||
config.KeyFile = newConfig.KeyFile
|
|
||||||
updatePaths()
|
|
||||||
if err := loadCertificate(); err != nil {
|
|
||||||
errorLogger.Printf("Error reloading certificate after path change: %v", err)
|
|
||||||
} else {
|
|
||||||
refreshLogger.Println("Updated certificate paths and reloaded certificate")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Update routes
|
|
||||||
for k, v := range newConfig.Routes {
|
|
||||||
if oldV, exists := config.Routes[k]; !exists || oldV != v {
|
|
||||||
config.Routes[k] = v
|
|
||||||
refreshLogger.Printf("Updated route %s to %s", k, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for k := range config.Routes {
|
|
||||||
if _, exists := newConfig.Routes[k]; !exists {
|
|
||||||
delete(config.Routes, k)
|
|
||||||
refreshLogger.Printf("Removed route %s", k)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Update trust_target
|
|
||||||
for k, v := range newConfig.TrustTarget {
|
|
||||||
if oldV, exists := config.TrustTarget[k]; !exists || oldV != v {
|
|
||||||
config.TrustTarget[k] = v
|
|
||||||
refreshLogger.Printf("Updated trust_target %s to %v", k, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for k := range config.TrustTarget {
|
|
||||||
if _, exists := newConfig.TrustTarget[k]; !exists {
|
|
||||||
delete(config.TrustTarget, k)
|
|
||||||
refreshLogger.Printf("Removed trust_target %s", k)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Update no_https_redirect
|
|
||||||
for k, v := range newConfig.NoHTTPSRedirect {
|
|
||||||
if oldV, exists := config.NoHTTPSRedirect[k]; !exists || oldV != v {
|
|
||||||
config.NoHTTPSRedirect[k] = v
|
|
||||||
refreshLogger.Printf("Updated no_https_redirect %s to %v", k, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for k := range config.NoHTTPSRedirect {
|
|
||||||
if _, exists := newConfig.NoHTTPSRedirect[k]; !exists {
|
|
||||||
delete(config.NoHTTPSRedirect, k)
|
|
||||||
refreshLogger.Printf("Removed no_https_redirect %s", k)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
configMux.Unlock()
|
|
||||||
refreshLogger.Println("Config reloaded successfully")
|
|
||||||
lastModTime = configInfo.ModTime()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
time.Sleep(5 * time.Second) // Poll every 5 seconds
|
|
||||||
}
|
|
||||||
}
|
|
||||||
11
go.mod
11
go.mod
@@ -1,11 +0,0 @@
|
|||||||
module main
|
|
||||||
|
|
||||||
go 1.24.0
|
|
||||||
|
|
||||||
require gopkg.in/yaml.v2 v2.4.0
|
|
||||||
|
|
||||||
require (
|
|
||||||
github.com/fsnotify/fsnotify v1.8.0 // indirect
|
|
||||||
golang.org/x/sys v0.13.0 // indirect
|
|
||||||
golang.org/x/time v0.10.0 // indirect
|
|
||||||
)
|
|
||||||
10
go.sum
10
go.sum
@@ -1,10 +0,0 @@
|
|||||||
github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/8M=
|
|
||||||
github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
|
||||||
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
|
|
||||||
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|
||||||
golang.org/x/time v0.10.0 h1:3usCWA8tQn0L8+hFJQNgzpWbd89begxN66o1Ojdn5L4=
|
|
||||||
golang.org/x/time v0.10.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
|
||||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
|
||||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
|
||||||
BIN
golangproxy/build/icon.ico
Normal file
BIN
golangproxy/build/icon.ico
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 264 KiB |
BIN
golangproxy/build/icon.png
Normal file
BIN
golangproxy/build/icon.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 21 KiB |
16
golangproxy/config.yaml
Normal file
16
golangproxy/config.yaml
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
listen_http: :80
|
||||||
|
listen_https: :443
|
||||||
|
cert_file: ./crt/certificate.pem
|
||||||
|
key_file: ./crt/key.pem
|
||||||
|
routes:
|
||||||
|
'*': http://127.0.0.1:61147
|
||||||
|
gg.example.com: https://example.com:443
|
||||||
|
main.example.com: https://10.100.111.254:4444
|
||||||
|
trust_target:
|
||||||
|
'*': true
|
||||||
|
gg.example.com: false
|
||||||
|
main.example.com: true
|
||||||
|
no_https_redirect:
|
||||||
|
'*': false
|
||||||
|
gg.example.com: true
|
||||||
|
main.example.com: false
|
||||||
65
golangproxy/config/config.go
Normal file
65
golangproxy/config/config.go
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Config represents the application configuration
|
||||||
|
type Config struct {
|
||||||
|
ListenHTTP string `yaml:"listen_http"` // HTTP listen address (e.g., ":80")
|
||||||
|
ListenHTTPS string `yaml:"listen_https"` // HTTPS listen address (e.g., ":443")
|
||||||
|
CertFile string `yaml:"cert_file"` // Path to SSL certificate
|
||||||
|
KeyFile string `yaml:"key_file"` // Path to SSL key
|
||||||
|
Routes map[string]string `yaml:"routes"` // Host to target URL mappings
|
||||||
|
TrustTarget map[string]bool `yaml:"trust_target"` // Whether to trust invalid target certs
|
||||||
|
NoHTTPSRedirect map[string]bool `yaml:"no_https_redirect"` // Disable HTTP to HTTPS redirect
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadConfig loads the config from file or creates a default one
|
||||||
|
func LoadConfig(configPath string) (*Config, error) {
|
||||||
|
if _, err := os.Stat(configPath); os.IsNotExist(err) {
|
||||||
|
// Create default configuration
|
||||||
|
defaultConfig := &Config{
|
||||||
|
ListenHTTP: ":80",
|
||||||
|
ListenHTTPS: ":443",
|
||||||
|
CertFile: "./crt/certificate.pem",
|
||||||
|
KeyFile: "./crt/key.pem",
|
||||||
|
Routes: map[string]string{
|
||||||
|
"*": "http://127.0.0.1:61147", // accespt all route
|
||||||
|
"main.example.com": "https://10.100.111.254:4444", // Specific route
|
||||||
|
"gg.example.com": "https://example.com:443",
|
||||||
|
},
|
||||||
|
TrustTarget: map[string]bool{
|
||||||
|
"*": true, // true = trust any certificates on target url
|
||||||
|
"main.example.com": true,
|
||||||
|
"gg.example.com": false, // trusting target cetificate disabled
|
||||||
|
},
|
||||||
|
NoHTTPSRedirect: map[string]bool{
|
||||||
|
"*": false, // false = HTTP redirected to HTTPS automatically
|
||||||
|
"main.example.com": false,
|
||||||
|
"gg.example.com": true, // no automatic redirect to HTTPS from HTTP
|
||||||
|
},
|
||||||
|
}
|
||||||
|
data, err := yaml.Marshal(defaultConfig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(configPath, data, 0644); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return defaultConfig, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(configPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var config Config
|
||||||
|
if err := yaml.Unmarshal(data, &config); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &config, nil
|
||||||
|
}
|
||||||
24
golangproxy/go.mod
Normal file
24
golangproxy/go.mod
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
module golangproxy
|
||||||
|
|
||||||
|
go 1.24.0
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/fsnotify/fsnotify v1.8.0
|
||||||
|
gopkg.in/yaml.v2 v2.4.0
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/akavel/rsrc v0.10.2 // indirect
|
||||||
|
github.com/getlantern/context v0.0.0-20190109183933-c447772a6520 // indirect
|
||||||
|
github.com/getlantern/errors v0.0.0-20190325191628-abdb3e3e36f7 // indirect
|
||||||
|
github.com/getlantern/golog v0.0.0-20190830074920-4ef2e798c2d7 // indirect
|
||||||
|
github.com/getlantern/hex v0.0.0-20190417191902-c6586a6fe0b7 // indirect
|
||||||
|
github.com/getlantern/hidden v0.0.0-20190325191715-f02dbb02be55 // indirect
|
||||||
|
github.com/getlantern/ops v0.0.0-20190325191751-d70cb0d6f85f // indirect
|
||||||
|
github.com/getlantern/systray v1.2.2 // indirect
|
||||||
|
github.com/go-stack/stack v1.8.0 // indirect
|
||||||
|
github.com/josephspurrier/goversioninfo v1.4.1 // indirect
|
||||||
|
github.com/kardianos/service v1.2.2 // indirect
|
||||||
|
github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c // indirect
|
||||||
|
golang.org/x/sys v0.13.0 // indirect
|
||||||
|
)
|
||||||
45
golangproxy/go.sum
Normal file
45
golangproxy/go.sum
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
github.com/akavel/rsrc v0.10.2 h1:Zxm8V5eI1hW4gGaYsJQUhxpjkENuG91ki8B4zCrvEsw=
|
||||||
|
github.com/akavel/rsrc v0.10.2/go.mod h1:uLoCtb9J+EyAqh+26kdrTgmzRBFPGOolLWKpdxkKq+c=
|
||||||
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/8M=
|
||||||
|
github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||||
|
github.com/getlantern/context v0.0.0-20190109183933-c447772a6520 h1:NRUJuo3v3WGC/g5YiyF790gut6oQr5f3FBI88Wv0dx4=
|
||||||
|
github.com/getlantern/context v0.0.0-20190109183933-c447772a6520/go.mod h1:L+mq6/vvYHKjCX2oez0CgEAJmbq1fbb/oNJIWQkBybY=
|
||||||
|
github.com/getlantern/errors v0.0.0-20190325191628-abdb3e3e36f7 h1:6uJ+sZ/e03gkbqZ0kUG6mfKoqDb4XMAzMIwlajq19So=
|
||||||
|
github.com/getlantern/errors v0.0.0-20190325191628-abdb3e3e36f7/go.mod h1:l+xpFBrCtDLpK9qNjxs+cHU6+BAdlBaxHqikB6Lku3A=
|
||||||
|
github.com/getlantern/golog v0.0.0-20190830074920-4ef2e798c2d7 h1:guBYzEaLz0Vfc/jv0czrr2z7qyzTOGC9hiQ0VC+hKjk=
|
||||||
|
github.com/getlantern/golog v0.0.0-20190830074920-4ef2e798c2d7/go.mod h1:zx/1xUUeYPy3Pcmet8OSXLbF47l+3y6hIPpyLWoR9oc=
|
||||||
|
github.com/getlantern/hex v0.0.0-20190417191902-c6586a6fe0b7 h1:micT5vkcr9tOVk1FiH8SWKID8ultN44Z+yzd2y/Vyb0=
|
||||||
|
github.com/getlantern/hex v0.0.0-20190417191902-c6586a6fe0b7/go.mod h1:dD3CgOrwlzca8ed61CsZouQS5h5jIzkK9ZWrTcf0s+o=
|
||||||
|
github.com/getlantern/hidden v0.0.0-20190325191715-f02dbb02be55 h1:XYzSdCbkzOC0FDNrgJqGRo8PCMFOBFL9py72DRs7bmc=
|
||||||
|
github.com/getlantern/hidden v0.0.0-20190325191715-f02dbb02be55/go.mod h1:6mmzY2kW1TOOrVy+r41Za2MxXM+hhqTtY3oBKd2AgFA=
|
||||||
|
github.com/getlantern/ops v0.0.0-20190325191751-d70cb0d6f85f h1:wrYrQttPS8FHIRSlsrcuKazukx/xqO/PpLZzZXsF+EA=
|
||||||
|
github.com/getlantern/ops v0.0.0-20190325191751-d70cb0d6f85f/go.mod h1:D5ao98qkA6pxftxoqzibIBBrLSUli+kYnJqrgBf9cIA=
|
||||||
|
github.com/getlantern/systray v1.2.2 h1:dCEHtfmvkJG7HZ8lS/sLklTH4RKUcIsKrAD9sThoEBE=
|
||||||
|
github.com/getlantern/systray v1.2.2/go.mod h1:pXFOI1wwqwYXEhLPm9ZGjS2u/vVELeIgNMY5HvhHhcE=
|
||||||
|
github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk=
|
||||||
|
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
|
||||||
|
github.com/josephspurrier/goversioninfo v1.4.1 h1:5LvrkP+n0tg91J9yTkoVnt/QgNnrI1t4uSsWjIonrqY=
|
||||||
|
github.com/josephspurrier/goversioninfo v1.4.1/go.mod h1:JWzv5rKQr+MmW+LvM412ToT/IkYDZjaclF2pKDss8IY=
|
||||||
|
github.com/kardianos/service v1.2.2 h1:ZvePhAHfvo0A7Mftk/tEzqEZ7Q4lgnR8sGz4xu1YX60=
|
||||||
|
github.com/kardianos/service v1.2.2/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||||
|
github.com/lxn/walk v0.0.0-20210112085537-c389da54e794/go.mod h1:E23UucZGqpuUANJooIbHWCufXvOcT6E7Stq81gU+CSQ=
|
||||||
|
github.com/lxn/win v0.0.0-20210218163916-a377121e959e/go.mod h1:KxxjdtRkfNoYDCUP5ryK7XJJNTnpC8atvtmTheChOtk=
|
||||||
|
github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c h1:rp5dCmg/yLR3mgFuSOe4oEnDDmGLROTvMragMUXpTQw=
|
||||||
|
github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c/go.mod h1:X07ZCGwUbLaax7L0S3Tw4hpejzu63ZrrQiUe6W0hcy0=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966/go.mod h1:sUM3LWHvSMaG192sy56D9F7CNvL7jUJVXoqM1QKLnog=
|
||||||
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
|
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||||
|
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
|
golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20201018230417-eeed37f84f13/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
|
||||||
|
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
gopkg.in/Knetic/govaluate.v3 v3.0.0/go.mod h1:csKLBORsPbafmSCGTEh3U7Ozmsuq8ZSIlKk1bcqph0E=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||||
|
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||||
|
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
23
golangproxy/info.md
Normal file
23
golangproxy/info.md
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
# structure:
|
||||||
|
```
|
||||||
|
/golangproxy
|
||||||
|
├── main.go # Application entry point
|
||||||
|
├── config/
|
||||||
|
│ └── config.go # Configuration loading and parsing
|
||||||
|
├── proxy/
|
||||||
|
│ └── proxy.go # Reverse proxy logic
|
||||||
|
├── server/
|
||||||
|
│ └── server.go # Simple web server implementation
|
||||||
|
├── ssl/
|
||||||
|
│ └── ssl.go # SSL certificate management
|
||||||
|
├── logger/
|
||||||
|
│ └── logger.go # Logging setup
|
||||||
|
├── logs/ # Logs directory (created at runtime)
|
||||||
|
├── ssl/ # SSL certificates directory (created at runtime)
|
||||||
|
├── www/ # Web server content directory (created at runtime)
|
||||||
|
└── tests/ # Test files
|
||||||
|
├── config_test.go # Tests for config package
|
||||||
|
├── proxy_test.go # Tests for proxy package
|
||||||
|
├── server_test.go # Tests for server package
|
||||||
|
└── ssl_test.go # Tests for ssl package
|
||||||
|
```
|
||||||
40
golangproxy/logger/logger.go
Normal file
40
golangproxy/logger/logger.go
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
package logger
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Logger is the global logger instance
|
||||||
|
var Logger *log.Logger
|
||||||
|
|
||||||
|
// InitLogger initializes logging to file and stdout
|
||||||
|
func InitLogger() {
|
||||||
|
if err := os.MkdirAll("logs", 0755); err != nil {
|
||||||
|
log.Fatalf("Error creating logs directory: %v", err)
|
||||||
|
}
|
||||||
|
logFile, err := os.OpenFile(filepath.Join("logs", "proxy.log"), os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Error opening log file: %v", err)
|
||||||
|
}
|
||||||
|
multiWriter := io.MultiWriter(os.Stdout, logFile)
|
||||||
|
Logger = log.New(multiWriter, "", log.LstdFlags)
|
||||||
|
// Wrap the logger to filter context canceled errors
|
||||||
|
oldOutput := Logger.Writer()
|
||||||
|
Logger.SetOutput(&filteredWriter{Writer: oldOutput})
|
||||||
|
}
|
||||||
|
|
||||||
|
// filteredWriter wraps an io.Writer to filter out context canceled errors
|
||||||
|
type filteredWriter struct {
|
||||||
|
Writer io.Writer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fw *filteredWriter) Write(p []byte) (n int, err error) {
|
||||||
|
if strings.Contains(string(p), "context canceled") && strings.Contains(string(p), "http: proxy error") {
|
||||||
|
return len(p), nil // Silently discard the message
|
||||||
|
}
|
||||||
|
return fw.Writer.Write(p)
|
||||||
|
}
|
||||||
345
golangproxy/main.go
Normal file
345
golangproxy/main.go
Normal file
@@ -0,0 +1,345 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/fsnotify/fsnotify"
|
||||||
|
|
||||||
|
"golangproxy/config"
|
||||||
|
"golangproxy/logger"
|
||||||
|
"golangproxy/proxy"
|
||||||
|
"golangproxy/server"
|
||||||
|
"golangproxy/ssl"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Global variables for dynamic configuration and certificate updates
|
||||||
|
var (
|
||||||
|
configPath = "config.yaml"
|
||||||
|
routesMutex sync.RWMutex // Protects routes and defaultRoute
|
||||||
|
certMutex sync.RWMutex // Protects currentCert
|
||||||
|
currentConfig *config.Config // Current configuration
|
||||||
|
currentCert *tls.Certificate // Current SSL certificate
|
||||||
|
routes map[string]*proxy.Route // Host-specific routes
|
||||||
|
defaultRoute *proxy.Route // Wildcard route
|
||||||
|
watcher *fsnotify.Watcher // File watcher instance
|
||||||
|
)
|
||||||
|
|
||||||
|
// main initializes and runs the reverse proxy application
|
||||||
|
func main() {
|
||||||
|
// Initialize logging to file and terminal
|
||||||
|
logger.InitLogger()
|
||||||
|
log := logger.Logger
|
||||||
|
|
||||||
|
// Load initial configuration
|
||||||
|
var err error
|
||||||
|
currentConfig, err = config.LoadConfig(configPath)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Error loading config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure SSL certificate and key files exist
|
||||||
|
err = ssl.EnsureCertFiles(currentConfig.CertFile, currentConfig.KeyFile)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Error ensuring cert files: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load initial SSL certificate
|
||||||
|
cert, err := tls.LoadX509KeyPair(currentConfig.CertFile, currentConfig.KeyFile)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Error loading cert: %v", err)
|
||||||
|
}
|
||||||
|
certMutex.Lock()
|
||||||
|
currentCert = &cert
|
||||||
|
certMutex.Unlock()
|
||||||
|
|
||||||
|
// Initialize proxy routes from config
|
||||||
|
initializeRoutes(log)
|
||||||
|
|
||||||
|
// Start the simple web server in a goroutine
|
||||||
|
go server.StartServer()
|
||||||
|
|
||||||
|
// Configure HTTP server
|
||||||
|
httpServer := &http.Server{
|
||||||
|
Addr: currentConfig.ListenHTTP,
|
||||||
|
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
routesMutex.RLock()
|
||||||
|
route := getRoute(r.Host)
|
||||||
|
routesMutex.RUnlock()
|
||||||
|
if strings.HasPrefix(route.Target, "https://") && !route.NoHTTPSRedirect {
|
||||||
|
httpsURL := "https://" + r.Host + r.URL.Path
|
||||||
|
if r.URL.RawQuery != "" {
|
||||||
|
httpsURL += "?" + r.URL.RawQuery
|
||||||
|
}
|
||||||
|
http.Redirect(w, r, httpsURL, http.StatusMovedPermanently)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
route.Handler.ServeHTTP(w, r) // Use Handler instead of Proxy
|
||||||
|
}),
|
||||||
|
ErrorLog: logger.Logger, // Add this to filter server-level errors (from previous fix)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure HTTPS server
|
||||||
|
httpsServer := &http.Server{
|
||||||
|
Addr: currentConfig.ListenHTTPS,
|
||||||
|
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
routesMutex.RLock()
|
||||||
|
route := getRoute(r.Host)
|
||||||
|
routesMutex.RUnlock()
|
||||||
|
route.Handler.ServeHTTP(w, r) // Use Handler instead of Proxy
|
||||||
|
}),
|
||||||
|
TLSConfig: &tls.Config{
|
||||||
|
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
|
certMutex.RLock()
|
||||||
|
defer certMutex.RUnlock()
|
||||||
|
return currentCert, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ErrorLog: logger.Logger, // Add this to filter server-level errors (from previous fix)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start servers in goroutines
|
||||||
|
go func() {
|
||||||
|
log.Println("Starting HTTP server on", currentConfig.ListenHTTP)
|
||||||
|
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||||
|
log.Fatalf("HTTP server error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
log.Println("Starting HTTPS server on", currentConfig.ListenHTTPS)
|
||||||
|
if err := httpsServer.ListenAndServeTLS("", ""); err != nil && err != http.ErrServerClosed {
|
||||||
|
log.Fatalf("HTTPS server error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Initialize file watcher
|
||||||
|
watcher, err = fsnotify.NewWatcher()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Error creating watcher: %v", err)
|
||||||
|
}
|
||||||
|
defer watcher.Close()
|
||||||
|
|
||||||
|
// Watch initial config and cert files
|
||||||
|
err = watcher.Add(configPath)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Error watching config file: %v", err)
|
||||||
|
}
|
||||||
|
err = watcher.Add(currentConfig.CertFile)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Error watching cert file:", err)
|
||||||
|
}
|
||||||
|
err = watcher.Add(currentConfig.KeyFile)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Error watching key file:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle file updates in a goroutine
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case event, ok := <-watcher.Events:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if event.Op&fsnotify.Write == fsnotify.Write {
|
||||||
|
switch event.Name {
|
||||||
|
case configPath:
|
||||||
|
log.Println("Config file changed, reloading...")
|
||||||
|
reloadConfig(log)
|
||||||
|
case currentConfig.CertFile, currentConfig.KeyFile:
|
||||||
|
log.Println("Cert files changed, reloading cert...")
|
||||||
|
reloadCert(log)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case err, ok := <-watcher.Errors:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Println("Watcher error:", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Handle graceful shutdown
|
||||||
|
quit := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
<-quit
|
||||||
|
log.Println("Shutting down...")
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := httpServer.Shutdown(ctx); err != nil {
|
||||||
|
log.Println("HTTP server shutdown error:", err)
|
||||||
|
}
|
||||||
|
if err := httpsServer.Shutdown(ctx); err != nil {
|
||||||
|
log.Println("HTTPS server shutdown error:", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getRoute retrieves the appropriate proxy route for a host
|
||||||
|
func getRoute(host string) *proxy.Route {
|
||||||
|
routesMutex.RLock()
|
||||||
|
defer routesMutex.RUnlock()
|
||||||
|
if route, ok := routes[host]; ok {
|
||||||
|
return route
|
||||||
|
}
|
||||||
|
return defaultRoute
|
||||||
|
}
|
||||||
|
|
||||||
|
// initializeRoutes sets up the routes map and default route from the current config
|
||||||
|
func initializeRoutes(log *log.Logger) {
|
||||||
|
routesMutex.Lock()
|
||||||
|
defer routesMutex.Unlock()
|
||||||
|
|
||||||
|
routes = make(map[string]*proxy.Route)
|
||||||
|
for host, target := range currentConfig.Routes {
|
||||||
|
if host == "*" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
trust := getConfigBool(currentConfig.TrustTarget, host)
|
||||||
|
noRedirect := getConfigBool(currentConfig.NoHTTPSRedirect, host)
|
||||||
|
route := proxy.CreateRoute(target, trust)
|
||||||
|
route.NoHTTPSRedirect = noRedirect
|
||||||
|
routes[host] = route
|
||||||
|
}
|
||||||
|
defaultTarget, ok := currentConfig.Routes["*"]
|
||||||
|
if !ok {
|
||||||
|
log.Fatal("Default route '*' not found in config")
|
||||||
|
}
|
||||||
|
defaultTrust := currentConfig.TrustTarget["*"]
|
||||||
|
defaultNoRedirect := currentConfig.NoHTTPSRedirect["*"]
|
||||||
|
defaultRoute = proxy.CreateRoute(defaultTarget, defaultTrust)
|
||||||
|
defaultRoute.NoHTTPSRedirect = defaultNoRedirect
|
||||||
|
}
|
||||||
|
|
||||||
|
// getConfigBool retrieves a boolean config value, falling back to '*' if host-specific value is absent
|
||||||
|
func getConfigBool(m map[string]bool, host string) bool {
|
||||||
|
if val, ok := m[host]; ok {
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
return m["*"]
|
||||||
|
}
|
||||||
|
|
||||||
|
// reloadConfig reloads the configuration and updates routes and certs if necessary
|
||||||
|
func reloadConfig(log *log.Logger) {
|
||||||
|
newConfig, err := config.LoadConfig(configPath)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Error reloading config:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log differences between old and new config
|
||||||
|
log.Println("Config file changed, reloading...")
|
||||||
|
logConfigChanges(log, currentConfig, newConfig)
|
||||||
|
|
||||||
|
// Store old cert file paths before updating config
|
||||||
|
oldCertFile := currentConfig.CertFile
|
||||||
|
oldKeyFile := currentConfig.KeyFile
|
||||||
|
certChanged := newConfig.CertFile != oldCertFile || newConfig.KeyFile != oldKeyFile
|
||||||
|
|
||||||
|
currentConfig = newConfig
|
||||||
|
|
||||||
|
// Update routes
|
||||||
|
initializeRoutes(log)
|
||||||
|
|
||||||
|
// Update certificates and watcher if paths changed
|
||||||
|
if certChanged {
|
||||||
|
reloadCert(log)
|
||||||
|
updateCertWatchers(log, oldCertFile, oldKeyFile)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// logConfigChanges logs the differences between old and new config
|
||||||
|
func logConfigChanges(log *log.Logger, oldConfig, newConfig *config.Config) {
|
||||||
|
if oldConfig.ListenHTTP != newConfig.ListenHTTP {
|
||||||
|
log.Printf("listen_http changed from %s to %s", oldConfig.ListenHTTP, newConfig.ListenHTTP)
|
||||||
|
}
|
||||||
|
if oldConfig.ListenHTTPS != newConfig.ListenHTTPS {
|
||||||
|
log.Printf("listen_https changed from %s to %s", oldConfig.ListenHTTPS, newConfig.ListenHTTPS)
|
||||||
|
}
|
||||||
|
if oldConfig.CertFile != newConfig.CertFile {
|
||||||
|
log.Printf("cert_file changed from %s to %s", oldConfig.CertFile, newConfig.CertFile)
|
||||||
|
}
|
||||||
|
if oldConfig.KeyFile != newConfig.KeyFile {
|
||||||
|
log.Printf("key_file changed from %s to %s", oldConfig.KeyFile, newConfig.KeyFile)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compare Routes
|
||||||
|
for key := range oldConfig.Routes {
|
||||||
|
if newVal, ok := newConfig.Routes[key]; !ok {
|
||||||
|
log.Printf("Route %s removed (was %s)", key, oldConfig.Routes[key])
|
||||||
|
} else if oldConfig.Routes[key] != newVal {
|
||||||
|
log.Printf("Route %s changed from %s to %s", key, oldConfig.Routes[key], newVal)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for key, newVal := range newConfig.Routes {
|
||||||
|
if _, ok := oldConfig.Routes[key]; !ok {
|
||||||
|
log.Printf("Route %s added: %s", key, newVal)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compare TrustTarget
|
||||||
|
for key := range oldConfig.TrustTarget {
|
||||||
|
if newVal, ok := newConfig.TrustTarget[key]; !ok {
|
||||||
|
log.Printf("trust_target %s removed (was %t)", key, oldConfig.TrustTarget[key])
|
||||||
|
} else if oldConfig.TrustTarget[key] != newVal {
|
||||||
|
log.Printf("trust_target %s changed from %t to %t", key, oldConfig.TrustTarget[key], newVal)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for key, newVal := range newConfig.TrustTarget {
|
||||||
|
if _, ok := oldConfig.TrustTarget[key]; !ok {
|
||||||
|
log.Printf("trust_target %s added: %t", key, newVal)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compare NoHTTPSRedirect
|
||||||
|
for key := range oldConfig.NoHTTPSRedirect {
|
||||||
|
if newVal, ok := newConfig.NoHTTPSRedirect[key]; !ok {
|
||||||
|
log.Printf("no_https_redirect %s removed (was %t)", key, oldConfig.NoHTTPSRedirect[key])
|
||||||
|
} else if oldConfig.NoHTTPSRedirect[key] != newVal {
|
||||||
|
log.Printf("no_https_redirect %s changed from %t to %t", key, oldConfig.NoHTTPSRedirect[key], newVal)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for key, newVal := range newConfig.NoHTTPSRedirect {
|
||||||
|
if _, ok := oldConfig.NoHTTPSRedirect[key]; !ok {
|
||||||
|
log.Printf("no_https_redirect %s added: %t", key, newVal)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// reloadCert reloads the SSL certificate from disk
|
||||||
|
func reloadCert(log *log.Logger) {
|
||||||
|
cert, err := tls.LoadX509KeyPair(currentConfig.CertFile, currentConfig.KeyFile)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Error reloading cert:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
certMutex.Lock()
|
||||||
|
currentCert = &cert
|
||||||
|
certMutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateCertWatchers updates the file watcher for new cert file paths
|
||||||
|
func updateCertWatchers(log *log.Logger, oldCertFile, oldKeyFile string) {
|
||||||
|
if oldCertFile != currentConfig.CertFile {
|
||||||
|
watcher.Remove(oldCertFile)
|
||||||
|
if err := watcher.Add(currentConfig.CertFile); err != nil {
|
||||||
|
log.Println("Error watching new cert file:", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if oldKeyFile != currentConfig.KeyFile {
|
||||||
|
watcher.Remove(oldKeyFile)
|
||||||
|
if err := watcher.Add(currentConfig.KeyFile); err != nil {
|
||||||
|
log.Println("Error watching new key file:", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
96
golangproxy/proxy/proxy.go
Normal file
96
golangproxy/proxy/proxy.go
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httputil"
|
||||||
|
"net/url"
|
||||||
|
|
||||||
|
"golangproxy/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Route holds proxy configuration for a specific host
|
||||||
|
type Route struct {
|
||||||
|
Proxy *httputil.ReverseProxy // The reverse proxy instance
|
||||||
|
Handler http.Handler // Custom handler wrapping the proxy
|
||||||
|
NoHTTPSRedirect bool // Disable HTTP to HTTPS redirect
|
||||||
|
Target string // Target URL for proxying
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateRoute initializes a reverse proxy for a target with trust settings
|
||||||
|
func CreateRoute(target string, trustInvalidCert bool) *Route {
|
||||||
|
url, _ := url.Parse(target)
|
||||||
|
proxy := httputil.NewSingleHostReverseProxy(url)
|
||||||
|
if url.Scheme == "https" {
|
||||||
|
proxy.Transport = &http.Transport{
|
||||||
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: trustInvalidCert},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Modify the Director based on whether the target is an IP or hostname
|
||||||
|
originalDirector := proxy.Director
|
||||||
|
proxy.Director = func(req *http.Request) {
|
||||||
|
originalDirector(req)
|
||||||
|
if isIPTarget(url.Hostname()) {
|
||||||
|
// For IP targets, preserve the incoming Host header (e.g., main.example.com)
|
||||||
|
// This ensures session cookies match the client's requested domain
|
||||||
|
} else {
|
||||||
|
// For hostname targets, set Host to the target's hostname (e.g., example.com)
|
||||||
|
req.Host = url.Host
|
||||||
|
}
|
||||||
|
req.Header.Set("X-Forwarded-For", req.RemoteAddr)
|
||||||
|
req.Header.Set("X-Forwarded-Host", req.Host)
|
||||||
|
req.Header.Set("X-Forwarded-Proto", url.Scheme)
|
||||||
|
if req.Header.Get("User-Agent") == "" {
|
||||||
|
req.Header.Set("User-Agent", "GoLangProxy")
|
||||||
|
}
|
||||||
|
//logger.Logger.Printf("Proxying to %s - Headers: %v, Cookies: %v", target, req.Header, req.Cookies())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a custom handler to wrap the proxy and filter context canceled errors
|
||||||
|
handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
rwWrapper := &responseWriterWrapper{ResponseWriter: rw}
|
||||||
|
proxy.ServeHTTP(rwWrapper, req)
|
||||||
|
if err := req.Context().Err(); err != nil && err != context.Canceled {
|
||||||
|
logger.Logger.Printf("Proxy error for %s: %v", target, err)
|
||||||
|
}
|
||||||
|
//logger.Logger.Printf("Response from %s - Headers: %v, Status: %d", target, rwWrapper.Header(), rwWrapper.status)
|
||||||
|
})
|
||||||
|
|
||||||
|
return &Route{
|
||||||
|
Proxy: proxy,
|
||||||
|
Handler: handler,
|
||||||
|
Target: target,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isIPTarget checks if the target hostname is an IP address
|
||||||
|
func isIPTarget(host string) bool {
|
||||||
|
// Split host and port if a port is present (e.g., "10.100.111.254:4444")
|
||||||
|
hostname, _, err := net.SplitHostPort(host)
|
||||||
|
if err != nil {
|
||||||
|
// If there's no port (or an error splitting), use the original host
|
||||||
|
hostname = host
|
||||||
|
}
|
||||||
|
return net.ParseIP(hostname) != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// responseWriterWrapper captures response status and headers
|
||||||
|
type responseWriterWrapper struct {
|
||||||
|
http.ResponseWriter
|
||||||
|
status int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rw *responseWriterWrapper) WriteHeader(status int) {
|
||||||
|
rw.status = status
|
||||||
|
rw.ResponseWriter.WriteHeader(status)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rw *responseWriterWrapper) Write(b []byte) (int, error) {
|
||||||
|
if rw.status == 0 {
|
||||||
|
rw.status = http.StatusOK
|
||||||
|
}
|
||||||
|
return rw.ResponseWriter.Write(b)
|
||||||
|
}
|
||||||
39
golangproxy/server/server.go
Normal file
39
golangproxy/server/server.go
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
)
|
||||||
|
|
||||||
|
// StartServer launches a web server on 127.0.0.1:61147
|
||||||
|
func StartServer() {
|
||||||
|
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
indexPath := filepath.Join("www", "index.html")
|
||||||
|
if _, err := os.Stat(indexPath); os.IsNotExist(err) {
|
||||||
|
// Create index.html if it doesn’t exist
|
||||||
|
if err := os.MkdirAll("www", 0755); err != nil {
|
||||||
|
http.Error(w, "Error creating www directory", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
file, err := os.Create(indexPath)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Error creating index.html", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
_, err = file.WriteString("<h1>GoLangProxy is up</h1>")
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Error writing index.html", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
http.ServeFile(w, r, indexPath)
|
||||||
|
})
|
||||||
|
|
||||||
|
fmt.Println("Starting simple web server on 127.0.0.1:61147")
|
||||||
|
if err := http.ListenAndServe("127.0.0.1:61147", nil); err != nil {
|
||||||
|
fmt.Println("Web server error:", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
101
golangproxy/ssl/ssl.go
Normal file
101
golangproxy/ssl/ssl.go
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
package ssl
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"encoding/pem"
|
||||||
|
"math/big"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golangproxy/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// EnsureCertFiles ensures SSL certificate and key files exist, generating self-signed if needed
|
||||||
|
func EnsureCertFiles(certPath, keyPath string) error {
|
||||||
|
_, certErr := os.Stat(certPath)
|
||||||
|
_, keyErr := os.Stat(keyPath)
|
||||||
|
if os.IsNotExist(certErr) || os.IsNotExist(keyErr) {
|
||||||
|
logger.Logger.Printf("Certificate or key missing, generating new ones: %s, %s", certPath, keyPath)
|
||||||
|
return generateSelfSignedCert(certPath, keyPath)
|
||||||
|
}
|
||||||
|
logger.Logger.Printf("Certificate and key found: %s, %s", certPath, keyPath)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateSelfSignedCert creates a self-signed certificate and key
|
||||||
|
func generateSelfSignedCert(certPath, keyPath string) error {
|
||||||
|
// Ensure ssl directory exists
|
||||||
|
if err := os.MkdirAll(filepath.Dir(certPath), 0755); err != nil {
|
||||||
|
logger.Logger.Printf("Error creating ssl directory: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
logger.Logger.Println("Created ssl directory")
|
||||||
|
|
||||||
|
// Generate private key
|
||||||
|
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
if err != nil {
|
||||||
|
logger.Logger.Printf("Error generating private key: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := priv.Validate(); err != nil {
|
||||||
|
logger.Logger.Printf("Generated private key is invalid: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
logger.Logger.Println("Generated and validated 2048-bit RSA private key")
|
||||||
|
|
||||||
|
// Create certificate template
|
||||||
|
template := x509.Certificate{
|
||||||
|
SerialNumber: big.NewInt(1),
|
||||||
|
Subject: pkix.Name{
|
||||||
|
Organization: []string{"GoLangProxy Self-Signed"},
|
||||||
|
CommonName: "example.com",
|
||||||
|
},
|
||||||
|
NotBefore: time.Now(),
|
||||||
|
NotAfter: time.Now().Add(365 * 24 * time.Hour),
|
||||||
|
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||||
|
BasicConstraintsValid: true,
|
||||||
|
DNSNames: []string{"example.com", "localhost"}, // SANs required
|
||||||
|
}
|
||||||
|
logger.Logger.Printf("Created certificate template with CN=%s, DNSNames=%v", template.Subject.CommonName, template.DNSNames)
|
||||||
|
|
||||||
|
// Generate certificate
|
||||||
|
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
|
||||||
|
if err != nil {
|
||||||
|
logger.Logger.Printf("Error creating certificate: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
logger.Logger.Printf("Generated certificate, DER length: %d", len(certDER))
|
||||||
|
|
||||||
|
// Write certificate
|
||||||
|
certOut, err := os.Create(certPath)
|
||||||
|
if err != nil {
|
||||||
|
logger.Logger.Printf("Error creating cert file %s: %v", certPath, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer certOut.Close()
|
||||||
|
if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}); err != nil {
|
||||||
|
logger.Logger.Printf("Error encoding cert PEM: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
logger.Logger.Printf("Wrote certificate to %s", certPath)
|
||||||
|
|
||||||
|
// Write private key
|
||||||
|
keyOut, err := os.Create(keyPath)
|
||||||
|
if err != nil {
|
||||||
|
logger.Logger.Printf("Error creating key file %s: %v", keyPath, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer keyOut.Close()
|
||||||
|
if err := pem.Encode(keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}); err != nil {
|
||||||
|
logger.Logger.Printf("Error encoding key PEM: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
logger.Logger.Printf("Wrote private key to %s", keyPath)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
21
golangproxy/tests/config_test.go
Normal file
21
golangproxy/tests/config_test.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package tests
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"golangproxy/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLoadConfig(t *testing.T) {
|
||||||
|
// Test loading existing config
|
||||||
|
// Test creating default config
|
||||||
|
os.Remove("test_config.yaml")
|
||||||
|
config, err := config.LoadConfig("test_config.yaml")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error loading config: %v", err)
|
||||||
|
}
|
||||||
|
if config.ListenHTTP != ":80" {
|
||||||
|
t.Errorf("Expected ListenHTTP :80, got %s", config.ListenHTTP)
|
||||||
|
}
|
||||||
|
}
|
||||||
15
golangproxy/tests/proxy_test.go
Normal file
15
golangproxy/tests/proxy_test.go
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
package tests
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"golangproxy/proxy"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCreateRoute(t *testing.T) {
|
||||||
|
// Test HTTP target
|
||||||
|
route := proxy.CreateRoute("http://example.com", false)
|
||||||
|
if route.Target != "http://example.com" {
|
||||||
|
t.Errorf("Expected target http://example.com, got %s", route.Target)
|
||||||
|
}
|
||||||
|
}
|
||||||
9
golangproxy/tests/server_test.go
Normal file
9
golangproxy/tests/server_test.go
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
package tests
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestStartServer(t *testing.T) {
|
||||||
|
// Test requires mocking or running server in a goroutine
|
||||||
|
}
|
||||||
19
golangproxy/tests/ssl_test.go
Normal file
19
golangproxy/tests/ssl_test.go
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
package tests
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"golangproxy/ssl"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEnsureCertFiles(t *testing.T) {
|
||||||
|
os.RemoveAll("ssl")
|
||||||
|
err := ssl.EnsureCertFiles("ssl/cert.pem", "ssl/key.pem")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error generating certs: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := os.Stat("ssl/cert.pem"); os.IsNotExist(err) {
|
||||||
|
t.Error("Certificate file not created")
|
||||||
|
}
|
||||||
|
}
|
||||||
205
main.go
205
main.go
@@ -1,205 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/tls"
|
|
||||||
"io"
|
|
||||||
"log"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"runtime"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Shared global variables used across all files
|
|
||||||
var (
|
|
||||||
config Config // Holds the proxy configuration loaded from config.yaml
|
|
||||||
configMux sync.RWMutex // Mutex for thread-safe access to config
|
|
||||||
cert *tls.Certificate // TLS certificate for HTTPS server
|
|
||||||
baseDir string // Base directory (working dir for go run, executable dir for binary)
|
|
||||||
certDir string // Directory for certificates
|
|
||||||
certPath string // Path to certificate file
|
|
||||||
keyPath string // Path to private key file
|
|
||||||
configPath string // Path to config.yaml
|
|
||||||
|
|
||||||
// Loggers for different types of events
|
|
||||||
errorLogger *log.Logger // Logs errors to errors.log and console
|
|
||||||
refreshLogger *log.Logger // Logs config/certificate refreshes to refresh.log and console
|
|
||||||
trafficLogger *log.Logger // Logs traffic to traffic-YYYY-MM-DD.log and console
|
|
||||||
)
|
|
||||||
|
|
||||||
// init sets up the base directory and initializes logging
|
|
||||||
func init() {
|
|
||||||
// Determine base directory based on execution context
|
|
||||||
if runtime.GOOS == "windows" && len(os.Args) > 0 && filepath.Ext(os.Args[0]) == ".go" {
|
|
||||||
// For "go run", use current working directory
|
|
||||||
var err error
|
|
||||||
baseDir, err = os.Getwd()
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("Failed to get working directory: %v", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// For compiled binary, use executable's directory
|
|
||||||
exePath, err := os.Executable()
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("Failed to get executable path: %v", err)
|
|
||||||
}
|
|
||||||
baseDir = filepath.Dir(exePath)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize logging to files and console
|
|
||||||
setupLogging()
|
|
||||||
}
|
|
||||||
|
|
||||||
// setupLogging configures loggers for errors, refreshes, and traffic
|
|
||||||
func setupLogging() {
|
|
||||||
logsDir := filepath.Join(baseDir, "logs")
|
|
||||||
if err := os.MkdirAll(logsDir, 0755); err != nil {
|
|
||||||
log.Fatalf("Failed to create logs directory: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Error log: writes to errors.log and stdout
|
|
||||||
errorFile, err := os.OpenFile(filepath.Join(logsDir, "errors.log"), os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("Failed to open errors.log: %v", err)
|
|
||||||
}
|
|
||||||
errorLogger = log.New(io.MultiWriter(os.Stdout, errorFile), "ERROR: ", log.Ldate|log.Ltime|log.Lshortfile)
|
|
||||||
|
|
||||||
// Refresh log: writes to refresh.log and stdout
|
|
||||||
refreshFile, err := os.OpenFile(filepath.Join(logsDir, "refresh.log"), os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("Failed to open refresh.log: %v", err)
|
|
||||||
}
|
|
||||||
refreshLogger = log.New(io.MultiWriter(os.Stdout, refreshFile), "REFRESH: ", log.Ldate|log.Ltime|log.Lshortfile)
|
|
||||||
|
|
||||||
// Traffic log: managed in a goroutine for daily rotation
|
|
||||||
go manageTrafficLogs(logsDir)
|
|
||||||
}
|
|
||||||
|
|
||||||
// manageTrafficLogs handles daily rotation of traffic logs with 7-day retention
|
|
||||||
func manageTrafficLogs(logsDir string) {
|
|
||||||
for {
|
|
||||||
dateStr := time.Now().Format("2006-01-02")
|
|
||||||
trafficFile, err := os.OpenFile(filepath.Join(logsDir, "traffic-"+dateStr+".log"), os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
|
||||||
if err != nil {
|
|
||||||
errorLogger.Printf("Failed to open traffic log: %v", err)
|
|
||||||
time.Sleep(1 * time.Minute)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
trafficLogger = log.New(io.MultiWriter(os.Stdout, trafficFile), "TRAFFIC: ", log.Ldate|log.Ltime)
|
|
||||||
|
|
||||||
// Cleanup logs older than 7 days
|
|
||||||
cleanupOldLogs(logsDir)
|
|
||||||
|
|
||||||
// Wait until next day
|
|
||||||
nextDay := time.Now().Truncate(24 * time.Hour).Add(24 * time.Hour)
|
|
||||||
time.Sleep(time.Until(nextDay))
|
|
||||||
trafficFile.Close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// cleanupOldLogs removes traffic logs older than 7 days
|
|
||||||
func cleanupOldLogs(logsDir string) {
|
|
||||||
files, err := os.ReadDir(logsDir)
|
|
||||||
if err != nil {
|
|
||||||
errorLogger.Printf("Failed to read logs directory: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
cutoff := time.Now().AddDate(0, 0, -7)
|
|
||||||
for _, file := range files {
|
|
||||||
if strings.HasPrefix(file.Name(), "traffic-") && file.Name() != "traffic-"+time.Now().Format("2006-01-02")+".log" {
|
|
||||||
info, err := file.Info()
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if info.ModTime().Before(cutoff) {
|
|
||||||
if err := os.Remove(filepath.Join(logsDir, file.Name())); err != nil {
|
|
||||||
errorLogger.Printf("Failed to remove old log %s: %v", file.Name(), err)
|
|
||||||
} else {
|
|
||||||
refreshLogger.Printf("Removed old traffic log: %s", file.Name())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// main is the entry point, setting up and running the HTTP/HTTPS servers
|
|
||||||
func main() {
|
|
||||||
// Set config file path and load or generate initial config
|
|
||||||
configPath = filepath.Join(baseDir, "config.yaml")
|
|
||||||
if _, err := os.Stat(configPath); os.IsNotExist(err) {
|
|
||||||
config = generateDefaultConfig()
|
|
||||||
if err := saveConfig(config); err != nil {
|
|
||||||
errorLogger.Fatalf("Failed to save default config: %v", err)
|
|
||||||
}
|
|
||||||
refreshLogger.Println("Generated default config file")
|
|
||||||
} else {
|
|
||||||
cfg, err := loadConfig()
|
|
||||||
if err != nil {
|
|
||||||
errorLogger.Fatalf("Failed to load config: %v", err)
|
|
||||||
}
|
|
||||||
config = cfg
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update certificate and config paths based on loaded config
|
|
||||||
updatePaths()
|
|
||||||
|
|
||||||
// Generate or load TLS certificate
|
|
||||||
_, certErr := os.Stat(certPath)
|
|
||||||
_, keyErr := os.Stat(keyPath)
|
|
||||||
if os.IsNotExist(certErr) || os.IsNotExist(keyErr) {
|
|
||||||
if err := generateSelfSignedCert(); err != nil {
|
|
||||||
errorLogger.Fatalf("Failed to generate self-signed certificate: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := loadCertificate(); err != nil {
|
|
||||||
errorLogger.Fatalf("Failed to load certificate: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start background monitoring for config and certificate changes
|
|
||||||
go monitorCertificates()
|
|
||||||
go monitorConfig()
|
|
||||||
|
|
||||||
// Configure HTTP server with timeouts for robustness
|
|
||||||
httpServer := &http.Server{
|
|
||||||
Addr: config.ListenHTTP,
|
|
||||||
Handler: http.HandlerFunc(handler),
|
|
||||||
MaxHeaderBytes: 1 << 20, // 1 MB max header size
|
|
||||||
ReadTimeout: 10 * time.Second,
|
|
||||||
WriteTimeout: 10 * time.Second,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Configure HTTPS server with TLS and certificate fetching
|
|
||||||
tlsConfig := &tls.Config{
|
|
||||||
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
||||||
configMux.RLock()
|
|
||||||
defer configMux.RUnlock()
|
|
||||||
return cert, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
httpsServer := &http.Server{
|
|
||||||
Addr: config.ListenHTTPS,
|
|
||||||
Handler: http.HandlerFunc(handler),
|
|
||||||
TLSConfig: tlsConfig,
|
|
||||||
MaxHeaderBytes: 1 << 20,
|
|
||||||
ReadTimeout: 10 * time.Second,
|
|
||||||
WriteTimeout: 10 * time.Second,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start HTTP server in a goroutine
|
|
||||||
go func() {
|
|
||||||
refreshLogger.Printf("Starting HTTP server on %s", config.ListenHTTP)
|
|
||||||
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
|
||||||
errorLogger.Fatalf("HTTP server error: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Start HTTPS server in the main goroutine
|
|
||||||
refreshLogger.Printf("Starting HTTPS server on %s", config.ListenHTTPS)
|
|
||||||
if err := httpsServer.ListenAndServeTLS("", ""); err != nil && err != http.ErrServerClosed {
|
|
||||||
errorLogger.Fatalf("HTTPS server error: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
370
proxy.go
370
proxy.go
@@ -1,370 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"compress/gzip"
|
|
||||||
"crypto/md5"
|
|
||||||
"crypto/tls"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"log"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httputil"
|
|
||||||
"net/url"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/time/rate"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
// transportPool configures HTTP transport with timeouts matching Nginx defaults
|
|
||||||
transportPool = &http.Transport{
|
|
||||||
MaxIdleConns: 100, // Max idle connections
|
|
||||||
MaxIdleConnsPerHost: 10, // Max idle per host
|
|
||||||
IdleConnTimeout: 90 * time.Second, // Idle connection timeout
|
|
||||||
DialContext: (&net.Dialer{Timeout: 30 * time.Second}).DialContext, // Dial timeout
|
|
||||||
ResponseHeaderTimeout: 60 * time.Second, // Response header timeout (Nginx-like)
|
|
||||||
TLSHandshakeTimeout: 10 * time.Second, // TLS handshake timeout
|
|
||||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // Default skip TLS verification
|
|
||||||
}
|
|
||||||
cache = make(map[string]cachedResponse) // Cache for static responses
|
|
||||||
cacheMutex sync.RWMutex // Mutex for cache access
|
|
||||||
|
|
||||||
// Rate limiting per client IP
|
|
||||||
rateLimiters = make(map[string]*rate.Limiter)
|
|
||||||
rateMutex sync.RWMutex
|
|
||||||
rateLimit = rate.Limit(10) // 10 req/s per client
|
|
||||||
rateBurst = 20 // Burst allowance
|
|
||||||
|
|
||||||
defaultCacheTTL = 5 * time.Minute // Default TTL for cached responses
|
|
||||||
)
|
|
||||||
|
|
||||||
// cachedResponse stores cached response details
|
|
||||||
type cachedResponse struct {
|
|
||||||
body []byte
|
|
||||||
headers http.Header
|
|
||||||
statusCode int
|
|
||||||
cachedAt time.Time
|
|
||||||
cacheDuration time.Duration
|
|
||||||
etag string
|
|
||||||
}
|
|
||||||
|
|
||||||
// getReverseProxy creates a reverse proxy for a target URL
|
|
||||||
func getReverseProxy(target string, skipVerify bool, originalReq *http.Request) *httputil.ReverseProxy {
|
|
||||||
targetURL, err := url.Parse(target)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Error parsing target URL %s: %v", target, err)
|
|
||||||
errorLogger.Printf("Error parsing target URL %s: %v", target, err)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// director modifies the request to forward it to the target
|
|
||||||
director := func(req *http.Request) {
|
|
||||||
req.URL.Scheme = targetURL.Scheme
|
|
||||||
req.URL.Host = targetURL.Host
|
|
||||||
// Preserve the original client’s Host header for session continuity
|
|
||||||
if originalReq.Header.Get("Host") != "" {
|
|
||||||
req.Host = originalReq.Header.Get("Host")
|
|
||||||
} else {
|
|
||||||
req.Host = targetURL.Host
|
|
||||||
}
|
|
||||||
|
|
||||||
// Copy all request headers to ensure cookies and session data are preserved
|
|
||||||
for k, v := range originalReq.Header {
|
|
||||||
req.Header[k] = v
|
|
||||||
}
|
|
||||||
|
|
||||||
// Preserve query parameters and full path
|
|
||||||
req.URL.RawQuery = originalReq.URL.RawQuery
|
|
||||||
if targetURL.Path != "" {
|
|
||||||
req.URL.Path = strings.TrimPrefix(originalReq.URL.Path, "/")
|
|
||||||
req.URL.Path = singleJoin(targetURL.Path, req.URL.Path)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure WebSocket headers are passed correctly
|
|
||||||
if strings.ToLower(req.Header.Get("Upgrade")) == "websocket" {
|
|
||||||
req.Header.Set("Connection", "Upgrade")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Log request details for debugging
|
|
||||||
trafficLogger.Printf("Request: %s %s -> %s [Host: %s] [Headers: %v]", req.Method, req.URL.String(), target, req.Host, req.Header)
|
|
||||||
}
|
|
||||||
|
|
||||||
transport := transportPool
|
|
||||||
if !skipVerify {
|
|
||||||
// Use a new transport with TLS verification if skipVerify is false
|
|
||||||
transport = &http.Transport{
|
|
||||||
MaxIdleConns: transportPool.MaxIdleConns,
|
|
||||||
MaxIdleConnsPerHost: transportPool.MaxIdleConnsPerHost,
|
|
||||||
IdleConnTimeout: transportPool.IdleConnTimeout,
|
|
||||||
DialContext: transportPool.DialContext,
|
|
||||||
ResponseHeaderTimeout: transportPool.ResponseHeaderTimeout,
|
|
||||||
TLSHandshakeTimeout: transportPool.TLSHandshakeTimeout,
|
|
||||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: false},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &httputil.ReverseProxy{
|
|
||||||
Director: director,
|
|
||||||
Transport: transport,
|
|
||||||
// ModifyResponse processes responses, avoiding corruption
|
|
||||||
ModifyResponse: func(resp *http.Response) error {
|
|
||||||
// Preserve all response headers, including Set-Cookie for sessions
|
|
||||||
for k, v := range resp.Header {
|
|
||||||
resp.Header[k] = v
|
|
||||||
}
|
|
||||||
|
|
||||||
// Skip modification for WebSocket responses
|
|
||||||
if strings.ToLower(resp.Header.Get("Upgrade")) == "websocket" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply compression only if not already compressed and client supports it
|
|
||||||
if resp.Header.Get("Content-Encoding") == "" && strings.Contains(resp.Request.Header.Get("Accept-Encoding"), "gzip") {
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
errorLogger.Printf("Error reading response body for compression: %v", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
resp.Body.Close()
|
|
||||||
|
|
||||||
var buf bytes.Buffer
|
|
||||||
gw := gzip.NewWriter(&buf)
|
|
||||||
if _, err := gw.Write(body); err != nil {
|
|
||||||
errorLogger.Printf("Error compressing response: %v", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
gw.Close()
|
|
||||||
|
|
||||||
resp.Body = io.NopCloser(&buf)
|
|
||||||
resp.Header.Set("Content-Encoding", "gzip")
|
|
||||||
resp.Header.Del("Content-Length")
|
|
||||||
trafficLogger.Printf("Compressed response for %s", resp.Request.URL.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cache static content if applicable
|
|
||||||
if shouldCache(resp) {
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
errorLogger.Printf("Error reading response body for caching: %v", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
resp.Body.Close()
|
|
||||||
resp.Body = io.NopCloser(bytes.NewReader(body))
|
|
||||||
|
|
||||||
etag := resp.Header.Get("ETag")
|
|
||||||
if etag == "" {
|
|
||||||
etag = generateETag(body)
|
|
||||||
}
|
|
||||||
|
|
||||||
cacheDuration := parseCacheControl(resp.Header.Get("Cache-Control"))
|
|
||||||
if cacheDuration == 0 {
|
|
||||||
cacheDuration = defaultCacheTTL
|
|
||||||
}
|
|
||||||
|
|
||||||
cacheMutex.Lock()
|
|
||||||
cache[resp.Request.URL.String()] = cachedResponse{
|
|
||||||
body: body,
|
|
||||||
headers: resp.Header.Clone(),
|
|
||||||
statusCode: resp.StatusCode,
|
|
||||||
cachedAt: time.Now(),
|
|
||||||
cacheDuration: cacheDuration,
|
|
||||||
etag: etag,
|
|
||||||
}
|
|
||||||
cacheMutex.Unlock()
|
|
||||||
resp.Header.Set("ETag", etag)
|
|
||||||
trafficLogger.Printf("Cached response for %s [ETag: %s]", resp.Request.URL.String(), etag)
|
|
||||||
}
|
|
||||||
|
|
||||||
trafficLogger.Printf("Response: %s %d from %s [Headers: %v]", resp.Status, resp.StatusCode, target, resp.Header)
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
|
|
||||||
if err.Error() == "context canceled" {
|
|
||||||
trafficLogger.Printf("Request canceled by client for %s: %v", r.Host, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Printf("Proxy error for %s: %v", r.Host, err)
|
|
||||||
errorLogger.Printf("Proxy error for %s: %v", r.Host, err)
|
|
||||||
http.Error(w, "Proxy error", http.StatusBadGateway)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// shouldCache checks if a response should be cached based on method and content type
|
|
||||||
func shouldCache(resp *http.Response) bool {
|
|
||||||
if resp.Request.Method != "GET" || resp.StatusCode != http.StatusOK {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
contentType := resp.Header.Get("Content-Type")
|
|
||||||
return strings.HasPrefix(contentType, "text/") || strings.HasPrefix(contentType, "image/") ||
|
|
||||||
strings.HasPrefix(contentType, "application/javascript") || strings.HasPrefix(contentType, "application/json")
|
|
||||||
}
|
|
||||||
|
|
||||||
// singleJoin combines path segments with a single slash
|
|
||||||
func singleJoin(prefix, suffix string) string {
|
|
||||||
prefix = strings.TrimSuffix(prefix, "/")
|
|
||||||
suffix = strings.TrimPrefix(suffix, "/")
|
|
||||||
return prefix + "/" + suffix
|
|
||||||
}
|
|
||||||
|
|
||||||
// generateETag creates an ETag from the response body using MD5
|
|
||||||
func generateETag(body []byte) string {
|
|
||||||
return fmt.Sprintf(`"%x"`, md5.Sum(body))
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseCacheControl extracts max-age from the Cache-Control header for caching duration
|
|
||||||
func parseCacheControl(header string) time.Duration {
|
|
||||||
if header == "" {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
parts := strings.Split(header, ",")
|
|
||||||
for _, part := range parts {
|
|
||||||
if strings.HasPrefix(part, "max-age=") {
|
|
||||||
ageStr := strings.TrimSpace(strings.TrimPrefix(part, "max-age="))
|
|
||||||
if age, err := strconv.Atoi(ageStr); err == nil {
|
|
||||||
return time.Duration(age) * time.Second
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// getLimiter manages rate limiters per client IP
|
|
||||||
func getLimiter(ip string) *rate.Limiter {
|
|
||||||
rateMutex.Lock()
|
|
||||||
defer rateMutex.Unlock()
|
|
||||||
|
|
||||||
if limiter, exists := rateLimiters[ip]; exists {
|
|
||||||
return limiter
|
|
||||||
}
|
|
||||||
limiter := rate.NewLimiter(rateLimit, rateBurst)
|
|
||||||
rateLimiters[ip] = limiter
|
|
||||||
return limiter
|
|
||||||
}
|
|
||||||
|
|
||||||
// handler processes incoming HTTP/HTTPS requests
|
|
||||||
func handler(w http.ResponseWriter, r *http.Request) {
|
|
||||||
configMux.RLock()
|
|
||||||
defer configMux.RUnlock()
|
|
||||||
|
|
||||||
// Rate limiting based on client IP
|
|
||||||
clientIP := r.RemoteAddr[:strings.LastIndex(r.RemoteAddr, ":")]
|
|
||||||
limiter := getLimiter(clientIP)
|
|
||||||
if !limiter.Allow() {
|
|
||||||
http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
|
|
||||||
trafficLogger.Printf("Rate limited client %s", clientIP)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Retrieve target and settings for the requested host
|
|
||||||
target, exists := config.Routes[r.Host]
|
|
||||||
skipVerify := config.TrustTarget[r.Host]
|
|
||||||
noHTTPSRedirect := config.NoHTTPSRedirect[r.Host]
|
|
||||||
|
|
||||||
if !exists {
|
|
||||||
if target, exists = config.Routes["*"]; !exists {
|
|
||||||
http.Error(w, "Host not configured", http.StatusNotFound)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
skipVerify = config.TrustTarget["*"]
|
|
||||||
noHTTPSRedirect = config.NoHTTPSRedirect["*"]
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle HTTP->HTTPS redirect if applicable
|
|
||||||
isHTTPS := target[:5] == "https"
|
|
||||||
isHTTPReq := r.TLS == nil
|
|
||||||
if isHTTPReq && isHTTPS && !noHTTPSRedirect {
|
|
||||||
redirectURL := "https://" + r.Host + r.RequestURI
|
|
||||||
http.Redirect(w, r, redirectURL, http.StatusMovedPermanently)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check cache for static content
|
|
||||||
cacheKey := r.URL.String()
|
|
||||||
cacheMutex.RLock()
|
|
||||||
if cached, ok := cache[cacheKey]; ok && time.Since(cached.cachedAt) < cached.cacheDuration {
|
|
||||||
if etag := r.Header.Get("If-None-Match"); etag != "" && etag == cached.etag {
|
|
||||||
w.WriteHeader(http.StatusNotModified)
|
|
||||||
trafficLogger.Printf("Served 304 Not Modified from cache: %s [ETag: %s]", cacheKey, cached.etag)
|
|
||||||
} else {
|
|
||||||
for k, v := range cached.headers {
|
|
||||||
w.Header()[k] = v
|
|
||||||
}
|
|
||||||
w.WriteHeader(cached.statusCode)
|
|
||||||
w.Write(cached.body)
|
|
||||||
trafficLogger.Printf("Served from cache: %s [ETag: %s]", cacheKey, cached.etag)
|
|
||||||
}
|
|
||||||
cacheMutex.RUnlock()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
cacheMutex.RUnlock()
|
|
||||||
|
|
||||||
proxy := getReverseProxy(target, skipVerify, r)
|
|
||||||
if proxy == nil {
|
|
||||||
http.Error(w, "Invalid target configuration", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle WebSocket connections
|
|
||||||
if strings.ToLower(r.Header.Get("Upgrade")) == "websocket" {
|
|
||||||
targetURL, err := url.Parse(target)
|
|
||||||
if err != nil {
|
|
||||||
errorLogger.Printf("Error parsing target URL for WebSocket: %v", err)
|
|
||||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
hijacker, ok := w.(http.Hijacker)
|
|
||||||
if !ok {
|
|
||||||
errorLogger.Printf("WebSocket upgrade not supported by server")
|
|
||||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
clientConn, _, err := hijacker.Hijack()
|
|
||||||
if err != nil {
|
|
||||||
errorLogger.Printf("Failed to hijack connection for WebSocket: %v", err)
|
|
||||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer clientConn.Close()
|
|
||||||
|
|
||||||
dialer := transportPool
|
|
||||||
if !skipVerify {
|
|
||||||
dialer = &http.Transport{
|
|
||||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: false},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
targetConn, err := dialer.Dial("tcp", targetURL.Host)
|
|
||||||
if err != nil {
|
|
||||||
errorLogger.Printf("Failed to dial target for WebSocket: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer targetConn.Close()
|
|
||||||
|
|
||||||
if err := r.Write(targetConn); err != nil {
|
|
||||||
errorLogger.Printf("Failed to forward WebSocket request: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
errChan := make(chan error, 2)
|
|
||||||
go func() {
|
|
||||||
_, err := io.Copy(targetConn, clientConn)
|
|
||||||
errChan <- err
|
|
||||||
}()
|
|
||||||
go func() {
|
|
||||||
_, err := io.Copy(clientConn, targetConn)
|
|
||||||
errChan <- err
|
|
||||||
}()
|
|
||||||
trafficLogger.Printf("WebSocket connection established: %s -> %s", r.Host, target)
|
|
||||||
<-errChan
|
|
||||||
trafficLogger.Printf("WebSocket connection closed: %s -> %s", r.Host, target)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
proxy.ServeHTTP(w, r)
|
|
||||||
}
|
|
||||||
168
proxy_test.go
168
proxy_test.go
@@ -1,168 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"bytes"
|
|
||||||
"io"
|
|
||||||
"log"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Mock globals for testing
|
|
||||||
var (
|
|
||||||
mockConfig = Config{
|
|
||||||
Routes: map[string]string{
|
|
||||||
"test.local": "http://mock-target",
|
|
||||||
"ws.local": "ws://mock-target",
|
|
||||||
"cache.local": "http://mock-target",
|
|
||||||
},
|
|
||||||
TrustTarget: map[string]bool{
|
|
||||||
"test.local": true,
|
|
||||||
"ws.local": true,
|
|
||||||
"cache.local": true,
|
|
||||||
},
|
|
||||||
NoHTTPSRedirect: map[string]bool{
|
|
||||||
"test.local": true,
|
|
||||||
"ws.local": true,
|
|
||||||
"cache.local": true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
mockConfigMux = sync.RWMutex{}
|
|
||||||
mockLogger = log.New(io.Discard, "", 0) // Discard logs during testing
|
|
||||||
)
|
|
||||||
|
|
||||||
// TestHandlerRoute tests basic routing to a target
|
|
||||||
func TestHandlerRoute(t *testing.T) {
|
|
||||||
// Mock target server
|
|
||||||
targetServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.Write([]byte("Hello from target"))
|
|
||||||
}))
|
|
||||||
defer targetServer.Close()
|
|
||||||
|
|
||||||
// Setup mock globals
|
|
||||||
configMux.Lock()
|
|
||||||
config = mockConfig
|
|
||||||
config.Routes["test.local"] = targetServer.URL
|
|
||||||
trafficLogger = mockLogger
|
|
||||||
configMux.Unlock()
|
|
||||||
|
|
||||||
// Create request
|
|
||||||
req, _ := http.NewRequest("GET", "http://test.local", nil)
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
|
|
||||||
// Run handler
|
|
||||||
handler(rr, req)
|
|
||||||
|
|
||||||
// Check response
|
|
||||||
if status := rr.Code; status != http.StatusOK {
|
|
||||||
t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusOK)
|
|
||||||
}
|
|
||||||
if body := rr.Body.String(); body != "Hello from target" {
|
|
||||||
t.Errorf("handler returned unexpected body: got %v want %v", body, "Hello from target")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestHandlerWebSocket tests WebSocket upgrade handling
|
|
||||||
func TestHandlerWebSocket(t *testing.T) {
|
|
||||||
// Mock WebSocket target (simplified)
|
|
||||||
targetServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if r.Header.Get("Upgrade") == "websocket" {
|
|
||||||
hj, ok := w.(http.Hijacker)
|
|
||||||
if !ok {
|
|
||||||
t.Fatal("Server doesn’t support hijacking")
|
|
||||||
}
|
|
||||||
conn, _, err := hj.Hijack()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
conn.Write([]byte("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n"))
|
|
||||||
conn.Close()
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
defer targetServer.Close()
|
|
||||||
|
|
||||||
// Setup mock globals
|
|
||||||
configMux.Lock()
|
|
||||||
config = mockConfig
|
|
||||||
config.Routes["ws.local"] = targetServer.URL
|
|
||||||
trafficLogger = mockLogger
|
|
||||||
configMux.Unlock()
|
|
||||||
|
|
||||||
// Create WebSocket request
|
|
||||||
req, _ := http.NewRequest("GET", "http://ws.local", nil)
|
|
||||||
req.Header.Set("Upgrade", "websocket")
|
|
||||||
req.Header.Set("Connection", "Upgrade")
|
|
||||||
|
|
||||||
// Use a custom recorder for hijacking
|
|
||||||
rr := &hijackRecorder{ResponseRecorder: httptest.NewRecorder()}
|
|
||||||
handler(rr, req)
|
|
||||||
|
|
||||||
// Check if hijacking occurred
|
|
||||||
if !rr.hijacked {
|
|
||||||
t.Errorf("Expected WebSocket hijacking, but it didn’t occur")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestHandlerCache tests caching functionality
|
|
||||||
func TestHandlerCache(t *testing.T) {
|
|
||||||
// Mock target server
|
|
||||||
targetServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.Header().Set("Content-Type", "text/plain")
|
|
||||||
w.Write([]byte("Cached content"))
|
|
||||||
}))
|
|
||||||
defer targetServer.Close()
|
|
||||||
|
|
||||||
// Setup mock globals
|
|
||||||
configMux.Lock()
|
|
||||||
config = mockConfig
|
|
||||||
config.Routes["cache.local"] = targetServer.URL
|
|
||||||
trafficLogger = mockLogger
|
|
||||||
cache = make(map[string]cachedResponse) // Reset cache for test isolation
|
|
||||||
configMux.Unlock()
|
|
||||||
|
|
||||||
// First request to cache
|
|
||||||
req, _ := http.NewRequest("GET", "http://cache.local", nil)
|
|
||||||
rr1 := httptest.NewRecorder()
|
|
||||||
handler(rr1, req)
|
|
||||||
|
|
||||||
// Second request to hit cache
|
|
||||||
rr2 := httptest.NewRecorder()
|
|
||||||
handler(rr2, req)
|
|
||||||
|
|
||||||
// Check if second response is from cache
|
|
||||||
if rr2.Code != http.StatusOK {
|
|
||||||
t.Errorf("Expected status OK, got %v", rr2.Code)
|
|
||||||
}
|
|
||||||
if body := rr2.Body.String(); body != "Cached content" {
|
|
||||||
t.Errorf("Expected cached response, got %v", body)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// hijackRecorder mocks ResponseRecorder with Hijack support for WebSocket testing
|
|
||||||
type hijackRecorder struct {
|
|
||||||
*httptest.ResponseRecorder
|
|
||||||
hijacked bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hr *hijackRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
|
||||||
hr.hijacked = true
|
|
||||||
return &mockConn{Reader: bytes.NewReader([]byte("mock response")), Writer: hr.Body}, nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// mockConn is a simple net.Conn mock for testing
|
|
||||||
type mockConn struct {
|
|
||||||
io.Reader
|
|
||||||
io.Writer
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mc *mockConn) Close() error { return nil }
|
|
||||||
func (mc *mockConn) LocalAddr() net.Addr { return nil }
|
|
||||||
func (mc *mockConn) RemoteAddr() net.Addr { return nil }
|
|
||||||
func (mc *mockConn) SetDeadline(t time.Time) error { return nil }
|
|
||||||
func (mc *mockConn) SetReadDeadline(t time.Time) error { return nil }
|
|
||||||
func (mc *mockConn) SetWriteDeadline(t time.Time) error { return nil }
|
|
||||||
11
utils.go
11
utils.go
@@ -1,11 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import "path/filepath"
|
|
||||||
|
|
||||||
// updatePaths sets certificate and config file paths based on the current config
|
|
||||||
func updatePaths() {
|
|
||||||
certDir = filepath.Join(baseDir, config.CertDir) // Certificate directory path
|
|
||||||
certPath = filepath.Join(certDir, config.CertFile) // Full certificate file path
|
|
||||||
keyPath = filepath.Join(certDir, config.KeyFile) // Full private key file path
|
|
||||||
configPath = filepath.Join(baseDir, "config.yaml") // Config file path
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user