Files
gobsidian/internal/handlers/auth.go
nahakubuilde 090d491dd6 fix view
2025-08-26 21:43:47 +01:00

422 lines
15 KiB
Go

package handlers
import (
"crypto/hmac"
"crypto/sha1"
"crypto/subtle"
"crypto/rand"
"encoding/base32"
"encoding/binary"
"fmt"
"net/http"
"net/url"
"strings"
"time"
"github.com/gin-gonic/gin"
"gobsidian/internal/utils"
)
const sessionCookieName = "gobsidian_session"
// LoginPage renders the login form
func (h *Handlers) LoginPage(c *gin.Context) {
// If already authenticated, redirect to home (respect URL prefix)
if isAuthenticated(c) {
c.Redirect(http.StatusFound, h.config.URLPrefix+"/")
return
}
token, _ := c.Get("csrf_token")
// propagate return_to if provided
returnTo := c.Query("return_to")
c.HTML(http.StatusOK, "login", gin.H{
"app_name": h.config.AppName,
"csrf_token": token,
"return_to": returnTo,
"ContentTemplate": "login_content",
"ScriptsTemplate": "login_scripts",
"Page": "login",
})
}
// --- Failed login tracking and automatic IP bans ---
const (
defaultPwdFailuresThreshold = 5
defaultMFAFailuresThreshold = 10
defaultFailuresWindow = 30 * time.Minute
defaultBanDuration = 12 * time.Hour
)
// recordFailedAttempt logs a failed attempt and applies IP ban if thresholds exceeded.
func (h *Handlers) recordFailedAttempt(c *gin.Context, typ, username string, userID *int64) {
ip := c.ClientIP()
var uid interface{}
if userID != nil {
uid = *userID
}
// Insert failed attempt (best-effort)
_, _ = h.authSvc.DB.Exec(`INSERT INTO failed_logins (ip, user_id, username, type) VALUES (?,?,?,?)`, ip, uid, username, typ)
// Determine threshold using config overrides (fallback to defaults when zero)
threshold := h.config.PwdFailuresThreshold
if threshold <= 0 {
threshold = defaultPwdFailuresThreshold
}
if typ == "mfa" {
t2 := h.config.MFAFailuresThreshold
if t2 <= 0 {
t2 = defaultMFAFailuresThreshold
}
threshold = t2
}
// Count recent failures for this IP and type
var cnt int
// Window from config (minutes)
windowMinutes := h.config.FailuresWindowMinutes
if windowMinutes <= 0 {
windowMinutes = int(defaultFailuresWindow / time.Minute)
}
cutoff := time.Now().Add(-time.Duration(windowMinutes) * time.Minute)
_ = h.authSvc.DB.QueryRow(`SELECT COUNT(1) FROM failed_logins WHERE ip = ? AND type = ? AND created_at >= ?`, ip, typ, cutoff).Scan(&cnt)
if cnt >= threshold {
// Duration/permanence from config
banHours := h.config.AutoBanDurationHours
if banHours <= 0 {
banHours = int(defaultBanDuration / time.Hour)
}
makePermanent := h.config.AutoBanPermanent
var until interface{}
if makePermanent {
until = nil
} else {
until = time.Now().Add(time.Duration(banHours) * time.Hour)
}
reason := fmt.Sprintf("auto-ban: %s failures=%d within %d minutes", typ, cnt, windowMinutes)
// Upsert ban unless whitelisted or permanent existing
if makePermanent {
_, _ = h.authSvc.DB.Exec(`
INSERT INTO ip_bans (ip, reason, until, permanent, whitelisted, created_at, updated_at)
VALUES (?, ?, NULL, 1, 0, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
ON CONFLICT(ip) DO UPDATE SET
reason=excluded.reason,
until=NULL,
permanent=1,
updated_at=CURRENT_TIMESTAMP
WHERE ip_bans.whitelisted = 0
`, ip, reason)
} else {
_, _ = h.authSvc.DB.Exec(`
INSERT INTO ip_bans (ip, reason, until, permanent, whitelisted, created_at, updated_at)
VALUES (?, ?, ?, 0, 0, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
ON CONFLICT(ip) DO UPDATE SET
reason=excluded.reason,
until=excluded.until,
permanent=0,
updated_at=CURRENT_TIMESTAMP
WHERE ip_bans.whitelisted = 0 AND ip_bans.permanent = 0
`, ip, reason, until)
}
}
}
// isAllDigits returns true if s consists only of ASCII digits 0-9
func isAllDigits(s string) bool {
if s == "" {
return false
}
for i := 0; i < len(s); i++ {
if s[i] < '0' || s[i] > '9' {
return false
}
}
return true
}
// MFALoginPage shows OTP prompt when MFA is enabled
func (h *Handlers) MFALoginPage(c *gin.Context) {
session, _ := h.store.Get(c.Request, sessionCookieName)
if _, ok := session.Values["mfa_user_id"]; !ok {
c.Redirect(http.StatusFound, h.config.URLPrefix+"/editor/login")
return
}
token, _ := c.Get("csrf_token")
c.HTML(http.StatusOK, "mfa", gin.H{
"app_name": h.config.AppName,
"csrf_token": token,
"ContentTemplate": "mfa_content",
"ScriptsTemplate": "mfa_scripts",
"Page": "mfa",
})
}
// MFALoginVerify verifies OTP and completes login
func (h *Handlers) MFALoginVerify(c *gin.Context) {
code := strings.TrimSpace(c.PostForm("code"))
if len(code) != 6 || !isAllDigits(code) {
token, _ := c.Get("csrf_token")
// record failed MFA attempt
h.recordFailedAttempt(c, "mfa", "", nil)
c.HTML(http.StatusUnauthorized, "mfa", gin.H{
"app_name": h.config.AppName,
"csrf_token": token,
"error": "Invalid code format",
"ContentTemplate": "mfa_content",
"ScriptsTemplate": "mfa_scripts",
"Page": "mfa",
})
return
}
session, _ := h.store.Get(c.Request, sessionCookieName)
uidAny, ok := session.Values["mfa_user_id"]
if !ok {
c.Redirect(http.StatusFound, h.config.URLPrefix+"/editor/login")
return
}
uid, _ := uidAny.(int64)
var secret string
if err := h.authSvc.DB.QueryRow(`SELECT mfa_secret FROM users WHERE id = ?`, uid).Scan(&secret); err != nil || secret == "" {
h.recordFailedAttempt(c, "mfa", "", &uid)
c.HTML(http.StatusUnauthorized, "mfa", gin.H{"error": "MFA not enabled", "Page": "mfa", "ContentTemplate": "mfa_content", "ScriptsTemplate": "mfa_scripts", "app_name": h.config.AppName})
return
}
if !verifyTOTP(secret, code, time.Now()) {
token, _ := c.Get("csrf_token")
h.recordFailedAttempt(c, "mfa", "", &uid)
c.HTML(http.StatusUnauthorized, "mfa", gin.H{
"app_name": h.config.AppName,
"csrf_token": token,
"error": "Invalid code",
"ContentTemplate": "mfa_content",
"ScriptsTemplate": "mfa_scripts",
"Page": "mfa",
})
return
}
// success: set user_id and clear mfa_user_id
delete(session.Values, "mfa_user_id")
session.Values["user_id"] = uid
// use return_to if set in session
var dest string
if v, ok := session.Values["return_to"].(string); ok {
dest = sanitizeReturnTo(h.config.URLPrefix, v)
delete(session.Values, "return_to")
}
_ = session.Save(c.Request, c.Writer)
if dest == "" {
dest = h.config.URLPrefix + "/"
}
c.Redirect(http.StatusFound, dest)
}
// ProfileMFASetupPage shows QR and input to verify during enrollment
func (h *Handlers) ProfileMFASetupPage(c *gin.Context) {
uidPtr := getUserIDPtr(c)
if uidPtr == nil {
c.Redirect(http.StatusFound, h.config.URLPrefix+"/editor/login")
return
}
// ensure enrollment exists, otherwise create one
var secret string
err := h.authSvc.DB.QueryRow(`SELECT secret FROM mfa_enrollments WHERE user_id = ?`, *uidPtr).Scan(&secret)
if err != nil || secret == "" {
// create new enrollment
s, e := generateBase32Secret()
if e != nil {
c.HTML(http.StatusInternalServerError, "error", gin.H{"error": "Failed to create enrollment", "Page": "error", "ContentTemplate": "error_content", "ScriptsTemplate": "error_scripts", "app_name": h.config.AppName})
return
}
_, _ = h.authSvc.DB.Exec(`INSERT OR REPLACE INTO mfa_enrollments (user_id, secret) VALUES (?, ?)`, *uidPtr, s)
secret = s
}
// Fetch username for label
var username string
_ = h.authSvc.DB.QueryRow(`SELECT username FROM users WHERE id = ?`, *uidPtr).Scan(&username)
issuer := h.config.AppName
label := url.PathEscape(fmt.Sprintf("%s:%s", issuer, username))
otpauth := fmt.Sprintf("otpauth://totp/%s?secret=%s&issuer=%s&digits=6&period=30&algorithm=SHA1", label, secret, url.QueryEscape(issuer))
// Build sidebar tree for consistent UI and pass auth flags
notesTree, _ := utils.BuildTreeStructure(h.config.NotesDir, h.config.NotesDirHideSidepane, h.config)
c.HTML(http.StatusOK, "mfa_setup", gin.H{
"app_name": h.config.AppName,
"notes_tree": notesTree,
"active_path": []string{},
"current_note": nil,
"breadcrumbs": utils.GenerateBreadcrumbs(""),
"Authenticated": true,
"IsAdmin": isAdmin(c),
"Secret": secret,
"OTPAuthURI": otpauth,
"ContentTemplate": "mfa_setup_content",
"ScriptsTemplate": "mfa_setup_scripts",
"Page": "mfa_setup",
})
}
// ProfileMFASetupVerify finalizes enrollment by verifying a TOTP code and setting mfa_secret
func (h *Handlers) ProfileMFASetupVerify(c *gin.Context) {
uidPtr := getUserIDPtr(c)
if uidPtr == nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "not authenticated"})
return
}
code := strings.TrimSpace(c.PostForm("code"))
if len(code) != 6 || !isAllDigits(code) {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid code format"})
return
}
var secret string
if err := h.authSvc.DB.QueryRow(`SELECT secret FROM mfa_enrollments WHERE user_id = ?`, *uidPtr).Scan(&secret); err != nil || secret == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "no enrollment found"})
return
}
if !verifyTOTP(secret, code, time.Now()) {
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid code"})
return
}
// move secret to users and delete enrollment
if _, err := h.authSvc.DB.Exec(`UPDATE users SET mfa_secret = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?`, secret, *uidPtr); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
_, _ = h.authSvc.DB.Exec(`DELETE FROM mfa_enrollments WHERE user_id = ?`, *uidPtr)
c.JSON(http.StatusOK, gin.H{"success": true})
}
// TOTP helpers (SHA1, 30s window, 6 digits)
func generateBase32Secret() (string, error) {
// 20 random bytes -> base32 without padding
b := make([]byte, 20)
if _, err := rand.Read(b); err != nil {
return "", err
}
enc := base32.StdEncoding.WithPadding(base32.NoPadding)
return enc.EncodeToString(b), nil
}
func hotp(secret []byte, counter uint64) string {
// HMAC-SHA1
var buf [8]byte
binary.BigEndian.PutUint64(buf[:], counter)
mac := hmac.New(sha1.New, secret)
mac.Write(buf[:])
sum := mac.Sum(nil)
// dynamic truncation
offset := sum[len(sum)-1] & 0x0F
code := (uint32(sum[offset])&0x7F)<<24 | (uint32(sum[offset+1])&0xFF)<<16 | (uint32(sum[offset+2])&0xFF)<<8 | (uint32(sum[offset+3]) & 0xFF)
return fmt.Sprintf("%06d", code%1000000)
}
func verifyTOTP(base32Secret, code string, t time.Time) bool {
enc := base32.StdEncoding.WithPadding(base32.NoPadding)
key, err := enc.DecodeString(strings.ToUpper(base32Secret))
if err != nil {
return false
}
timestep := uint64(t.Unix() / 30)
// allow +/- 1 window
candidates := []string{
hotp(key, timestep-1),
hotp(key, timestep),
hotp(key, timestep+1),
}
for _, c := range candidates {
if subtle.ConstantTimeCompare([]byte(c), []byte(code)) == 1 {
return true
}
}
return false
}
// LoginPost processes the login form
func (h *Handlers) LoginPost(c *gin.Context) {
username := c.PostForm("username")
password := c.PostForm("password")
returnTo := strings.TrimSpace(c.PostForm("return_to"))
user, err := h.authSvc.Authenticate(username, password)
if err != nil {
token, _ := c.Get("csrf_token")
// record failed password attempt
h.recordFailedAttempt(c, "password", username, nil)
c.HTML(http.StatusUnauthorized, "login", gin.H{
"app_name": h.config.AppName,
"csrf_token": token,
"error": err.Error(),
"return_to": returnTo,
"ContentTemplate": "login_content",
"ScriptsTemplate": "login_scripts",
"Page": "login",
})
return
}
// If user has MFA enabled, require OTP before setting user_id
if user.MFASecret.Valid && user.MFASecret.String != "" {
session, _ := h.store.Get(c.Request, sessionCookieName)
session.Values["mfa_user_id"] = user.ID
if rt := sanitizeReturnTo(h.config.URLPrefix, returnTo); rt != "" {
session.Values["return_to"] = rt
}
_ = session.Save(c.Request, c.Writer)
c.Redirect(http.StatusFound, h.config.URLPrefix+"/editor/mfa")
return
}
// Do NOT automatically force MFA setup just because an enrollment row exists.
// Some deployments may leave stale enrollment rows; we only require MFA when
// the user actually has MFA enabled (mfa_secret set) or when they explicitly
// navigate to setup from profile.
// Create normal session
session, _ := h.store.Get(c.Request, sessionCookieName)
session.Values["user_id"] = user.ID
_ = session.Save(c.Request, c.Writer)
// Redirect to requested page if provided and safe; otherwise home
if rt := sanitizeReturnTo(h.config.URLPrefix, returnTo); rt != "" {
c.Redirect(http.StatusFound, rt)
} else {
c.Redirect(http.StatusFound, h.config.URLPrefix+"/")
}
}
// LogoutPost clears the session
func (h *Handlers) LogoutPost(c *gin.Context) {
session, _ := h.store.Get(c.Request, sessionCookieName)
session.Options.MaxAge = -1
_ = session.Save(c.Request, c.Writer)
c.Redirect(http.StatusFound, h.config.URLPrefix+"/editor/login")
}
// sanitizeReturnTo ensures the provided return_to is a safe in-app path.
// It rejects absolute URLs and protocol-relative URLs. When URLPrefix is set,
// it enforces that the destination stays within that prefix; if a bare
// "/..." path is provided, it will be rewritten to include the prefix.
func sanitizeReturnTo(prefix, v string) string {
v = strings.TrimSpace(v)
if v == "" {
return ""
}
// Disallow absolute and protocol-relative URLs
if strings.HasPrefix(v, "//") {
return ""
}
if u, err := url.Parse(v); err != nil || (u != nil && u.IsAbs()) {
return ""
}
// Must be a path
if !strings.HasPrefix(v, "/") {
v = "/" + v
}
// Enforce prefix containment when configured
if prefix != "" {
if strings.HasPrefix(v, prefix+"/") || v == prefix || v == prefix+"/" {
return v
}
// If it's a root-relative path without prefix, rewrite into prefix
return prefix + v
}
return v
}