package middleware import ( "context" "crypto/hmac" "crypto/sha256" "crypto/subtle" "encoding/hex" "fmt" "log" "net/http" "runtime/debug" "strconv" "strings" "time" ) // Middleware wraps an http.Handler. type Middleware func(http.Handler) http.Handler // Chain applies middlewares in order (first = outermost). func Chain(mws ...Middleware) Middleware { return func(next http.Handler) http.Handler { for i := len(mws) - 1; i >= 0; i-- { next = mws[i](next) } return next } } // ----------------------------------------------------------------------- // Context keys // ----------------------------------------------------------------------- type contextKey int const ( sessionContextKey contextKey = iota csrfContextKey ) // ----------------------------------------------------------------------- // Session constants // ----------------------------------------------------------------------- const ( sessionCookieName = "cs_session" sessionTTL = 8 * time.Hour ) // ----------------------------------------------------------------------- // SessionAuth middleware // ----------------------------------------------------------------------- // SessionAuth validates the session cookie on every request. // /login and /static/ are exempt — all other paths redirect to /login when unauthenticated. // On success, the CSRF token is stored in the request context. func SessionAuth(secret string) Middleware { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/login" || strings.HasPrefix(r.URL.Path, "/static/") { next.ServeHTTP(w, r) return } cookie, err := r.Cookie(sessionCookieName) if err != nil { http.Redirect(w, r, "/login", http.StatusSeeOther) return } username, ok := validateSessionValue(secret, cookie.Value) if !ok { http.SetCookie(w, expiredCookie()) http.Redirect(w, r, "/login", http.StatusSeeOther) return } csrf := deriveCSRF(secret, cookie.Value) ctx := context.WithValue(r.Context(), sessionContextKey, username) ctx = context.WithValue(ctx, csrfContextKey, csrf) next.ServeHTTP(w, r.WithContext(ctx)) }) } } // ----------------------------------------------------------------------- // Cookie helpers (exported for use in auth handler) // ----------------------------------------------------------------------- // NewSessionCookie returns a signed, HttpOnly session cookie. func NewSessionCookie(secret, username string) *http.Cookie { exp := time.Now().Add(sessionTTL).Unix() payload := fmt.Sprintf("%d.%s", exp, username) mac := computeHMAC(secret, payload) return &http.Cookie{ Name: sessionCookieName, Value: payload + "." + mac, Path: "/", HttpOnly: true, SameSite: http.SameSiteStrictMode, MaxAge: int(sessionTTL.Seconds()), } } // ClearSessionCookie returns an expired cookie that clears the session. func ClearSessionCookie() *http.Cookie { return expiredCookie() } func expiredCookie() *http.Cookie { return &http.Cookie{ Name: sessionCookieName, Value: "", Path: "/", HttpOnly: true, MaxAge: -1, } } // ----------------------------------------------------------------------- // Context accessors (exported for handlers) // ----------------------------------------------------------------------- // CSRFFromContext returns the CSRF token stored by SessionAuth, or "". func CSRFFromContext(r *http.Request) string { v, _ := r.Context().Value(csrfContextKey).(string) return v } // ----------------------------------------------------------------------- // Internal helpers // ----------------------------------------------------------------------- // validateSessionValue parses and verifies a session cookie value. // Format: ".." // HMAC covers ".". func validateSessionValue(secret, value string) (username string, ok bool) { lastDot := strings.LastIndex(value, ".") if lastDot < 0 { return "", false } payload := value[:lastDot] mac := value[lastDot+1:] if subtle.ConstantTimeCompare([]byte(mac), []byte(computeHMAC(secret, payload))) != 1 { return "", false } firstDot := strings.Index(payload, ".") if firstDot < 0 { return "", false } expStr := payload[:firstDot] username = payload[firstDot+1:] exp, err := strconv.ParseInt(expStr, 10, 64) if err != nil || time.Now().Unix() > exp || username == "" { return "", false } return username, true } func deriveCSRF(secret, sessionValue string) string { h := computeHMAC(secret, "csrf:"+sessionValue) if len(h) > 32 { return h[:32] } return h } func computeHMAC(secret, data string) string { h := hmac.New(sha256.New, []byte(secret)) h.Write([]byte(data)) return hex.EncodeToString(h.Sum(nil)) } // ----------------------------------------------------------------------- // Logger // ----------------------------------------------------------------------- func Logger() Middleware { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() rw := &responseWriter{ResponseWriter: w, status: http.StatusOK} next.ServeHTTP(rw, r) log.Printf("%s %s %d %s", r.Method, r.URL.Path, rw.status, time.Since(start)) }) } } // ----------------------------------------------------------------------- // SecureHeaders // ----------------------------------------------------------------------- func SecureHeaders() Middleware { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { h := w.Header() h.Set("X-Content-Type-Options", "nosniff") h.Set("X-Frame-Options", "DENY") h.Set("X-XSS-Protection", "1; mode=block") h.Set("Referrer-Policy", "same-origin") h.Set("Content-Security-Policy", "default-src 'self'; "+ "script-src 'self' https://cdn.tailwindcss.com 'unsafe-inline'; "+ "style-src 'self' https://fonts.googleapis.com 'unsafe-inline'; "+ "font-src 'self' https://fonts.gstatic.com; "+ "img-src 'self' data:; "+ "connect-src 'self'; "+ "frame-ancestors 'none'") next.ServeHTTP(w, r) }) } } // ----------------------------------------------------------------------- // Recovery // ----------------------------------------------------------------------- func Recovery() Middleware { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { defer func() { if rec := recover(); rec != nil { log.Printf("panic: %v\n%s", rec, debug.Stack()) http.Error(w, "internal server error", http.StatusInternalServerError) } }() next.ServeHTTP(w, r) }) } } // ----------------------------------------------------------------------- // responseWriter — captures status code for logging // ----------------------------------------------------------------------- type responseWriter struct { http.ResponseWriter status int written bool } func (rw *responseWriter) WriteHeader(code int) { if !rw.written { rw.status = code rw.written = true rw.ResponseWriter.WriteHeader(code) } } func (rw *responseWriter) Write(b []byte) (int, error) { if !rw.written { rw.written = true } return rw.ResponseWriter.Write(b) }