package server import ( "crypto/rand" "database/sql" "encoding/base64" "net/http" "net/url" "time" "github.com/gin-gonic/gin" ) const csrfSessionKey = "csrf_token" func (s *Server) randomToken(n int) (string, error) { b := make([]byte, n) if _, err := rand.Read(b); err != nil { return "", err } return base64.RawURLEncoding.EncodeToString(b), nil } // SessionUser loads gorilla session and exposes user_id and csrf token to context func (s *Server) SessionUser() gin.HandlerFunc { return func(c *gin.Context) { sess, _ := s.store.Get(c.Request, "gobsidian_session") if v, ok := sess.Values["user_id"].(int64); ok { c.Set("user_id", v) // derive admin flag if ok, err := s.auth.IsUserInGroup(v, "admin"); err == nil && ok { c.Set("is_admin", true) } } // ensure CSRF token exists in session tok, _ := sess.Values[csrfSessionKey].(string) if tok == "" { if t, err := s.randomToken(32); err == nil { sess.Values[csrfSessionKey] = t _ = sess.Save(c.Request, c.Writer) tok = t } } c.Set("csrf_token", tok) // expose CSRF token to client: header + non-HttpOnly cookie if tok != "" { c.Writer.Header().Set("X-CSRF-Token", tok) // cookie accessible to JS (HttpOnly=false). Secure/ SameSite Lax for CSRF http.SetCookie(c.Writer, &http.Cookie{ Name: "csrf_token", Value: tok, Path: "/", HttpOnly: false, SameSite: http.SameSiteLaxMode, }) } c.Next() } } // CSRFRequire validates the CSRF token for state-changing requests func (s *Server) CSRFRequire() gin.HandlerFunc { return func(c *gin.Context) { if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead || c.Request.Method == http.MethodOptions { c.Next() return } sess, _ := s.store.Get(c.Request, "gobsidian_session") expected, _ := sess.Values[csrfSessionKey].(string) var token string if h := c.GetHeader("X-CSRF-Token"); h != "" { token = h } else { token = c.PostForm("csrf_token") } if expected == "" || token == "" || expected != token { c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "invalid CSRF token"}) return } c.Next() } } // RequireAuth enforces authenticated access func (s *Server) RequireAuth() gin.HandlerFunc { return func(c *gin.Context) { if _, exists := c.Get("user_id"); !exists { // Attach return_to so user can be redirected back after login requested := c.Request.URL.RequestURI() q := url.Values{} if requested != "" { q.Set("return_to", requested) } loginURL := s.config.URLPrefix + "/editor/login" if qs := q.Encode(); qs != "" { loginURL = loginURL + "?" + qs } c.Redirect(http.StatusFound, loginURL) c.Abort() return } c.Next() } } // RequireAdmin enforces admin-only access func (s *Server) RequireAdmin() gin.HandlerFunc { return func(c *gin.Context) { if _, exists := c.Get("user_id"); !exists { requested := c.Request.URL.RequestURI() q := url.Values{} if requested != "" { q.Set("return_to", requested) } loginURL := s.config.URLPrefix + "/editor/login" if qs := q.Encode(); qs != "" { loginURL = loginURL + "?" + qs } c.Redirect(http.StatusFound, loginURL) c.Abort() return } if _, ok := c.Get("is_admin"); !ok { c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "admin required"}) return } c.Next() } } // AccessLogger logs all requests to the access_logs table after handling. func (s *Server) AccessLogger() gin.HandlerFunc { return func(c *gin.Context) { start := time.Now() c.Next() duration := time.Since(start) status := c.Writer.Status() ip := c.ClientIP() method := c.Request.Method path := c.FullPath() if path == "" { path = c.Request.URL.Path } ua := c.Request.UserAgent() var userID interface{} if uidAny, ok := c.Get("user_id"); ok { if v, ok := uidAny.(int64); ok { userID = v } } _, _ = s.auth.DB.Exec( `INSERT INTO access_logs (user_id, ip, method, path, status, duration_ms, user_agent) VALUES (?,?,?,?,?,?,?)`, userID, ip, method, path, status, duration.Milliseconds(), ua, ) } } // IPBanEnforce blocks banned IPs (unless whitelisted). func (s *Server) IPBanEnforce() gin.HandlerFunc { return func(c *gin.Context) { ip := c.ClientIP() var whitelisted, permanent int var until sql.NullTime err := s.auth.DB.QueryRow(`SELECT whitelisted, permanent, until FROM ip_bans WHERE ip = ?`, ip).Scan(&whitelisted, &permanent, &until) if err != nil { if err == sql.ErrNoRows { c.Next() return } s.logError(c, "ip_ban_lookup_failed: "+err.Error(), "") c.Next() return } if whitelisted == 1 { c.Next() return } banned := false if permanent == 1 { banned = true } else if until.Valid && until.Time.After(time.Now()) { banned = true } if banned { c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "access denied"}) return } c.Next() } } // logError stores error logs (best-effort). func (s *Server) logError(c *gin.Context, message, stack string) { var userID interface{} if uidAny, ok := c.Get("user_id"); ok { if v, ok := uidAny.(int64); ok { userID = v } } ip := c.ClientIP() path := c.Request.URL.Path _, _ = s.auth.DB.Exec(`INSERT INTO error_logs (user_id, ip, path, message, stack) VALUES (?,?,?,?,?)`, userID, ip, path, message, stack) }