362 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			362 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package handlers
 | |
| 
 | |
| import (
 | |
|     "crypto/hmac"
 | |
|     "crypto/sha1"
 | |
|     "crypto/subtle"
 | |
|     "crypto/rand"
 | |
|     "encoding/base32"
 | |
|     "encoding/binary"
 | |
|     "fmt"
 | |
|     "net/http"
 | |
|     "net/url"
 | |
|     "strings"
 | |
|     "time"
 | |
| 
 | |
|     "github.com/gin-gonic/gin"
 | |
| )
 | |
| 
 | |
| const sessionCookieName = "gobsidian_session"
 | |
| 
 | |
| // LoginPage renders the login form
 | |
| func (h *Handlers) LoginPage(c *gin.Context) {
 | |
|     token, _ := c.Get("csrf_token")
 | |
|     c.HTML(http.StatusOK, "login", gin.H{
 | |
|         "app_name":        h.config.AppName,
 | |
|         "csrf_token":      token,
 | |
|         "ContentTemplate": "login_content",
 | |
|         "ScriptsTemplate": "login_scripts",
 | |
|         "Page":            "login",
 | |
|     })
 | |
| }
 | |
| 
 | |
| // --- Failed login tracking and automatic IP bans ---
 | |
| 
 | |
| const (
 | |
|     defaultPwdFailuresThreshold = 5
 | |
|     defaultMFAFailuresThreshold = 10
 | |
|     defaultFailuresWindow       = 30 * time.Minute
 | |
|     defaultBanDuration          = 12 * time.Hour
 | |
| )
 | |
| 
 | |
| // recordFailedAttempt logs a failed attempt and applies IP ban if thresholds exceeded.
 | |
| func (h *Handlers) recordFailedAttempt(c *gin.Context, typ, username string, userID *int64) {
 | |
|     ip := c.ClientIP()
 | |
|     var uid interface{}
 | |
|     if userID != nil {
 | |
|         uid = *userID
 | |
|     }
 | |
|     // Insert failed attempt (best-effort)
 | |
|     _, _ = h.authSvc.DB.Exec(`INSERT INTO failed_logins (ip, user_id, username, type) VALUES (?,?,?,?)`, ip, uid, username, typ)
 | |
| 
 | |
|     // Determine threshold using config overrides (fallback to defaults when zero)
 | |
|     threshold := h.config.PwdFailuresThreshold
 | |
|     if threshold <= 0 {
 | |
|         threshold = defaultPwdFailuresThreshold
 | |
|     }
 | |
|     if typ == "mfa" {
 | |
|         t2 := h.config.MFAFailuresThreshold
 | |
|         if t2 <= 0 {
 | |
|             t2 = defaultMFAFailuresThreshold
 | |
|         }
 | |
|         threshold = t2
 | |
|     }
 | |
|     // Count recent failures for this IP and type
 | |
|     var cnt int
 | |
|     // Window from config (minutes)
 | |
|     windowMinutes := h.config.FailuresWindowMinutes
 | |
|     if windowMinutes <= 0 {
 | |
|         windowMinutes = int(defaultFailuresWindow / time.Minute)
 | |
|     }
 | |
|     cutoff := time.Now().Add(-time.Duration(windowMinutes) * time.Minute)
 | |
|     _ = h.authSvc.DB.QueryRow(`SELECT COUNT(1) FROM failed_logins WHERE ip = ? AND type = ? AND created_at >= ?`, ip, typ, cutoff).Scan(&cnt)
 | |
|     if cnt >= threshold {
 | |
|         // Duration/permanence from config
 | |
|         banHours := h.config.AutoBanDurationHours
 | |
|         if banHours <= 0 {
 | |
|             banHours = int(defaultBanDuration / time.Hour)
 | |
|         }
 | |
|         makePermanent := h.config.AutoBanPermanent
 | |
|         var until interface{}
 | |
|         if makePermanent {
 | |
|             until = nil
 | |
|         } else {
 | |
|             until = time.Now().Add(time.Duration(banHours) * time.Hour)
 | |
|         }
 | |
|         reason := fmt.Sprintf("auto-ban: %s failures=%d within %d minutes", typ, cnt, windowMinutes)
 | |
|         // Upsert ban unless whitelisted or permanent existing
 | |
|         if makePermanent {
 | |
|             _, _ = h.authSvc.DB.Exec(`
 | |
|                 INSERT INTO ip_bans (ip, reason, until, permanent, whitelisted, created_at, updated_at)
 | |
|                 VALUES (?, ?, NULL, 1, 0, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
 | |
|                 ON CONFLICT(ip) DO UPDATE SET
 | |
|                     reason=excluded.reason,
 | |
|                     until=NULL,
 | |
|                     permanent=1,
 | |
|                     updated_at=CURRENT_TIMESTAMP
 | |
|                 WHERE ip_bans.whitelisted = 0
 | |
|             `, ip, reason)
 | |
|         } else {
 | |
|             _, _ = h.authSvc.DB.Exec(`
 | |
|                 INSERT INTO ip_bans (ip, reason, until, permanent, whitelisted, created_at, updated_at)
 | |
|                 VALUES (?, ?, ?, 0, 0, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
 | |
|                 ON CONFLICT(ip) DO UPDATE SET
 | |
|                     reason=excluded.reason,
 | |
|                     until=excluded.until,
 | |
|                     permanent=0,
 | |
|                     updated_at=CURRENT_TIMESTAMP
 | |
|                 WHERE ip_bans.whitelisted = 0 AND ip_bans.permanent = 0
 | |
|             `, ip, reason, until)
 | |
|         }
 | |
|     }
 | |
| }
 | |
