Files
gobsidian/internal/server/middleware.go
2025-08-25 21:19:15 +01:00

108 lines
2.7 KiB
Go

package server
import (
"crypto/rand"
"encoding/base64"
"net/http"
"github.com/gin-gonic/gin"
)
const csrfSessionKey = "csrf_token"
func (s *Server) randomToken(n int) (string, error) {
b := make([]byte, n)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
// SessionUser loads gorilla session and exposes user_id and csrf token to context
func (s *Server) SessionUser() gin.HandlerFunc {
return func(c *gin.Context) {
sess, _ := s.store.Get(c.Request, "gobsidian_session")
if v, ok := sess.Values["user_id"].(int64); ok {
c.Set("user_id", v)
// derive admin flag
if ok, err := s.auth.IsUserInGroup(v, "admin"); err == nil && ok {
c.Set("is_admin", true)
}
}
// ensure CSRF token exists in session
tok, _ := sess.Values[csrfSessionKey].(string)
if tok == "" {
if t, err := s.randomToken(32); err == nil {
sess.Values[csrfSessionKey] = t
_ = sess.Save(c.Request, c.Writer)
tok = t
}
}
c.Set("csrf_token", tok)
// expose CSRF token to client: header + non-HttpOnly cookie
if tok != "" {
c.Writer.Header().Set("X-CSRF-Token", tok)
// cookie accessible to JS (HttpOnly=false). Secure/ SameSite Lax for CSRF
http.SetCookie(c.Writer, &http.Cookie{
Name: "csrf_token",
Value: tok,
Path: "/",
HttpOnly: false,
SameSite: http.SameSiteLaxMode,
})
}
c.Next()
}
}
// CSRFRequire validates the CSRF token for state-changing requests
func (s *Server) CSRFRequire() gin.HandlerFunc {
return func(c *gin.Context) {
if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead || c.Request.Method == http.MethodOptions {
c.Next()
return
}
sess, _ := s.store.Get(c.Request, "gobsidian_session")
expected, _ := sess.Values[csrfSessionKey].(string)
var token string
if h := c.GetHeader("X-CSRF-Token"); h != "" {
token = h
} else {
token = c.PostForm("csrf_token")
}
if expected == "" || token == "" || expected != token {
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "invalid CSRF token"})
return
}
c.Next()
}
}
// RequireAuth enforces authenticated access
func (s *Server) RequireAuth() gin.HandlerFunc {
return func(c *gin.Context) {
if _, exists := c.Get("user_id"); !exists {
c.Redirect(http.StatusFound, "/editor/login")
c.Abort()
return
}
c.Next()
}
}
// RequireAdmin enforces admin-only access
func (s *Server) RequireAdmin() gin.HandlerFunc {
return func(c *gin.Context) {
if _, exists := c.Get("user_id"); !exists {
c.Redirect(http.StatusFound, "/editor/login")
c.Abort()
return
}
if _, ok := c.Get("is_admin"); !ok {
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "admin required"})
return
}
c.Next()
}
}