191 lines
4.8 KiB
Go
191 lines
4.8 KiB
Go
package server
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"database/sql"
|
|
"encoding/base64"
|
|
"net/http"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
const csrfSessionKey = "csrf_token"
|
|
|
|
func (s *Server) randomToken(n int) (string, error) {
|
|
b := make([]byte, n)
|
|
if _, err := rand.Read(b); err != nil {
|
|
return "", err
|
|
}
|
|
return base64.RawURLEncoding.EncodeToString(b), nil
|
|
}
|
|
|
|
// SessionUser loads gorilla session and exposes user_id and csrf token to context
|
|
func (s *Server) SessionUser() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
sess, _ := s.store.Get(c.Request, "gobsidian_session")
|
|
if v, ok := sess.Values["user_id"].(int64); ok {
|
|
c.Set("user_id", v)
|
|
// derive admin flag
|
|
if ok, err := s.auth.IsUserInGroup(v, "admin"); err == nil && ok {
|
|
c.Set("is_admin", true)
|
|
}
|
|
}
|
|
// ensure CSRF token exists in session
|
|
tok, _ := sess.Values[csrfSessionKey].(string)
|
|
if tok == "" {
|
|
if t, err := s.randomToken(32); err == nil {
|
|
sess.Values[csrfSessionKey] = t
|
|
_ = sess.Save(c.Request, c.Writer)
|
|
tok = t
|
|
}
|
|
}
|
|
c.Set("csrf_token", tok)
|
|
// expose CSRF token to client: header + non-HttpOnly cookie
|
|
if tok != "" {
|
|
c.Writer.Header().Set("X-CSRF-Token", tok)
|
|
// cookie accessible to JS (HttpOnly=false). Secure/ SameSite Lax for CSRF
|
|
http.SetCookie(c.Writer, &http.Cookie{
|
|
Name: "csrf_token",
|
|
Value: tok,
|
|
Path: "/",
|
|
HttpOnly: false,
|
|
SameSite: http.SameSiteLaxMode,
|
|
})
|
|
}
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
// CSRFRequire validates the CSRF token for state-changing requests
|
|
func (s *Server) CSRFRequire() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead || c.Request.Method == http.MethodOptions {
|
|
c.Next()
|
|
return
|
|
}
|
|
sess, _ := s.store.Get(c.Request, "gobsidian_session")
|
|
expected, _ := sess.Values[csrfSessionKey].(string)
|
|
var token string
|
|
if h := c.GetHeader("X-CSRF-Token"); h != "" {
|
|
token = h
|
|
} else {
|
|
token = c.PostForm("csrf_token")
|
|
}
|
|
if expected == "" || token == "" || expected != token {
|
|
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "invalid CSRF token"})
|
|
return
|
|
}
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
// RequireAuth enforces authenticated access
|
|
func (s *Server) RequireAuth() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
if _, exists := c.Get("user_id"); !exists {
|
|
c.Redirect(http.StatusFound, s.config.URLPrefix+"/editor/login")
|
|
c.Abort()
|
|
return
|
|
}
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
// RequireAdmin enforces admin-only access
|
|
func (s *Server) RequireAdmin() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
if _, exists := c.Get("user_id"); !exists {
|
|
c.Redirect(http.StatusFound, s.config.URLPrefix+"/editor/login")
|
|
c.Abort()
|
|
return
|
|
}
|
|
if _, ok := c.Get("is_admin"); !ok {
|
|
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "admin required"})
|
|
return
|
|
}
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
// AccessLogger logs all requests to the access_logs table after handling.
|
|
func (s *Server) AccessLogger() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
start := time.Now()
|
|
c.Next()
|
|
|
|
duration := time.Since(start)
|
|
status := c.Writer.Status()
|
|
ip := c.ClientIP()
|
|
method := c.Request.Method
|
|
path := c.FullPath()
|
|
if path == "" {
|
|
path = c.Request.URL.Path
|
|
}
|
|
ua := c.Request.UserAgent()
|
|
|
|
var userID interface{}
|
|
if uidAny, ok := c.Get("user_id"); ok {
|
|
if v, ok := uidAny.(int64); ok {
|
|
userID = v
|
|
}
|
|
}
|
|
|
|
_, _ = s.auth.DB.Exec(
|
|
`INSERT INTO access_logs (user_id, ip, method, path, status, duration_ms, user_agent) VALUES (?,?,?,?,?,?,?)`,
|
|
userID, ip, method, path, status, duration.Milliseconds(), ua,
|
|
)
|
|
}
|
|
}
|
|
|
|
// IPBanEnforce blocks banned IPs (unless whitelisted).
|
|
func (s *Server) IPBanEnforce() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
ip := c.ClientIP()
|
|
|
|
var whitelisted, permanent int
|
|
var until sql.NullTime
|
|
err := s.auth.DB.QueryRow(`SELECT whitelisted, permanent, until FROM ip_bans WHERE ip = ?`, ip).Scan(&whitelisted, &permanent, &until)
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
c.Next()
|
|
return
|
|
}
|
|
s.logError(c, "ip_ban_lookup_failed: "+err.Error(), "")
|
|
c.Next()
|
|
return
|
|
}
|
|
|
|
if whitelisted == 1 {
|
|
c.Next()
|
|
return
|
|
}
|
|
|
|
banned := false
|
|
if permanent == 1 {
|
|
banned = true
|
|
} else if until.Valid && until.Time.After(time.Now()) {
|
|
banned = true
|
|
}
|
|
|
|
if banned {
|
|
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "access denied"})
|
|
return
|
|
}
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
// logError stores error logs (best-effort).
|
|
func (s *Server) logError(c *gin.Context, message, stack string) {
|
|
var userID interface{}
|
|
if uidAny, ok := c.Get("user_id"); ok {
|
|
if v, ok := uidAny.(int64); ok {
|
|
userID = v
|
|
}
|
|
}
|
|
ip := c.ClientIP()
|
|
path := c.Request.URL.Path
|
|
_, _ = s.auth.DB.Exec(`INSERT INTO error_logs (user_id, ip, path, message, stack) VALUES (?,?,?,?,?)`, userID, ip, path, message, stack)
|
|
}
|