| 
 | |
| // isAllDigits returns true if s consists only of ASCII digits 0-9
 | |
| func isAllDigits(s string) bool {
 | |
|     if s == "" {
 | |
|         return false
 | |
|     }
 | |
|     for i := 0; i < len(s); i++ {
 | |
|         if s[i] < '0' || s[i] > '9' {
 | |
|             return false
 | |
|         }
 | |
|     }
 | |
|     return true
 | |
| }
 | |
| 
 | |
| // MFALoginPage shows OTP prompt when MFA is enabled
 | |
| func (h *Handlers) MFALoginPage(c *gin.Context) {
 | |
|     session, _ := h.store.Get(c.Request, sessionCookieName)
 | |
|     if _, ok := session.Values["mfa_user_id"]; !ok {
 | |
|         c.Redirect(http.StatusFound, "/editor/login")
 | |
|         return
 | |
|     }
 | |
|     token, _ := c.Get("csrf_token")
 | |
|     c.HTML(http.StatusOK, "mfa", gin.H{
 | |
|         "app_name":        h.config.AppName,
 | |
|         "csrf_token":      token,
 | |
|         "ContentTemplate": "mfa_content",
 | |
|         "ScriptsTemplate": "mfa_scripts",
 | |
|         "Page":            "mfa",
 | |
|     })
 | |
| }
 | |
| 
 | |
| // MFALoginVerify verifies OTP and completes login
 | |
| func (h *Handlers) MFALoginVerify(c *gin.Context) {
 | |
|     code := strings.TrimSpace(c.PostForm("code"))
 | |
|     if len(code) != 6 || !isAllDigits(code) {
 | |
|         token, _ := c.Get("csrf_token")
 | |
|         // record failed MFA attempt
 | |
|         h.recordFailedAttempt(c, "mfa", "", nil)
 | |
|         c.HTML(http.StatusUnauthorized, "mfa", gin.H{
 | |
|             "app_name":        h.config.AppName,
 | |
|             "csrf_token":      token,
 | |
|             "error":           "Invalid code format",
 | |
|             "ContentTemplate": "mfa_content",
 | |
|             "ScriptsTemplate": "mfa_scripts",
 | |
|             "Page":            "mfa",
 | |
|         })
 | |
|         return
 | |
|     }
 | |
|     session, _ := h.store.Get(c.Request, sessionCookieName)
 | |
|     uidAny, ok := session.Values["mfa_user_id"]
 | |
|     if !ok {
 | |
|         c.Redirect(http.StatusFound, "/editor/login")
 | |
|         return
 | |
|     }
 | |
|     uid, _ := uidAny.(int64)
 | |
|     var secret string
 | |
|     if err := h.authSvc.DB.QueryRow(`SELECT mfa_secret FROM users WHERE id = ?`, uid).Scan(&secret); err != nil || secret == "" {
 | |
|         h.recordFailedAttempt(c, "mfa", "", &uid)
 | |
|         c.HTML(http.StatusUnauthorized, "mfa", gin.H{"error": "MFA not enabled", "Page": "mfa", "ContentTemplate": "mfa_content", "ScriptsTemplate": "mfa_scripts", "app_name": h.config.AppName})
 | |
|         return
 | |
|     }
 | |
|     if !verifyTOTP(secret, code, time.Now()) {
 | |
|         token, _ := c.Get("csrf_token")
 | |
|         h.recordFailedAttempt(c, "mfa", "", &uid)
 | |
|         c.HTML(http.StatusUnauthorized, "mfa", gin.H{
 | |
|             "app_name":        h.config.AppName,
 | |
|             "csrf_token":      token,
 | |
|             "error":           "Invalid code",
 | |
|             "ContentTemplate": "mfa_content",
 | |
|             "ScriptsTemplate": "mfa_scripts",
 | |
|             "Page":            "mfa",
 | |
|         })
 | |
|         return
 | |
|     }
 | |
|     // success: set user_id and clear mfa_user_id
 | |
|     delete(session.Values, "mfa_user_id")
 | |
|     session.Values["user_id"] = uid
 | |
|     _ = session.Save(c.Request, c.Writer)
 | |
|     c.Redirect(http.StatusFound, "/")
 | |
| }
 | |
