211 lines
		
	
	
		
			5.3 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			211 lines
		
	
	
		
			5.3 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package server
 | |
| 
 | |
| import (
 | |
| 	"crypto/rand"
 | |
| 	"database/sql"
 | |
| 	"encoding/base64"
 | |
| 	"net/http"
 | |
| 	"net/url"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/gin-gonic/gin"
 | |
| )
 | |
| 
 | |
| const csrfSessionKey = "csrf_token"
 | |
| 
 | |
| func (s *Server) randomToken(n int) (string, error) {
 | |
| 	b := make([]byte, n)
 | |
| 	if _, err := rand.Read(b); err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 	return base64.RawURLEncoding.EncodeToString(b), nil
 | |
| }
 | |
| 
 | |
| // SessionUser loads gorilla session and exposes user_id and csrf token to context
 | |
| func (s *Server) SessionUser() gin.HandlerFunc {
 | |
| 	return func(c *gin.Context) {
 | |
| 		sess, _ := s.store.Get(c.Request, "gobsidian_session")
 | |
| 		if v, ok := sess.Values["user_id"].(int64); ok {
 | |
| 			c.Set("user_id", v)
 | |
| 			// derive admin flag
 | |
| 			if ok, err := s.auth.IsUserInGroup(v, "admin"); err == nil && ok {
 | |
| 				c.Set("is_admin", true)
 | |
| 			}
 | |
| 		}
 | |
| 		// ensure CSRF token exists in session
 | |
| 		tok, _ := sess.Values[csrfSessionKey].(string)
 | |
| 		if tok == "" {
 | |
| 			if t, err := s.randomToken(32); err == nil {
 | |
| 				sess.Values[csrfSessionKey] = t
 | |
| 				_ = sess.Save(c.Request, c.Writer)
 | |
| 				tok = t
 | |
| 			}
 | |
| 		}
 | |
| 		c.Set("csrf_token", tok)
 | |
| 		// expose CSRF token to client: header + non-HttpOnly cookie
 | |
| 		if tok != "" {
 | |
| 			c.Writer.Header().Set("X-CSRF-Token", tok)
 | |
| 			// cookie accessible to JS (HttpOnly=false). Secure/ SameSite Lax for CSRF
 | |
| 			http.SetCookie(c.Writer, &http.Cookie{
 | |
| 				Name:     "csrf_token",
 | |
| 				Value:    tok,
 | |
| 				Path:     "/",
 | |
| 				HttpOnly: false,
 | |
| 				SameSite: http.SameSiteLaxMode,
 | |
| 			})
 | |
| 		}
 | |
| 		c.Next()
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // CSRFRequire validates the CSRF token for state-changing requests
 | |
| func (s *Server) CSRFRequire() gin.HandlerFunc {
 | |
| 	return func(c *gin.Context) {
 | |
| 		if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead || c.Request.Method == http.MethodOptions {
 | |
| 			c.Next()
 | |
| 			return
 | |
| 		}
 | |
| 		sess, _ := s.store.Get(c.Request, "gobsidian_session")
 | |
| 		expected, _ := sess.Values[csrfSessionKey].(string)
 | |
| 		var token string
 | |
| 		if h := c.GetHeader("X-CSRF-Token"); h != "" {
 | |
| 			token = h
 | |
| 		} else {
 | |
| 			token = c.PostForm("csrf_token")
 | |
| 		}
 | |
| 		if expected == "" || token == "" || expected != token {
 | |
| 			c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "invalid CSRF token"})
 | |
| 			return
 | |
| 		}
 | |
| 		c.Next()
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // RequireAuth enforces authenticated access
 | |
| func (s *Server) RequireAuth() gin.HandlerFunc {
 | |
| 	return func(c *gin.Context) {
 | |
| 		if _, exists := c.Get("user_id"); !exists {
 | |
| 			// Attach return_to so user can be redirected back after login
 | |
| 			requested := c.Request.URL.RequestURI()
 | |
| 			q := url.Values{}
 | |
| 			if requested != "" {
 | |
| 				q.Set("return_to", requested)
 | |
| 			}
 | |
| 			loginURL := s.config.URLPrefix + "/editor/login"
 | |
| 			if qs := q.Encode(); qs != "" {
 | |
| 				loginURL = loginURL + "?" + qs
 | |
| 			}
 | |
| 			c.Redirect(http.StatusFound, loginURL)
 | |
| 			c.Abort()
 | |
| 			return
 | |
| 		}
 | |
| 		c.Next()
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // RequireAdmin enforces admin-only access
 | |
| func (s *Server) RequireAdmin() gin.HandlerFunc {
 | |
| 	return func(c *gin.Context) {
 | |
| 		if _, exists := c.Get("user_id"); !exists {
 | |
| 			requested := c.Request.URL.RequestURI()
 | |
| 			q := url.Values{}
 | |
| 			if requested != "" {
 | |
| 				q.Set("return_to", requested)
 | |
| 			}
 | |
| 			loginURL := s.config.URLPrefix + "/editor/login"
 | |
| 			if qs := q.Encode(); qs != "" {
 | |
| 				loginURL = loginURL + "?" + qs
 | |
| 			}
 | |
| 			c.Redirect(http.StatusFound, loginURL)
 | |
| 			c.Abort()
 | |
| 			return
 | |
| 		}
 | |
| 		if _, ok := c.Get("is_admin"); !ok {
 | |
| 			c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "admin required"})
 | |
| 			return
 | |
| 		}
 | |
| 		c.Next()
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // AccessLogger logs all requests to the access_logs table after handling.
 | |
| func (s *Server) AccessLogger() gin.HandlerFunc {
 | |
| 	return func(c *gin.Context) {
 | |
| 		start := time.Now()
 | |
| 		c.Next()
 | |
| 
 | |
| 		duration := time.Since(start)
 | |
| 		status := c.Writer.Status()
 | |
| 		ip := c.ClientIP()
 | |
| 		method := c.Request.Method
 | |
| 		path := c.FullPath()
 | |
| 		if path == "" {
 | |
| 			path = c.Request.URL.Path
 | |
| 		}
 | |
| 		ua := c.Request.UserAgent()
 | |
| 
 | |
| 		var userID interface{}
 | |
| 		if uidAny, ok := c.Get("user_id"); ok {
 | |
| 			if v, ok := uidAny.(int64); ok {
 | |
| 				userID = v
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		_, _ = s.auth.DB.Exec(
 | |
| 			`INSERT INTO access_logs (user_id, ip, method, path, status, duration_ms, user_agent) VALUES (?,?,?,?,?,?,?)`,
 | |
| 			userID, ip, method, path, status, duration.Milliseconds(), ua,
 | |
| 		)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // IPBanEnforce blocks banned IPs (unless whitelisted).
 | |
| func (s *Server) IPBanEnforce() gin.HandlerFunc {
 | |
| 	return func(c *gin.Context) {
 | |
| 		ip := c.ClientIP()
 | |
| 
 | |
| 		var whitelisted, permanent int
 | |
| 		var until sql.NullTime
 | |
| 		err := s.auth.DB.QueryRow(`SELECT whitelisted, permanent, until FROM ip_bans WHERE ip = ?`, ip).Scan(&whitelisted, &permanent, &until)
 | |
| 		if err != nil {
 | |
| 			if err == sql.ErrNoRows {
 | |
| 				c.Next()
 | |
| 				return
 | |
| 			}
 | |
| 			s.logError(c, "ip_ban_lookup_failed: "+err.Error(), "")
 | |
| 			c.Next()
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		if whitelisted == 1 {
 | |
| 			c.Next()
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		banned := false
 | |
| 		if permanent == 1 {
 | |
| 			banned = true
 | |
| 		} else if until.Valid && until.Time.After(time.Now()) {
 | |
| 			banned = true
 | |
| 		}
 | |
| 
 | |
| 		if banned {
 | |
| 			c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "access denied"})
 | |
| 			return
 | |
| 		}
 | |
| 		c.Next()
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // logError stores error logs (best-effort).
 | |
| func (s *Server) logError(c *gin.Context, message, stack string) {
 | |
| 	var userID interface{}
 | |
| 	if uidAny, ok := c.Get("user_id"); ok {
 | |
| 		if v, ok := uidAny.(int64); ok {
 | |
| 			userID = v
 | |
| 		}
 | |
| 	}
 | |
| 	ip := c.ClientIP()
 | |
| 	path := c.Request.URL.Path
 | |
| 	_, _ = s.auth.DB.Exec(`INSERT INTO error_logs (user_id, ip, path, message, stack) VALUES (?,?,?,?,?)`, userID, ip, path, message, stack)
 | |
| }
 | 
