package handlers import ( "crypto/hmac" "crypto/sha1" "crypto/subtle" "crypto/rand" "encoding/base32" "encoding/binary" "fmt" "net/http" "net/url" "strings" "time" "github.com/gin-gonic/gin" "gobsidian/internal/utils" ) const sessionCookieName = "gobsidian_session" // LoginPage renders the login form func (h *Handlers) LoginPage(c *gin.Context) { // If already authenticated, redirect to home (respect URL prefix) if isAuthenticated(c) { c.Redirect(http.StatusFound, h.config.URLPrefix+"/") return } token, _ := c.Get("csrf_token") // propagate return_to if provided returnTo := c.Query("return_to") c.HTML(http.StatusOK, "login", gin.H{ "app_name": h.config.AppName, "csrf_token": token, "return_to": returnTo, "ContentTemplate": "login_content", "ScriptsTemplate": "login_scripts", "Page": "login", }) } // --- Failed login tracking and automatic IP bans --- const ( defaultPwdFailuresThreshold = 5 defaultMFAFailuresThreshold = 10 defaultFailuresWindow = 30 * time.Minute defaultBanDuration = 12 * time.Hour ) // recordFailedAttempt logs a failed attempt and applies IP ban if thresholds exceeded. func (h *Handlers) recordFailedAttempt(c *gin.Context, typ, username string, userID *int64) { ip := c.ClientIP() var uid interface{} if userID != nil { uid = *userID } // Insert failed attempt (best-effort) _, _ = h.authSvc.DB.Exec(`INSERT INTO failed_logins (ip, user_id, username, type) VALUES (?,?,?,?)`, ip, uid, username, typ) // Determine threshold using config overrides (fallback to defaults when zero) threshold := h.config.PwdFailuresThreshold if threshold <= 0 { threshold = defaultPwdFailuresThreshold } if typ == "mfa" { t2 := h.config.MFAFailuresThreshold if t2 <= 0 { t2 = defaultMFAFailuresThreshold } threshold = t2 } // Count recent failures for this IP and type var cnt int // Window from config (minutes) windowMinutes := h.config.FailuresWindowMinutes if windowMinutes <= 0 { windowMinutes = int(defaultFailuresWindow / time.Minute) } cutoff := time.Now().Add(-time.Duration(windowMinutes) * time.Minute) _ = h.authSvc.DB.QueryRow(`SELECT COUNT(1) FROM failed_logins WHERE ip = ? AND type = ? AND created_at >= ?`, ip, typ, cutoff).Scan(&cnt) if cnt >= threshold { // Duration/permanence from config banHours := h.config.AutoBanDurationHours if banHours <= 0 { banHours = int(defaultBanDuration / time.Hour) } makePermanent := h.config.AutoBanPermanent var until interface{} if makePermanent { until = nil } else { until = time.Now().Add(time.Duration(banHours) * time.Hour) } reason := fmt.Sprintf("auto-ban: %s failures=%d within %d minutes", typ, cnt, windowMinutes) // Upsert ban unless whitelisted or permanent existing if makePermanent { _, _ = h.authSvc.DB.Exec(` INSERT INTO ip_bans (ip, reason, until, permanent, whitelisted, created_at, updated_at) VALUES (?, ?, NULL, 1, 0, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) ON CONFLICT(ip) DO UPDATE SET reason=excluded.reason, until=NULL, permanent=1, updated_at=CURRENT_TIMESTAMP WHERE ip_bans.whitelisted = 0 `, ip, reason) } else { _, _ = h.authSvc.DB.Exec(` INSERT INTO ip_bans (ip, reason, until, permanent, whitelisted, created_at, updated_at) VALUES (?, ?, ?, 0, 0, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) ON CONFLICT(ip) DO UPDATE SET reason=excluded.reason, until=excluded.until, permanent=0, updated_at=CURRENT_TIMESTAMP WHERE ip_bans.whitelisted = 0 AND ip_bans.permanent = 0 `, ip, reason, until) } } } // isAllDigits returns true if s consists only of ASCII digits 0-9 func isAllDigits(s string) bool { if s == "" { return false } for i := 0; i < len(s); i++ { if s[i] < '0' || s[i] > '9' { return false } } return true } // MFALoginPage shows OTP prompt when MFA is enabled func (h *Handlers) MFALoginPage(c *gin.Context) { session, _ := h.store.Get(c.Request, sessionCookieName) if _, ok := session.Values["mfa_user_id"]; !ok { c.Redirect(http.StatusFound, h.config.URLPrefix+"/editor/login") return } token, _ := c.Get("csrf_token") c.HTML(http.StatusOK, "mfa", gin.H{ "app_name": h.config.AppName, "csrf_token": token, "ContentTemplate": "mfa_content", "ScriptsTemplate": "mfa_scripts", "Page": "mfa", }) } // MFALoginVerify verifies OTP and completes login func (h *Handlers) MFALoginVerify(c *gin.Context) { code := strings.TrimSpace(c.PostForm("code")) if len(code) != 6 || !isAllDigits(code) { token, _ := c.Get("csrf_token") // record failed MFA attempt h.recordFailedAttempt(c, "mfa", "", nil) c.HTML(http.StatusUnauthorized, "mfa", gin.H{ "app_name": h.config.AppName, "csrf_token": token, "error": "Invalid code format", "ContentTemplate": "mfa_content", "ScriptsTemplate": "mfa_scripts", "Page": "mfa", }) return } session, _ := h.store.Get(c.Request, sessionCookieName) uidAny, ok := session.Values["mfa_user_id"] if !ok { c.Redirect(http.StatusFound, h.config.URLPrefix+"/editor/login") return } uid, _ := uidAny.(int64) var secret string if err := h.authSvc.DB.QueryRow(`SELECT mfa_secret FROM users WHERE id = ?`, uid).Scan(&secret); err != nil || secret == "" { h.recordFailedAttempt(c, "mfa", "", &uid) c.HTML(http.StatusUnauthorized, "mfa", gin.H{"error": "MFA not enabled", "Page": "mfa", "ContentTemplate": "mfa_content", "ScriptsTemplate": "mfa_scripts", "app_name": h.config.AppName}) return } if !verifyTOTP(secret, code, time.Now()) { token, _ := c.Get("csrf_token") h.recordFailedAttempt(c, "mfa", "", &uid) c.HTML(http.StatusUnauthorized, "mfa", gin.H{ "app_name": h.config.AppName, "csrf_token": token, "error": "Invalid code", "ContentTemplate": "mfa_content", "ScriptsTemplate": "mfa_scripts", "Page": "mfa", }) return } // success: set user_id and clear mfa_user_id delete(session.Values, "mfa_user_id") session.Values["user_id"] = uid // use return_to if set in session var dest string if v, ok := session.Values["return_to"].(string); ok { dest = sanitizeReturnTo(h.config.URLPrefix, v) delete(session.Values, "return_to") } _ = session.Save(c.Request, c.Writer) if dest == "" { dest = h.config.URLPrefix + "/" } c.Redirect(http.StatusFound, dest) } // ProfileMFASetupPage shows QR and input to verify during enrollment func (h *Handlers) ProfileMFASetupPage(c *gin.Context) { uidPtr := getUserIDPtr(c) if uidPtr == nil { c.Redirect(http.StatusFound, h.config.URLPrefix+"/editor/login") return } // ensure enrollment exists, otherwise create one var secret string err := h.authSvc.DB.QueryRow(`SELECT secret FROM mfa_enrollments WHERE user_id = ?`, *uidPtr).Scan(&secret) if err != nil || secret == "" { // create new enrollment s, e := generateBase32Secret() if e != nil { c.HTML(http.StatusInternalServerError, "error", gin.H{"error": "Failed to create enrollment", "Page": "error", "ContentTemplate": "error_content", "ScriptsTemplate": "error_scripts", "app_name": h.config.AppName}) return } _, _ = h.authSvc.DB.Exec(`INSERT OR REPLACE INTO mfa_enrollments (user_id, secret) VALUES (?, ?)`, *uidPtr, s) secret = s } // Fetch username for label var username string _ = h.authSvc.DB.QueryRow(`SELECT username FROM users WHERE id = ?`, *uidPtr).Scan(&username) issuer := h.config.AppName label := url.PathEscape(fmt.Sprintf("%s:%s", issuer, username)) otpauth := fmt.Sprintf("otpauth://totp/%s?secret=%s&issuer=%s&digits=6&period=30&algorithm=SHA1", label, secret, url.QueryEscape(issuer)) // Build sidebar tree for consistent UI and pass auth flags notesTree, _ := utils.BuildTreeStructure(h.config.NotesDir, h.config.NotesDirHideSidepane, h.config) c.HTML(http.StatusOK, "mfa_setup", gin.H{ "app_name": h.config.AppName, "notes_tree": notesTree, "active_path": []string{}, "current_note": nil, "breadcrumbs": utils.GenerateBreadcrumbs(""), "Authenticated": true, "IsAdmin": isAdmin(c), "Secret": secret, "OTPAuthURI": otpauth, "ContentTemplate": "mfa_setup_content", "ScriptsTemplate": "mfa_setup_scripts", "Page": "mfa_setup", }) } // ProfileMFASetupVerify finalizes enrollment by verifying a TOTP code and setting mfa_secret func (h *Handlers) ProfileMFASetupVerify(c *gin.Context) { uidPtr := getUserIDPtr(c) if uidPtr == nil { c.JSON(http.StatusUnauthorized, gin.H{"error": "not authenticated"}) return } code := strings.TrimSpace(c.PostForm("code")) if len(code) != 6 || !isAllDigits(code) { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid code format"}) return } var secret string if err := h.authSvc.DB.QueryRow(`SELECT secret FROM mfa_enrollments WHERE user_id = ?`, *uidPtr).Scan(&secret); err != nil || secret == "" { c.JSON(http.StatusBadRequest, gin.H{"error": "no enrollment found"}) return } if !verifyTOTP(secret, code, time.Now()) { c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid code"}) return } // move secret to users and delete enrollment if _, err := h.authSvc.DB.Exec(`UPDATE users SET mfa_secret = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?`, secret, *uidPtr); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } _, _ = h.authSvc.DB.Exec(`DELETE FROM mfa_enrollments WHERE user_id = ?`, *uidPtr) c.JSON(http.StatusOK, gin.H{"success": true}) } // TOTP helpers (SHA1, 30s window, 6 digits) func generateBase32Secret() (string, error) { // 20 random bytes -> base32 without padding b := make([]byte, 20) if _, err := rand.Read(b); err != nil { return "", err } enc := base32.StdEncoding.WithPadding(base32.NoPadding) return enc.EncodeToString(b), nil } func hotp(secret []byte, counter uint64) string { // HMAC-SHA1 var buf [8]byte binary.BigEndian.PutUint64(buf[:], counter) mac := hmac.New(sha1.New, secret) mac.Write(buf[:]) sum := mac.Sum(nil) // dynamic truncation offset := sum[len(sum)-1] & 0x0F code := (uint32(sum[offset])&0x7F)<<24 | (uint32(sum[offset+1])&0xFF)<<16 | (uint32(sum[offset+2])&0xFF)<<8 | (uint32(sum[offset+3]) & 0xFF) return fmt.Sprintf("%06d", code%1000000) } func verifyTOTP(base32Secret, code string, t time.Time) bool { enc := base32.StdEncoding.WithPadding(base32.NoPadding) key, err := enc.DecodeString(strings.ToUpper(base32Secret)) if err != nil { return false } timestep := uint64(t.Unix() / 30) // allow +/- 1 window candidates := []string{ hotp(key, timestep-1), hotp(key, timestep), hotp(key, timestep+1), } for _, c := range candidates { if subtle.ConstantTimeCompare([]byte(c), []byte(code)) == 1 { return true } } return false } // LoginPost processes the login form func (h *Handlers) LoginPost(c *gin.Context) { username := c.PostForm("username") password := c.PostForm("password") returnTo := strings.TrimSpace(c.PostForm("return_to")) user, err := h.authSvc.Authenticate(username, password) if err != nil { token, _ := c.Get("csrf_token") // record failed password attempt h.recordFailedAttempt(c, "password", username, nil) c.HTML(http.StatusUnauthorized, "login", gin.H{ "app_name": h.config.AppName, "csrf_token": token, "error": err.Error(), "return_to": returnTo, "ContentTemplate": "login_content", "ScriptsTemplate": "login_scripts", "Page": "login", }) return } // If user has MFA enabled, require OTP before setting user_id if user.MFASecret.Valid && user.MFASecret.String != "" { session, _ := h.store.Get(c.Request, sessionCookieName) session.Values["mfa_user_id"] = user.ID if rt := sanitizeReturnTo(h.config.URLPrefix, returnTo); rt != "" { session.Values["return_to"] = rt } _ = session.Save(c.Request, c.Writer) c.Redirect(http.StatusFound, h.config.URLPrefix+"/editor/mfa") return } // Do NOT automatically force MFA setup just because an enrollment row exists. // Some deployments may leave stale enrollment rows; we only require MFA when // the user actually has MFA enabled (mfa_secret set) or when they explicitly // navigate to setup from profile. // Create normal session session, _ := h.store.Get(c.Request, sessionCookieName) session.Values["user_id"] = user.ID _ = session.Save(c.Request, c.Writer) // Redirect to requested page if provided and safe; otherwise home if rt := sanitizeReturnTo(h.config.URLPrefix, returnTo); rt != "" { c.Redirect(http.StatusFound, rt) } else { c.Redirect(http.StatusFound, h.config.URLPrefix+"/") } } // LogoutPost clears the session func (h *Handlers) LogoutPost(c *gin.Context) { session, _ := h.store.Get(c.Request, sessionCookieName) session.Options.MaxAge = -1 _ = session.Save(c.Request, c.Writer) c.Redirect(http.StatusFound, h.config.URLPrefix+"/editor/login") } // sanitizeReturnTo ensures the provided return_to is a safe in-app path. // It rejects absolute URLs and protocol-relative URLs. When URLPrefix is set, // it enforces that the destination stays within that prefix; if a bare // "/..." path is provided, it will be rewritten to include the prefix. func sanitizeReturnTo(prefix, v string) string { v = strings.TrimSpace(v) if v == "" { return "" } // Disallow absolute and protocol-relative URLs if strings.HasPrefix(v, "//") { return "" } if u, err := url.Parse(v); err != nil || (u != nil && u.IsAbs()) { return "" } // Must be a path if !strings.HasPrefix(v, "/") { v = "/" + v } // Enforce prefix containment when configured if prefix != "" { if strings.HasPrefix(v, prefix+"/") || v == prefix || v == prefix+"/" { return v } // If it's a root-relative path without prefix, rewrite into prefix return prefix + v } return v }