| 
 | |
| // ProfileMFASetupPage shows QR and input to verify during enrollment
 | |
| func (h *Handlers) ProfileMFASetupPage(c *gin.Context) {
 | |
|     uidPtr := getUserIDPtr(c)
 | |
|     if uidPtr == nil {
 | |
|         c.Redirect(http.StatusFound, "/editor/login")
 | |
|         return
 | |
|     }
 | |
|     // ensure enrollment exists, otherwise create one
 | |
|     var secret string
 | |
|     err := h.authSvc.DB.QueryRow(`SELECT secret FROM mfa_enrollments WHERE user_id = ?`, *uidPtr).Scan(&secret)
 | |
|     if err != nil || secret == "" {
 | |
|         // create new enrollment
 | |
|         s, e := generateBase32Secret()
 | |
|         if e != nil {
 | |
|             c.HTML(http.StatusInternalServerError, "error", gin.H{"error": "Failed to create enrollment", "Page": "error", "ContentTemplate": "error_content", "ScriptsTemplate": "error_scripts", "app_name": h.config.AppName})
 | |
|             return
 | |
|         }
 | |
|         _, _ = h.authSvc.DB.Exec(`INSERT OR REPLACE INTO mfa_enrollments (user_id, secret) VALUES (?, ?)`, *uidPtr, s)
 | |
|         secret = s
 | |
|     }
 | |
|     // Fetch username for label
 | |
|     var username string
 | |
|     _ = h.authSvc.DB.QueryRow(`SELECT username FROM users WHERE id = ?`, *uidPtr).Scan(&username)
 | |
|     issuer := h.config.AppName
 | |
|     label := url.PathEscape(fmt.Sprintf("%s:%s", issuer, username))
 | |
|     otpauth := fmt.Sprintf("otpauth://totp/%s?secret=%s&issuer=%s&digits=6&period=30&algorithm=SHA1", label, secret, url.QueryEscape(issuer))
 | |
| 
 | |
|     // Render simple page (uses base.html shell)
 | |
|     c.HTML(http.StatusOK, "mfa_setup", gin.H{
 | |
|         "app_name":        h.config.AppName,
 | |
|         "Secret":          secret,
 | |
|         "OTPAuthURI":      otpauth,
 | |
|         "ContentTemplate": "mfa_setup_content",
 | |
|         "ScriptsTemplate": "mfa_setup_scripts",
 | |
|         "Page":            "mfa_setup",
 | |
|     })
 | |
| }
 | |
| 
 | |
| // ProfileMFASetupVerify finalizes enrollment by verifying a TOTP code and setting mfa_secret
 | |
| func (h *Handlers) ProfileMFASetupVerify(c *gin.Context) {
 | |
|     uidPtr := getUserIDPtr(c)
 | |
|     if uidPtr == nil {
 | |
|         c.JSON(http.StatusUnauthorized, gin.H{"error": "not authenticated"})
 | |
|         return
 | |
|     }
 | |
|     code := strings.TrimSpace(c.PostForm("code"))
 | |
|     if len(code) != 6 || !isAllDigits(code) {
 | |
|         c.JSON(http.StatusBadRequest, gin.H{"error": "invalid code format"})
 | |
|         return
 | |
|     }
 | |
|     var secret string
 | |
|     if err := h.authSvc.DB.QueryRow(`SELECT secret FROM mfa_enrollments WHERE user_id = ?`, *uidPtr).Scan(&secret); err != nil || secret == "" {
 | |
|         c.JSON(http.StatusBadRequest, gin.H{"error": "no enrollment found"})
 | |
|         return
 | |
|     }
 | |
|     if !verifyTOTP(secret, code, time.Now()) {
 | |
|         c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid code"})
 | |
|         return
 | |
|     }
 | |
|     // move secret to users and delete enrollment
 | |
|     if _, err := h.authSvc.DB.Exec(`UPDATE users SET mfa_secret = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?`, secret, *uidPtr); err != nil {
 | |
|         c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 | |
|         return
 | |
|     }
 | |
|     _, _ = h.authSvc.DB.Exec(`DELETE FROM mfa_enrollments WHERE user_id = ?`, *uidPtr)
 | |
|     c.JSON(http.StatusOK, gin.H{"success": true})
 | |
| }
 | |
| 
 | |
| // TOTP helpers (SHA1, 30s window, 6 digits)
 | |
| func generateBase32Secret() (string, error) {
 | |
|     // 20 random bytes -> base32 without padding
 | |
|     b := make([]byte, 20)
 | |
|     if _, err := rand.Read(b); err != nil {
 | |
|         return "", err
 | |
|     }
 | |
|     enc := base32.StdEncoding.WithPadding(base32.NoPadding)
 | |
|     return enc.EncodeToString(b), nil
 | |
| }
 | |
