This commit is contained in:
nahakubuilde
2025-08-26 21:43:47 +01:00
parent e8658f5aab
commit 090d491dd6
9 changed files with 168 additions and 59 deletions

View File

@@ -14,6 +14,7 @@ import (
"time"
"github.com/gin-gonic/gin"
"gobsidian/internal/utils"
)
const sessionCookieName = "gobsidian_session"
@@ -26,9 +27,12 @@ func (h *Handlers) LoginPage(c *gin.Context) {
return
}
token, _ := c.Get("csrf_token")
// propagate return_to if provided
returnTo := c.Query("return_to")
c.HTML(http.StatusOK, "login", gin.H{
"app_name": h.config.AppName,
"csrf_token": token,
"return_to": returnTo,
"ContentTemplate": "login_content",
"ScriptsTemplate": "login_scripts",
"Page": "login",
@@ -192,8 +196,17 @@ func (h *Handlers) MFALoginVerify(c *gin.Context) {
// success: set user_id and clear mfa_user_id
delete(session.Values, "mfa_user_id")
session.Values["user_id"] = uid
// use return_to if set in session
var dest string
if v, ok := session.Values["return_to"].(string); ok {
dest = sanitizeReturnTo(h.config.URLPrefix, v)
delete(session.Values, "return_to")
}
_ = session.Save(c.Request, c.Writer)
c.Redirect(http.StatusFound, h.config.URLPrefix+"/")
if dest == "" {
dest = h.config.URLPrefix + "/"
}
c.Redirect(http.StatusFound, dest)
}
// ProfileMFASetupPage shows QR and input to verify during enrollment
@@ -223,9 +236,16 @@ func (h *Handlers) ProfileMFASetupPage(c *gin.Context) {
label := url.PathEscape(fmt.Sprintf("%s:%s", issuer, username))
otpauth := fmt.Sprintf("otpauth://totp/%s?secret=%s&issuer=%s&digits=6&period=30&algorithm=SHA1", label, secret, url.QueryEscape(issuer))
// Render simple page (uses base.html shell)
// Build sidebar tree for consistent UI and pass auth flags
notesTree, _ := utils.BuildTreeStructure(h.config.NotesDir, h.config.NotesDirHideSidepane, h.config)
c.HTML(http.StatusOK, "mfa_setup", gin.H{
"app_name": h.config.AppName,
"notes_tree": notesTree,
"active_path": []string{},
"current_note": nil,
"breadcrumbs": utils.GenerateBreadcrumbs(""),
"Authenticated": true,
"IsAdmin": isAdmin(c),
"Secret": secret,
"OTPAuthURI": otpauth,
"ContentTemplate": "mfa_setup_content",
@@ -313,6 +333,7 @@ func verifyTOTP(base32Secret, code string, t time.Time) bool {
func (h *Handlers) LoginPost(c *gin.Context) {
username := c.PostForm("username")
password := c.PostForm("password")
returnTo := strings.TrimSpace(c.PostForm("return_to"))
user, err := h.authSvc.Authenticate(username, password)
if err != nil {
@@ -323,6 +344,7 @@ func (h *Handlers) LoginPost(c *gin.Context) {
"app_name": h.config.AppName,
"csrf_token": token,
"error": err.Error(),
"return_to": returnTo,
"ContentTemplate": "login_content",
"ScriptsTemplate": "login_scripts",
"Page": "login",
@@ -333,28 +355,30 @@ func (h *Handlers) LoginPost(c *gin.Context) {
if user.MFASecret.Valid && user.MFASecret.String != "" {
session, _ := h.store.Get(c.Request, sessionCookieName)
session.Values["mfa_user_id"] = user.ID
if rt := sanitizeReturnTo(h.config.URLPrefix, returnTo); rt != "" {
session.Values["return_to"] = rt
}
_ = session.Save(c.Request, c.Writer)
c.Redirect(http.StatusFound, h.config.URLPrefix+"/editor/mfa")
return
}
// If admin created an enrollment for this user, force MFA setup after login
var pending int
if err := h.authSvc.DB.QueryRow(`SELECT 1 FROM mfa_enrollments WHERE user_id = ?`, user.ID).Scan(&pending); err == nil {
// normal login, then redirect to setup
session, _ := h.store.Get(c.Request, sessionCookieName)
session.Values["user_id"] = user.ID
_ = session.Save(c.Request, c.Writer)
c.Redirect(http.StatusFound, h.config.URLPrefix+"/editor/profile/mfa/setup")
return
}
// Do NOT automatically force MFA setup just because an enrollment row exists.
// Some deployments may leave stale enrollment rows; we only require MFA when
// the user actually has MFA enabled (mfa_secret set) or when they explicitly
// navigate to setup from profile.
// Create normal session
session, _ := h.store.Get(c.Request, sessionCookieName)
session.Values["user_id"] = user.ID
_ = session.Save(c.Request, c.Writer)
c.Redirect(http.StatusFound, h.config.URLPrefix+"/")
// Redirect to requested page if provided and safe; otherwise home
if rt := sanitizeReturnTo(h.config.URLPrefix, returnTo); rt != "" {
c.Redirect(http.StatusFound, rt)
} else {
c.Redirect(http.StatusFound, h.config.URLPrefix+"/")
}
}
// LogoutPost clears the session
@@ -364,3 +388,34 @@ func (h *Handlers) LogoutPost(c *gin.Context) {
_ = session.Save(c.Request, c.Writer)
c.Redirect(http.StatusFound, h.config.URLPrefix+"/editor/login")
}
// sanitizeReturnTo ensures the provided return_to is a safe in-app path.
// It rejects absolute URLs and protocol-relative URLs. When URLPrefix is set,
// it enforces that the destination stays within that prefix; if a bare
// "/..." path is provided, it will be rewritten to include the prefix.
func sanitizeReturnTo(prefix, v string) string {
v = strings.TrimSpace(v)
if v == "" {
return ""
}
// Disallow absolute and protocol-relative URLs
if strings.HasPrefix(v, "//") {
return ""
}
if u, err := url.Parse(v); err != nil || (u != nil && u.IsAbs()) {
return ""
}
// Must be a path
if !strings.HasPrefix(v, "/") {
v = "/" + v
}
// Enforce prefix containment when configured
if prefix != "" {
if strings.HasPrefix(v, prefix+"/") || v == prefix || v == prefix+"/" {
return v
}
// If it's a root-relative path without prefix, rewrite into prefix
return prefix + v
}
return v
}

View File

@@ -5,6 +5,7 @@ import (
"database/sql"
"encoding/base64"
"net/http"
"net/url"
"time"
"github.com/gin-gonic/gin"
@@ -84,7 +85,17 @@ func (s *Server) CSRFRequire() gin.HandlerFunc {
func (s *Server) RequireAuth() gin.HandlerFunc {
return func(c *gin.Context) {
if _, exists := c.Get("user_id"); !exists {
c.Redirect(http.StatusFound, s.config.URLPrefix+"/editor/login")
// Attach return_to so user can be redirected back after login
requested := c.Request.URL.RequestURI()
q := url.Values{}
if requested != "" {
q.Set("return_to", requested)
}
loginURL := s.config.URLPrefix + "/editor/login"
if qs := q.Encode(); qs != "" {
loginURL = loginURL + "?" + qs
}
c.Redirect(http.StatusFound, loginURL)
c.Abort()
return
}
@@ -96,7 +107,16 @@ func (s *Server) RequireAuth() gin.HandlerFunc {
func (s *Server) RequireAdmin() gin.HandlerFunc {
return func(c *gin.Context) {
if _, exists := c.Get("user_id"); !exists {
c.Redirect(http.StatusFound, s.config.URLPrefix+"/editor/login")
requested := c.Request.URL.RequestURI()
q := url.Values{}
if requested != "" {
q.Set("return_to", requested)
}
loginURL := s.config.URLPrefix + "/editor/login"
if qs := q.Encode(); qs != "" {
loginURL = loginURL + "?" + qs
}
c.Redirect(http.StatusFound, loginURL)
c.Abort()
return
}