Files
mailgosend/internal/webclient/handlers.go
T

905 lines
24 KiB
Go
Raw Normal View History

2026-05-22 06:06:44 +00:00
package webclient
import (
"context"
"fmt"
"log"
"net/http"
"net/mail"
"strconv"
"strings"
"time"
appCrypto "ghb.freebede.com/nahakubuilder/mailgosend/internal/crypto"
"ghb.freebede.com/nahakubuilder/mailgosend/internal/db"
"ghb.freebede.com/nahakubuilder/mailgosend/internal/models"
"ghb.freebede.com/nahakubuilder/mailgosend/internal/storage"
"ghb.freebede.com/nahakubuilder/mailgosend/internal/totp"
)
const messagesPerPage = 50
// ---- Root redirect ----
func (s *Server) rootRedirect(w http.ResponseWriter, r *http.Request) {
user := s.currentUser(r)
ctx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
defer cancel()
inbox, err := s.deps.DB.GetMailboxByType(ctx, user.ID, models.MailboxInbox)
if err != nil || inbox == nil {
redirect(w, r, "/settings", "", "No inbox found. Contact your admin.")
return
}
http.Redirect(w, r, fmt.Sprintf("/mail/%d", inbox.ID), http.StatusSeeOther)
}
// ---- Mailbox view (message list) ----
type mailboxPage struct {
basePage
CurrentBox *models.Mailbox
Messages []*db.IMAPMessage
Query string
PrevPage uint32 // UID of last message on prev page (0 = none)
NextPage uint32 // UID of first message on next page (0 = none)
TotalCount int
}
func (s *Server) mailboxView(w http.ResponseWriter, r *http.Request) {
user := s.currentUser(r)
boxID := pathID(r, "boxid")
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
defer cancel()
box, err := s.deps.DB.GetMailboxByID(ctx, boxID)
if err != nil || box == nil || box.UserID != user.ID {
http.NotFound(w, r)
return
}
msgs, err := s.deps.DB.ListIMAPMessages(ctx, boxID)
if err != nil {
log.Printf("[webmail] list messages: %v", err)
msgs = nil
}
// Reverse for newest-first.
for i, j := 0, len(msgs)-1; i < j; i, j = i+1, j-1 {
msgs[i], msgs[j] = msgs[j], msgs[i]
}
// Apply search filter.
query := strings.TrimSpace(r.URL.Query().Get("q"))
if query != "" {
ql := strings.ToLower(query)
filtered := msgs[:0]
for _, m := range msgs {
if strings.Contains(strings.ToLower(m.Subject), ql) ||
strings.Contains(strings.ToLower(m.FromEmail), ql) ||
strings.Contains(strings.ToLower(m.FromName), ql) {
filtered = append(filtered, m)
}
}
msgs = filtered
}
total := len(msgs)
// Pagination: "before" UID (load messages with UID < before).
var prevPage, nextPage uint32
beforeUID, _ := strconv.ParseUint(r.URL.Query().Get("before"), 10, 32)
if beforeUID > 0 {
// Filter messages with UID < beforeUID.
cutEnd := len(msgs)
for i, m := range msgs {
if uint64(m.UID) < beforeUID {
cutEnd = i + messagesPerPage
if cutEnd > len(msgs) {
cutEnd = len(msgs)
}
break
}
}
// Find start.
cutStart := 0
for i, m := range msgs {
if uint64(m.UID) < beforeUID {
cutStart = i
break
}
}
if cutStart > 0 {
prevPage = msgs[cutStart-1].UID
}
msgs = msgs[cutStart:cutEnd]
} else {
// First page: newest messagesPerPage.
if len(msgs) > messagesPerPage {
nextPage = msgs[messagesPerPage].UID
msgs = msgs[:messagesPerPage]
}
}
flash, errMsg := flashFrom(r)
base := s.newBase(r, flash, errMsg)
base.CurrentBoxID = boxID
s.render(w, "mail", mailboxPage{
basePage: base,
CurrentBox: box,
Messages: msgs,
Query: query,
PrevPage: prevPage,
NextPage: nextPage,
TotalCount: total,
})
}
// ---- Message view ----
type messagePage struct {
basePage
CurrentBox *models.Mailbox
Message *db.IMAPMessage
Body *storage.BodyParts
PrevUID uint32
NextUID uint32
}
func (s *Server) messageView(w http.ResponseWriter, r *http.Request) {
user := s.currentUser(r)
boxID := pathID(r, "boxid")
uid64, err := strconv.ParseUint(r.PathValue("uid"), 10, 32)
if err != nil {
http.NotFound(w, r)
return
}
uid := uint32(uid64)
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
defer cancel()
box, err := s.deps.DB.GetMailboxByID(ctx, boxID)
if err != nil || box == nil || box.UserID != user.ID {
http.NotFound(w, r)
return
}
msg, err := s.deps.DB.GetIMAPMessageByUID(ctx, boxID, uid)
if err != nil || msg == nil {
http.NotFound(w, r)
return
}
// Auto-mark as read.
if !msg.IsRead {
if err := s.deps.DB.SetMessageFlags(ctx, msg.ID, true, msg.IsStarred, msg.IsDraft, msg.Flags); err != nil {
log.Printf("[webmail] mark read: %v", err)
} else {
msg.IsRead = true
}
}
// Decrypt body parts.
body, err := s.deps.Store.GetBodyParts(ctx, user.ID, msg.ID)
if err != nil {
log.Printf("[webmail] get body: %v", err)
body = &storage.BodyParts{Text: "[Error loading message body]"}
}
// Prev/next UIDs for navigation.
msgs, _ := s.deps.DB.ListIMAPMessages(ctx, boxID)
var prevUID, nextUID uint32
for i, m := range msgs {
if m.UID == uid {
if i > 0 {
nextUID = msgs[i-1].UID // list is ascending, so "next newer"
}
if i < len(msgs)-1 {
prevUID = msgs[i+1].UID // older
}
break
}
}
flash, errMsg := flashFrom(r)
base := s.newBase(r, flash, errMsg)
base.CurrentBoxID = boxID
s.render(w, "message", messagePage{
basePage: base,
CurrentBox: box,
Message: msg,
Body: body,
PrevUID: prevUID,
NextUID: nextUID,
})
}
// ---- Message actions ----
func (s *Server) messageFlag(w http.ResponseWriter, r *http.Request) {
if !s.validateCSRF(w, r) {
return
}
user := s.currentUser(r)
boxID := pathID(r, "boxid")
uid64, _ := strconv.ParseUint(r.PathValue("uid"), 10, 32)
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
defer cancel()
box, err := s.deps.DB.GetMailboxByID(ctx, boxID)
if err != nil || box == nil || box.UserID != user.ID {
http.NotFound(w, r)
return
}
msg, err := s.deps.DB.GetIMAPMessageByUID(ctx, boxID, uint32(uid64))
if err != nil || msg == nil {
http.NotFound(w, r)
return
}
flag := r.FormValue("flag")
isRead := msg.IsRead
isStar := msg.IsStarred
switch flag {
case "read":
isRead = !isRead
case "star":
isStar = !isStar
}
if err := s.deps.DB.SetMessageFlags(ctx, msg.ID, isRead, isStar, msg.IsDraft, msg.Flags); err != nil {
log.Printf("[webmail] flag: %v", err)
}
// Return to message or mailbox depending on referrer.
returnTo := r.FormValue("return")
if returnTo == "" {
returnTo = fmt.Sprintf("/mail/%d/%d", boxID, uid64)
}
http.Redirect(w, r, returnTo, http.StatusSeeOther)
}
func (s *Server) messageTrash(w http.ResponseWriter, r *http.Request) {
if !s.validateCSRF(w, r) {
return
}
user := s.currentUser(r)
boxID := pathID(r, "boxid")
uid64, _ := strconv.ParseUint(r.PathValue("uid"), 10, 32)
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
defer cancel()
box, err := s.deps.DB.GetMailboxByID(ctx, boxID)
if err != nil || box == nil || box.UserID != user.ID {
http.NotFound(w, r)
return
}
msg, err := s.deps.DB.GetIMAPMessageByUID(ctx, boxID, uint32(uid64))
if err != nil || msg == nil {
http.NotFound(w, r)
return
}
if box.Type == models.MailboxTrash {
// Already in trash: hard delete.
if err := s.deps.DB.SoftDeleteMessage(ctx, msg.ID); err != nil {
log.Printf("[webmail] soft delete: %v", err)
}
if _, err := s.deps.DB.HardDeleteMessages(ctx, boxID); err != nil {
log.Printf("[webmail] hard delete: %v", err)
}
redirect(w, r, fmt.Sprintf("/mail/%d", boxID), "Message permanently deleted.", "")
return
}
// Move to trash: copy then soft-delete original.
trashBox, err := s.deps.DB.GetMailboxByType(ctx, user.ID, models.MailboxTrash)
if err != nil || trashBox == nil {
redirect(w, r, fmt.Sprintf("/mail/%d", boxID), "", "Trash folder not found.")
return
}
if _, err := s.deps.DB.CopyMessageToMailbox(ctx, msg.ID, trashBox.ID, user.ID); err != nil {
log.Printf("[webmail] copy to trash: %v", err)
redirect(w, r, fmt.Sprintf("/mail/%d", boxID), "", "Move to trash failed.")
return
}
if err := s.deps.DB.SoftDeleteMessage(ctx, msg.ID); err != nil {
log.Printf("[webmail] soft delete orig: %v", err)
}
redirect(w, r, fmt.Sprintf("/mail/%d", boxID), "Moved to trash.", "")
}
func (s *Server) messageMove(w http.ResponseWriter, r *http.Request) {
if !s.validateCSRF(w, r) {
return
}
user := s.currentUser(r)
boxID := pathID(r, "boxid")
uid64, _ := strconv.ParseUint(r.PathValue("uid"), 10, 32)
destBoxID, _ := strconv.ParseInt(r.FormValue("dest_box"), 10, 64)
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
defer cancel()
box, err := s.deps.DB.GetMailboxByID(ctx, boxID)
if err != nil || box == nil || box.UserID != user.ID {
http.NotFound(w, r)
return
}
destBox, err := s.deps.DB.GetMailboxByID(ctx, destBoxID)
if err != nil || destBox == nil || destBox.UserID != user.ID {
redirect(w, r, fmt.Sprintf("/mail/%d", boxID), "", "Destination folder not found.")
return
}
msg, err := s.deps.DB.GetIMAPMessageByUID(ctx, boxID, uint32(uid64))
if err != nil || msg == nil {
http.NotFound(w, r)
return
}
if _, err := s.deps.DB.CopyMessageToMailbox(ctx, msg.ID, destBoxID, user.ID); err != nil {
redirect(w, r, fmt.Sprintf("/mail/%d", boxID), "", "Move failed.")
return
}
if err := s.deps.DB.SoftDeleteMessage(ctx, msg.ID); err != nil {
log.Printf("[webmail] soft delete on move: %v", err)
}
redirect(w, r, fmt.Sprintf("/mail/%d", boxID), "Moved.", "")
}
func (s *Server) mailboxExpunge(w http.ResponseWriter, r *http.Request) {
if !s.validateCSRF(w, r) {
return
}
user := s.currentUser(r)
boxID := pathID(r, "boxid")
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
defer cancel()
box, err := s.deps.DB.GetMailboxByID(ctx, boxID)
if err != nil || box == nil || box.UserID != user.ID {
http.NotFound(w, r)
return
}
// Mark all as deleted first (for boxes that only soft-delete).
if _, err := s.deps.DB.HardDeleteMessages(ctx, boxID); err != nil {
log.Printf("[webmail] expunge: %v", err)
redirect(w, r, fmt.Sprintf("/mail/%d", boxID), "", "Expunge failed.")
return
}
redirect(w, r, fmt.Sprintf("/mail/%d", boxID), "Expunged.", "")
}
// ---- Compose ----
type composePage struct {
basePage
To string
CC string
Subject string
BodyText string
InReplyTo string
References string
}
func (s *Server) composeGet(w http.ResponseWriter, r *http.Request) {
user := s.currentUser(r)
flash, errMsg := flashFrom(r)
base := s.newBase(r, flash, errMsg)
p := composePage{basePage: base}
// Handle reply/forward.
action := r.URL.Query().Get("action")
boxID, _ := strconv.ParseInt(r.URL.Query().Get("boxid"), 10, 64)
uid64, _ := strconv.ParseUint(r.URL.Query().Get("uid"), 10, 32)
if (action == "reply" || action == "forward") && boxID > 0 && uid64 > 0 {
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
defer cancel()
box, err := s.deps.DB.GetMailboxByID(ctx, boxID)
if err == nil && box != nil && box.UserID == user.ID {
orig, err := s.deps.DB.GetIMAPMessageByUID(ctx, boxID, uint32(uid64))
if err == nil && orig != nil {
body, err := s.deps.Store.GetBodyParts(ctx, user.ID, orig.ID)
if err != nil {
body = &storage.BodyParts{}
}
if action == "reply" {
replyAddr := orig.FromEmail
if orig.FromName != "" {
replyAddr = orig.FromName + " <" + orig.FromEmail + ">"
}
p.To = replyAddr
p.Subject = reSubject(orig.Subject)
p.InReplyTo = orig.MessageID
p.References = orig.MessageID
p.BodyText = quoteBody(orig.FromEmail, orig.Date, body.Text)
} else {
p.Subject = fwdSubject(orig.Subject)
p.BodyText = fwdBody(orig.FromEmail, orig.ToList, orig.Date, orig.Subject, body.Text)
}
}
}
}
s.render(w, "compose", p)
}
func (s *Server) composeSend(w http.ResponseWriter, r *http.Request) {
if !s.validateCSRF(w, r) {
return
}
user := s.currentUser(r)
ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second)
defer cancel()
toRaw := strings.TrimSpace(r.FormValue("to"))
ccRaw := strings.TrimSpace(r.FormValue("cc"))
bccRaw := strings.TrimSpace(r.FormValue("bcc"))
subject := strings.TrimSpace(r.FormValue("subject"))
bodyText := r.FormValue("body")
inReplyTo := strings.TrimSpace(r.FormValue("in_reply_to"))
references := strings.TrimSpace(r.FormValue("references"))
if toRaw == "" {
redirect(w, r, "/compose", "", "To field is required.")
return
}
if len(subject) > 998 {
subject = subject[:998]
}
if len(bodyText) > 10*1024*1024 {
redirect(w, r, "/compose", "", "Message body too large (max 10 MB).")
return
}
toAddrs, err := parseAddressList(toRaw)
if err != nil || len(toAddrs) == 0 {
redirect(w, r, "/compose", "", "Invalid To address: "+err.Error())
return
}
ccAddrs, _ := parseAddressList(ccRaw)
bccAddrs, _ := parseAddressList(bccRaw)
// Build from address.
displayName := user.DisplayName
if displayName == "" {
displayName = user.Username
}
fromAddr := mail.Address{Name: displayName, Address: user.Email}
fromRFC := fromAddr.String()
allRecipients := append(append(toAddrs, ccAddrs...), bccAddrs...)
p := &ComposeParams{
From: fromRFC,
FromEmail: user.Email,
To: addressListRFC(toAddrs),
CC: addressListRFC(ccAddrs),
BCC: addressListRFC(bccAddrs),
Subject: subject,
BodyText: bodyText,
InReplyTo: inReplyTo,
References: references,
}
raw, err := BuildRFC5322(p)
if err != nil {
log.Printf("[webmail] build message: %v", err)
redirect(w, r, "/compose", "", "Failed to build message.")
return
}
incomingMsg := &storage.IncomingMessage{
Raw: raw,
FromEmail: user.Email,
FromName: displayName,
ToList: strings.Join(toAddrs, ", "),
CCList: strings.Join(ccAddrs, ", "),
BCCList: strings.Join(bccAddrs, ", "),
Subject: subject,
Date: time.Now().UTC(),
MessageID: p.MessageID,
}
// Save to Sent.
if err := s.saveSentCopy(ctx, user.ID, raw, incomingMsg); err != nil {
log.Printf("[webmail] save sent: %v", err)
}
// Deliver to each recipient.
var deliveryErrors []string
for _, rcpt := range allRecipients {
rcptDomain := ""
if idx := strings.LastIndex(rcpt, "@"); idx >= 0 {
rcptDomain = rcpt[idx+1:]
}
// Check if local domain.
isLocal := false
if rcptDomain != "" {
isLocal, _ = s.deps.DB.IsLocalDomain(ctx, rcptDomain)
}
if isLocal {
if err := s.deliverLocally(ctx, rcpt, raw, incomingMsg); err != nil {
log.Printf("[webmail] local deliver %s: %v", rcpt, err)
deliveryErrors = append(deliveryErrors, rcpt+": "+err.Error())
}
} else {
if err := s.enqueueForDelivery(ctx, user.Email, rcpt, raw, p.MessageID); err != nil {
log.Printf("[webmail] enqueue %s: %v", rcpt, err)
deliveryErrors = append(deliveryErrors, rcpt+": queued (may fail)")
}
}
}
if len(deliveryErrors) > 0 {
redirect(w, r, "/compose", "", "Sent with errors: "+strings.Join(deliveryErrors, "; "))
return
}
// Redirect to Sent folder.
sentBox, _ := s.deps.DB.GetMailboxByType(ctx, user.ID, models.MailboxSent)
if sentBox != nil {
redirect(w, r, fmt.Sprintf("/mail/%d", sentBox.ID), "Message sent.", "")
} else {
redirect(w, r, "/", "Message sent.", "")
}
}
// ---- Settings ----
type settingsPage struct {
basePage
AccountUser *models.User
}
func (s *Server) settingsGet(w http.ResponseWriter, r *http.Request) {
user := s.currentUser(r)
flash, errMsg := flashFrom(r)
s.render(w, "settings", settingsPage{
basePage: s.newBase(r, flash, errMsg),
AccountUser: user,
})
}
func (s *Server) settingsPassword(w http.ResponseWriter, r *http.Request) {
if !s.validateCSRF(w, r) {
return
}
user := s.currentUser(r)
current := r.FormValue("current_password")
newPw := r.FormValue("new_password")
confirm := r.FormValue("confirm_password")
if err := appCrypto.CheckPassword(user.PasswordHash, current); err != nil {
redirect(w, r, "/settings", "", "Current password is incorrect.")
return
}
if len(newPw) < 8 || len(newPw) > 1024 {
redirect(w, r, "/settings", "", "New password must be 8-1024 characters.")
return
}
if newPw != confirm {
redirect(w, r, "/settings", "", "Passwords do not match.")
return
}
hash, err := appCrypto.HashPassword(newPw)
if err != nil {
redirect(w, r, "/settings", "", "Password error.")
return
}
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
defer cancel()
if err := s.deps.DB.SetUserPassword(ctx, user.ID, hash); err != nil {
log.Printf("[webmail] set password: %v", err)
redirect(w, r, "/settings", "", "Failed to update password.")
return
}
// Log out all other sessions for security.
if err := s.deps.Sessions.DestroyAll(ctx, user.ID); err != nil {
log.Printf("[webmail] destroy sessions: %v", err)
}
// Create fresh session for current request.
if _, err := s.deps.Sessions.Create(w, r, user.ID); err != nil {
log.Printf("[webmail] re-create session: %v", err)
}
redirect(w, r, "/settings", "Password updated. All other sessions logged out.", "")
}
func (s *Server) settingsDisplay(w http.ResponseWriter, r *http.Request) {
if !s.validateCSRF(w, r) {
return
}
user := s.currentUser(r)
displayName := strings.TrimSpace(r.FormValue("display_name"))
if len(displayName) > 255 {
displayName = displayName[:255]
}
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
defer cancel()
if err := s.deps.DB.SetUserDisplayName(ctx, user.ID, displayName); err != nil {
log.Printf("[webmail] display name: %v", err)
redirect(w, r, "/settings", "", "Update failed.")
return
}
redirect(w, r, "/settings", "Display name updated.", "")
}
// ---- helpers ----
func pathID(r *http.Request, key string) int64 {
id, _ := strconv.ParseInt(r.PathValue(key), 10, 64)
return id
}
func reSubject(s string) string {
low := strings.ToLower(s)
if strings.HasPrefix(low, "re:") {
return s
}
return "Re: " + s
}
func fwdSubject(s string) string {
low := strings.ToLower(s)
if strings.HasPrefix(low, "fwd:") || strings.HasPrefix(low, "fw:") {
return s
}
return "Fwd: " + s
}
func quoteBody(from string, date time.Time, text string) string {
var sb strings.Builder
sb.WriteString("\r\n\r\n")
sb.WriteString("On " + date.Format("Mon, 2 Jan 2006 at 15:04") + ", " + from + " wrote:\r\n")
for _, line := range strings.Split(text, "\n") {
sb.WriteString("> " + strings.TrimRight(line, "\r") + "\r\n")
}
return sb.String()
}
func fwdBody(from, to string, date time.Time, subject, text string) string {
var sb strings.Builder
sb.WriteString("\r\n\r\n-------- Forwarded Message --------\r\n")
sb.WriteString("From: " + from + "\r\n")
sb.WriteString("To: " + to + "\r\n")
sb.WriteString("Date: " + date.Format(time.RFC1123Z) + "\r\n")
sb.WriteString("Subject: " + subject + "\r\n\r\n")
sb.WriteString(text)
return sb.String()
}
// ---- TOTP / MFA enrollment ----
type mfaEnrollPage struct {
basePage
Secret string // base32, shown for manual entry
OTPAuthURI string // otpauth:// URI for QR code
}
// mfaEnrollGet generates a new TOTP secret, stores it unconfirmed in a signed
// session cookie, and renders the enrollment form.
func (s *Server) mfaEnrollGet(w http.ResponseWriter, r *http.Request) {
user := s.currentUser(r)
flash, errMsg := flashFrom(r)
secret, err := totp.GenerateSecret()
if err != nil {
log.Printf("[webmail] mfa generate: %v", err)
redirect(w, r, "/settings", "", "Failed to generate MFA secret.")
return
}
issuer := s.deps.Cfg.DefaultDomain
if issuer == "" {
issuer = s.deps.Cfg.Hostname
}
if issuer == "" {
issuer = "mailgosend"
}
uri := totp.OTPAuthURI(secret, user.Email, issuer)
// Stash pending secret in a short-lived signed cookie so that the POST
// can verify the code before persisting to the DB.
s.setPendingTOTPCookie(w, user.ID, secret)
s.render(w, "mfa_enroll", mfaEnrollPage{
basePage: s.newBase(r, flash, errMsg),
Secret: secret,
OTPAuthURI: uri,
})
}
// mfaEnrollPost verifies the TOTP code from the enrollment form and, on
// success, encrypts + persists the secret and generates recovery codes.
func (s *Server) mfaEnrollPost(w http.ResponseWriter, r *http.Request) {
if !s.validateCSRF(w, r) {
return
}
user := s.currentUser(r)
code := strings.TrimSpace(r.FormValue("code"))
if len(code) != totp.Digits {
redirect(w, r, "/settings/mfa/enroll", "", "Enter the 6-digit code from your authenticator app.")
return
}
// Read and validate pending secret cookie.
secret, ok := s.pendingTOTPSecret(r, user.ID)
if !ok {
redirect(w, r, "/settings/mfa/enroll", "", "Enrollment session expired. Please start over.")
return
}
if !totp.Verify(secret, code) {
redirect(w, r, "/settings/mfa/enroll", "", "Code did not match. Check your authenticator and try again.")
return
}
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
defer cancel()
// Encrypt and store secret.
encSecret, err := s.deps.Crypt.EncryptForUser(user.ID, "totp", []byte(secret))
if err != nil {
log.Printf("[webmail] mfa encrypt secret: %v", err)
redirect(w, r, "/settings", "", "MFA setup failed.")
return
}
if err := s.deps.DB.SetMFASecret(ctx, user.ID, encSecret); err != nil {
log.Printf("[webmail] mfa save secret: %v", err)
redirect(w, r, "/settings", "", "MFA setup failed.")
return
}
// Generate and store recovery codes.
codes, err := totp.GenerateRecoveryCodes()
if err != nil {
log.Printf("[webmail] mfa recovery codes: %v", err)
} else {
codesJSON, _ := totp.EncodeRecoveryCodes(codes)
encCodes, encErr := s.deps.Crypt.EncryptForUser(user.ID, "recovery", codesJSON)
if encErr == nil {
_ = s.deps.DB.SetRecoveryCodes(ctx, user.ID, encCodes)
}
}
if err := s.deps.DB.SetMFAEnabled(ctx, user.ID, true); err != nil {
log.Printf("[webmail] mfa enable: %v", err)
redirect(w, r, "/settings", "", "MFA setup failed.")
return
}
clearPendingTOTP(w)
redirect(w, r, "/settings", "Two-factor authentication enabled.", "")
}
// mfaDisable disables MFA after verifying current password.
func (s *Server) mfaDisable(w http.ResponseWriter, r *http.Request) {
if !s.validateCSRF(w, r) {
return
}
if err := r.ParseForm(); err != nil {
http.Error(w, "bad request", http.StatusBadRequest)
return
}
user := s.currentUser(r)
pw := r.FormValue("password")
if err := appCrypto.CheckPassword(user.PasswordHash, pw); err != nil {
redirect(w, r, "/settings", "", "Incorrect password. MFA not disabled.")
return
}
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
defer cancel()
if err := s.deps.DB.ClearMFA(ctx, user.ID); err != nil {
log.Printf("[webmail] mfa disable: %v", err)
redirect(w, r, "/settings", "", "Failed to disable MFA.")
return
}
redirect(w, r, "/settings", "Two-factor authentication disabled.", "")
}
// ---- Pending TOTP cookie (enrollment flow) ----
const pendingTOTPCookie = "mailgo_enroll"
const pendingTOTPMaxAge = 300
func (s *Server) setPendingTOTPCookie(w http.ResponseWriter, userID int64, secret string) {
uid := fmt.Sprintf("%d", userID)
ts := fmt.Sprintf("%d", time.Now().Unix())
payload := uid + "|" + ts + "|" + secret
mac := preAuthMAC(s.deps.Cfg.SessionSecret, payload)
value := payload + "|" + mac
http.SetCookie(w, &http.Cookie{
Name: pendingTOTPCookie,
Value: value,
Path: "/settings/mfa",
MaxAge: pendingTOTPMaxAge,
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteStrictMode,
})
}
func (s *Server) pendingTOTPSecret(r *http.Request, userID int64) (string, bool) {
c, err := r.Cookie(pendingTOTPCookie)
if err != nil {
return "", false
}
// Split from the right: last segment is MAC.
idx := strings.LastIndex(c.Value, "|")
if idx < 0 {
return "", false
}
payload := c.Value[:idx]
gotMAC := c.Value[idx+1:]
wantMAC := preAuthMAC(s.deps.Cfg.SessionSecret, payload)
if !strings.EqualFold(gotMAC, wantMAC) {
return "", false
}
// payload = uid|ts|secret
parts := strings.SplitN(payload, "|", 3)
if len(parts) != 3 {
return "", false
}
uid, ts, secret := parts[0], parts[1], parts[2]
id, err := strconv.ParseInt(uid, 10, 64)
if err != nil || id != userID {
return "", false
}
tsi, err := strconv.ParseInt(ts, 10, 64)
if err != nil || time.Now().Unix()-tsi > pendingTOTPMaxAge {
return "", false
}
if secret == "" {
return "", false
}
return secret, true
}
func clearPendingTOTP(w http.ResponseWriter) {
http.SetCookie(w, &http.Cookie{
Name: pendingTOTPCookie,
Value: "",
Path: "/settings/mfa",
MaxAge: -1,
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteStrictMode,
})
}