| 
 | |
| func hotp(secret []byte, counter uint64) string {
 | |
|     // HMAC-SHA1
 | |
|     var buf [8]byte
 | |
|     binary.BigEndian.PutUint64(buf[:], counter)
 | |
|     mac := hmac.New(sha1.New, secret)
 | |
|     mac.Write(buf[:])
 | |
|     sum := mac.Sum(nil)
 | |
|     // dynamic truncation
 | |
|     offset := sum[len(sum)-1] & 0x0F
 | |
|     code := (uint32(sum[offset])&0x7F)<<24 | (uint32(sum[offset+1])&0xFF)<<16 | (uint32(sum[offset+2])&0xFF)<<8 | (uint32(sum[offset+3]) & 0xFF)
 | |
|     return fmt.Sprintf("%06d", code%1000000)
 | |
| }
 | |
| 
 | |
| func verifyTOTP(base32Secret, code string, t time.Time) bool {
 | |
|     enc := base32.StdEncoding.WithPadding(base32.NoPadding)
 | |
|     key, err := enc.DecodeString(strings.ToUpper(base32Secret))
 | |
|     if err != nil {
 | |
|         return false
 | |
|     }
 | |
|     timestep := uint64(t.Unix() / 30)
 | |
|     // allow +/- 1 window
 | |
|     candidates := []string{
 | |
|         hotp(key, timestep-1),
 | |
|         hotp(key, timestep),
 | |
|         hotp(key, timestep+1),
 | |
|     }
 | |
|     for _, c := range candidates {
 | |
|         if subtle.ConstantTimeCompare([]byte(c), []byte(code)) == 1 {
 | |
|             return true
 | |
|         }
 | |
|     }
 | |
|     return false
 | |
| }
 | |
| 
 | |
| // LoginPost processes the login form
 | |
| func (h *Handlers) LoginPost(c *gin.Context) {
 | |
|     username := c.PostForm("username")
 | |
|     password := c.PostForm("password")
 | |
| 
 | |
|     user, err := h.authSvc.Authenticate(username, password)
 | |
|     if err != nil {
 | |
|         token, _ := c.Get("csrf_token")
 | |
|         // record failed password attempt
 | |
|         h.recordFailedAttempt(c, "password", username, nil)
 | |
|         c.HTML(http.StatusUnauthorized, "login", gin.H{
 | |
|             "app_name":        h.config.AppName,
 | |
|             "csrf_token":      token,
 | |
|             "error":           err.Error(),
 | |
|             "ContentTemplate": "login_content",
 | |
|             "ScriptsTemplate": "login_scripts",
 | |
|             "Page":            "login",
 | |
|         })
 | |
|         return
 | |
|     }
 | |
|     // If user has MFA enabled, require OTP before setting user_id
 | |
|     if user.MFASecret.Valid && user.MFASecret.String != "" {
 | |
|         session, _ := h.store.Get(c.Request, sessionCookieName)
 | |
|         session.Values["mfa_user_id"] = user.ID
 | |
|         _ = session.Save(c.Request, c.Writer)
 | |
|         c.Redirect(http.StatusFound, "/editor/mfa")
 | |
|         return
 | |
|     }
 | |
| 
 | |
|     // If admin created an enrollment for this user, force MFA setup after login
 | |
|     var pending int
 | |
|     if err := h.authSvc.DB.QueryRow(`SELECT 1 FROM mfa_enrollments WHERE user_id = ?`, user.ID).Scan(&pending); err == nil {
 | |
|         // normal login, then redirect to setup
 | |
|         session, _ := h.store.Get(c.Request, sessionCookieName)
 | |
|         session.Values["user_id"] = user.ID
 | |
|         _ = session.Save(c.Request, c.Writer)
 | |
|         c.Redirect(http.StatusFound, "/editor/profile/mfa/setup")
 | |
|         return
 | |
|     }
 | |
| 
 | |
|     // Create normal session
 | |
|     session, _ := h.store.Get(c.Request, sessionCookieName)
 | |
|     session.Values["user_id"] = user.ID
 | |
|     _ = session.Save(c.Request, c.Writer)
 | |
| 
 | |
|     c.Redirect(http.StatusFound, "/")
 | |
| }
 | |
| 
 | |
| // LogoutPost clears the session
 | |
| func (h *Handlers) LogoutPost(c *gin.Context) {
 | |
|     session, _ := h.store.Get(c.Request, sessionCookieName)
 | |
|     session.Options.MaxAge = -1
 | |
|     _ = session.Save(c.Request, c.Writer)
 | |
|     c.Redirect(http.StatusFound, "/editor/login")
 | |
| }
 | 
