169 lines
4.6 KiB
Go
169 lines
4.6 KiB
Go
package main
|
||
|
||
import (
|
||
"bufio"
|
||
"bytes"
|
||
"io"
|
||
"log"
|
||
"net"
|
||
"net/http"
|
||
"net/http/httptest"
|
||
"sync"
|
||
"testing"
|
||
"time"
|
||
)
|
||
|
||
// Mock globals for testing
|
||
var (
|
||
mockConfig = Config{
|
||
Routes: map[string]string{
|
||
"test.local": "http://mock-target",
|
||
"ws.local": "ws://mock-target",
|
||
"cache.local": "http://mock-target",
|
||
},
|
||
TrustTarget: map[string]bool{
|
||
"test.local": true,
|
||
"ws.local": true,
|
||
"cache.local": true,
|
||
},
|
||
NoHTTPSRedirect: map[string]bool{
|
||
"test.local": true,
|
||
"ws.local": true,
|
||
"cache.local": true,
|
||
},
|
||
}
|
||
mockConfigMux = sync.RWMutex{}
|
||
mockLogger = log.New(io.Discard, "", 0) // Discard logs during testing
|
||
)
|
||
|
||
// TestHandlerRoute tests basic routing to a target
|
||
func TestHandlerRoute(t *testing.T) {
|
||
// Mock target server
|
||
targetServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
w.Write([]byte("Hello from target"))
|
||
}))
|
||
defer targetServer.Close()
|
||
|
||
// Setup mock globals
|
||
configMux.Lock()
|
||
config = mockConfig
|
||
config.Routes["test.local"] = targetServer.URL
|
||
trafficLogger = mockLogger
|
||
configMux.Unlock()
|
||
|
||
// Create request
|
||
req, _ := http.NewRequest("GET", "http://test.local", nil)
|
||
rr := httptest.NewRecorder()
|
||
|
||
// Run handler
|
||
handler(rr, req)
|
||
|
||
// Check response
|
||
if status := rr.Code; status != http.StatusOK {
|
||
t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusOK)
|
||
}
|
||
if body := rr.Body.String(); body != "Hello from target" {
|
||
t.Errorf("handler returned unexpected body: got %v want %v", body, "Hello from target")
|
||
}
|
||
}
|
||
|
||
// TestHandlerWebSocket tests WebSocket upgrade handling
|
||
func TestHandlerWebSocket(t *testing.T) {
|
||
// Mock WebSocket target (simplified)
|
||
targetServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
if r.Header.Get("Upgrade") == "websocket" {
|
||
hj, ok := w.(http.Hijacker)
|
||
if !ok {
|
||
t.Fatal("Server doesn’t support hijacking")
|
||
}
|
||
conn, _, err := hj.Hijack()
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
conn.Write([]byte("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n"))
|
||
conn.Close()
|
||
}
|
||
}))
|
||
defer targetServer.Close()
|
||
|
||
// Setup mock globals
|
||
configMux.Lock()
|
||
config = mockConfig
|
||
config.Routes["ws.local"] = targetServer.URL
|
||
trafficLogger = mockLogger
|
||
configMux.Unlock()
|
||
|
||
// Create WebSocket request
|
||
req, _ := http.NewRequest("GET", "http://ws.local", nil)
|
||
req.Header.Set("Upgrade", "websocket")
|
||
req.Header.Set("Connection", "Upgrade")
|
||
|
||
// Use a custom recorder for hijacking
|
||
rr := &hijackRecorder{ResponseRecorder: httptest.NewRecorder()}
|
||
handler(rr, req)
|
||
|
||
// Check if hijacking occurred
|
||
if !rr.hijacked {
|
||
t.Errorf("Expected WebSocket hijacking, but it didn’t occur")
|
||
}
|
||
}
|
||
|
||
// TestHandlerCache tests caching functionality
|
||
func TestHandlerCache(t *testing.T) {
|
||
// Mock target server
|
||
targetServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
w.Header().Set("Content-Type", "text/plain")
|
||
w.Write([]byte("Cached content"))
|
||
}))
|
||
defer targetServer.Close()
|
||
|
||
// Setup mock globals
|
||
configMux.Lock()
|
||
config = mockConfig
|
||
config.Routes["cache.local"] = targetServer.URL
|
||
trafficLogger = mockLogger
|
||
cache = make(map[string]cachedResponse) // Reset cache for test isolation
|
||
configMux.Unlock()
|
||
|
||
// First request to cache
|
||
req, _ := http.NewRequest("GET", "http://cache.local", nil)
|
||
rr1 := httptest.NewRecorder()
|
||
handler(rr1, req)
|
||
|
||
// Second request to hit cache
|
||
rr2 := httptest.NewRecorder()
|
||
handler(rr2, req)
|
||
|
||
// Check if second response is from cache
|
||
if rr2.Code != http.StatusOK {
|
||
t.Errorf("Expected status OK, got %v", rr2.Code)
|
||
}
|
||
if body := rr2.Body.String(); body != "Cached content" {
|
||
t.Errorf("Expected cached response, got %v", body)
|
||
}
|
||
}
|
||
|
||
// hijackRecorder mocks ResponseRecorder with Hijack support for WebSocket testing
|
||
type hijackRecorder struct {
|
||
*httptest.ResponseRecorder
|
||
hijacked bool
|
||
}
|
||
|
||
func (hr *hijackRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||
hr.hijacked = true
|
||
return &mockConn{Reader: bytes.NewReader([]byte("mock response")), Writer: hr.Body}, nil, nil
|
||
}
|
||
|
||
// mockConn is a simple net.Conn mock for testing
|
||
type mockConn struct {
|
||
io.Reader
|
||
io.Writer
|
||
}
|
||
|
||
func (mc *mockConn) Close() error { return nil }
|
||
func (mc *mockConn) LocalAddr() net.Addr { return nil }
|
||
func (mc *mockConn) RemoteAddr() net.Addr { return nil }
|
||
func (mc *mockConn) SetDeadline(t time.Time) error { return nil }
|
||
func (mc *mockConn) SetReadDeadline(t time.Time) error { return nil }
|
||
func (mc *mockConn) SetWriteDeadline(t time.Time) error { return nil }
|