316 lines
8.0 KiB
Go
316 lines
8.0 KiB
Go
package webclient
|
|
|
|
import (
|
|
"context"
|
|
"crypto/hmac"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
appCrypto "ghb.freebede.com/nahakubuilder/mailgosend/internal/crypto"
|
|
"ghb.freebede.com/nahakubuilder/mailgosend/internal/totp"
|
|
)
|
|
|
|
const preAuthCookieName = "mailgo_preauth"
|
|
const preAuthMaxAge = 300 // 5 minutes
|
|
|
|
// ---- Login (step 1: password) ----
|
|
|
|
func (s *Server) loginGet(w http.ResponseWriter, r *http.Request) {
|
|
if s.currentUser(r) != nil {
|
|
http.Redirect(w, r, "/", http.StatusSeeOther)
|
|
return
|
|
}
|
|
flash, errMsg := flashFrom(r)
|
|
s.render(w, "login", struct{ basePage }{
|
|
basePage: basePage{Flash: flash, Error: errMsg},
|
|
})
|
|
}
|
|
|
|
func (s *Server) loginPost(w http.ResponseWriter, r *http.Request) {
|
|
if err := r.ParseForm(); err != nil {
|
|
http.Error(w, "bad request", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
email := strings.ToLower(strings.TrimSpace(r.FormValue("email")))
|
|
password := r.FormValue("password")
|
|
clientIP := realIP(r)
|
|
|
|
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
// Brute-force check.
|
|
if s.deps.Brute != nil && s.deps.Cfg.BruteMaxTries > 0 {
|
|
banned, err := s.deps.Brute.IsBanned(ctx, clientIP)
|
|
if err != nil {
|
|
log.Printf("[webmail] brute check: %v", err)
|
|
}
|
|
if banned {
|
|
redirect(w, r, "/login", "", "Too many failed attempts. Try again later.")
|
|
return
|
|
}
|
|
}
|
|
|
|
if email == "" || password == "" || len(email) > 254 || len(password) > 1024 {
|
|
s.recordAttempt(ctx, clientIP, email, false)
|
|
redirect(w, r, "/login", "", "Invalid credentials.")
|
|
return
|
|
}
|
|
|
|
user, err := s.deps.DB.GetUserByEmail(ctx, email)
|
|
if err != nil {
|
|
log.Printf("[webmail] login db: %v", err)
|
|
redirect(w, r, "/login", "", "Internal error.")
|
|
return
|
|
}
|
|
|
|
if user == nil || !user.Enabled {
|
|
s.recordAttempt(ctx, clientIP, email, false)
|
|
redirect(w, r, "/login", "", "Invalid credentials.")
|
|
return
|
|
}
|
|
|
|
if err := appCrypto.CheckPassword(user.PasswordHash, password); err != nil {
|
|
s.recordAttempt(ctx, clientIP, email, false)
|
|
redirect(w, r, "/login", "", "Invalid credentials.")
|
|
return
|
|
}
|
|
|
|
// Password OK. Check MFA.
|
|
if user.MFAEnabled && len(user.MFASecretEnc) > 0 {
|
|
// Issue pre-auth cookie and redirect to TOTP step.
|
|
s.setPreAuthCookie(w, user.ID)
|
|
http.Redirect(w, r, "/login/mfa", http.StatusSeeOther)
|
|
return
|
|
}
|
|
|
|
// No MFA — create session directly.
|
|
s.recordAttempt(ctx, clientIP, email, true)
|
|
if _, err := s.deps.Sessions.Create(w, r, user.ID); err != nil {
|
|
log.Printf("[webmail] session create: %v", err)
|
|
redirect(w, r, "/login", "", "Session error.")
|
|
return
|
|
}
|
|
s.deps.DB.UpdateLastLogin(ctx, user.ID)
|
|
http.Redirect(w, r, "/", http.StatusSeeOther)
|
|
}
|
|
|
|
// ---- Login (step 2: TOTP) ----
|
|
|
|
func (s *Server) mfaGet(w http.ResponseWriter, r *http.Request) {
|
|
if s.currentUser(r) != nil {
|
|
http.Redirect(w, r, "/", http.StatusSeeOther)
|
|
return
|
|
}
|
|
if _, ok := s.preAuthUserID(r); !ok {
|
|
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
|
return
|
|
}
|
|
flash, errMsg := flashFrom(r)
|
|
s.render(w, "mfa", struct{ basePage }{
|
|
basePage: basePage{Flash: flash, Error: errMsg},
|
|
})
|
|
}
|
|
|
|
func (s *Server) mfaPost(w http.ResponseWriter, r *http.Request) {
|
|
if err := r.ParseForm(); err != nil {
|
|
http.Error(w, "bad request", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
userID, ok := s.preAuthUserID(r)
|
|
if !ok {
|
|
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
|
return
|
|
}
|
|
|
|
clientIP := realIP(r)
|
|
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
// Brute-force check (TOTP codes are brute-forceable too).
|
|
if s.deps.Brute != nil && s.deps.Cfg.BruteMaxTries > 0 {
|
|
banned, _ := s.deps.Brute.IsBanned(ctx, clientIP)
|
|
if banned {
|
|
clearPreAuth(w)
|
|
redirect(w, r, "/login", "", "Too many failed attempts. Try again later.")
|
|
return
|
|
}
|
|
}
|
|
|
|
code := strings.TrimSpace(r.FormValue("code"))
|
|
if len(code) == 0 || len(code) > 64 {
|
|
s.recordAttempt(ctx, clientIP, "", false)
|
|
redirect(w, r, "/login/mfa", "", "Invalid code.")
|
|
return
|
|
}
|
|
|
|
user, err := s.deps.DB.GetUserByID(ctx, userID)
|
|
if err != nil || user == nil || !user.Enabled || !user.MFAEnabled {
|
|
clearPreAuth(w)
|
|
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
|
return
|
|
}
|
|
|
|
// Decrypt the TOTP secret.
|
|
secretRaw, err := s.deps.Crypt.DecryptForUser(user.ID, "totp", user.MFASecretEnc)
|
|
if err != nil {
|
|
log.Printf("[webmail] mfa decrypt: %v", err)
|
|
clearPreAuth(w)
|
|
redirect(w, r, "/login", "", "MFA error. Please try again.")
|
|
return
|
|
}
|
|
|
|
var authenticated bool
|
|
|
|
if len(code) == totp.Digits {
|
|
// TOTP path.
|
|
authenticated = totp.Verify(string(secretRaw), code)
|
|
} else {
|
|
// Recovery code path.
|
|
if len(user.RecoveryCodesEnc) > 0 {
|
|
codesRaw, cerr := s.deps.Crypt.DecryptForUser(user.ID, "recovery", user.RecoveryCodesEnc)
|
|
if cerr == nil {
|
|
codes, cerr := totp.DecodeRecoveryCodes(codesRaw)
|
|
if cerr == nil {
|
|
updated, consumed := totp.ConsumeRecoveryCode(codes, code)
|
|
if consumed {
|
|
authenticated = true
|
|
newEnc, encErr := s.deps.Crypt.EncryptForUser(user.ID, "recovery", mustMarshalCodes(updated))
|
|
if encErr == nil {
|
|
_ = s.deps.DB.SetRecoveryCodes(ctx, user.ID, newEnc)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if !authenticated {
|
|
s.recordAttempt(ctx, clientIP, user.Email, false)
|
|
redirect(w, r, "/login/mfa", "", "Invalid code.")
|
|
return
|
|
}
|
|
|
|
clearPreAuth(w)
|
|
s.recordAttempt(ctx, clientIP, user.Email, true)
|
|
if _, err := s.deps.Sessions.Create(w, r, user.ID); err != nil {
|
|
log.Printf("[webmail] session create: %v", err)
|
|
redirect(w, r, "/login", "", "Session error.")
|
|
return
|
|
}
|
|
s.deps.DB.UpdateLastLogin(ctx, user.ID)
|
|
http.Redirect(w, r, "/", http.StatusSeeOther)
|
|
}
|
|
|
|
// ---- Logout ----
|
|
|
|
func (s *Server) logout(w http.ResponseWriter, r *http.Request) {
|
|
if err := s.deps.Sessions.Destroy(w, r); err != nil {
|
|
log.Printf("[webmail] session destroy: %v", err)
|
|
}
|
|
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
|
}
|
|
|
|
// ---- Pre-auth cookie ----
|
|
|
|
func (s *Server) setPreAuthCookie(w http.ResponseWriter, userID int64) {
|
|
ts := fmt.Sprintf("%d", time.Now().Unix())
|
|
uid := fmt.Sprintf("%d", userID)
|
|
payload := uid + "|" + ts
|
|
mac := preAuthMAC(s.deps.Cfg.SessionSecret, payload)
|
|
value := payload + "|" + mac
|
|
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: preAuthCookieName,
|
|
Value: value,
|
|
Path: "/login",
|
|
MaxAge: preAuthMaxAge,
|
|
HttpOnly: true,
|
|
Secure: true,
|
|
SameSite: http.SameSiteStrictMode,
|
|
})
|
|
}
|
|
|
|
func (s *Server) preAuthUserID(r *http.Request) (int64, bool) {
|
|
c, err := r.Cookie(preAuthCookieName)
|
|
if err != nil {
|
|
return 0, false
|
|
}
|
|
parts := strings.SplitN(c.Value, "|", 3)
|
|
if len(parts) != 3 {
|
|
return 0, false
|
|
}
|
|
uid, ts, gotMAC := parts[0], parts[1], parts[2]
|
|
payload := uid + "|" + ts
|
|
|
|
wantMAC := preAuthMAC(s.deps.Cfg.SessionSecret, payload)
|
|
if !hmac.Equal([]byte(gotMAC), []byte(wantMAC)) {
|
|
return 0, false
|
|
}
|
|
|
|
tsi, err := strconv.ParseInt(ts, 10, 64)
|
|
if err != nil || time.Now().Unix()-tsi > preAuthMaxAge {
|
|
return 0, false
|
|
}
|
|
|
|
id, err := strconv.ParseInt(uid, 10, 64)
|
|
if err != nil || id <= 0 {
|
|
return 0, false
|
|
}
|
|
return id, true
|
|
}
|
|
|
|
func clearPreAuth(w http.ResponseWriter) {
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: preAuthCookieName,
|
|
Value: "",
|
|
Path: "/login",
|
|
MaxAge: -1,
|
|
HttpOnly: true,
|
|
Secure: true,
|
|
SameSite: http.SameSiteStrictMode,
|
|
})
|
|
}
|
|
|
|
func preAuthMAC(secret []byte, payload string) string {
|
|
mac := hmac.New(sha256.New, secret)
|
|
mac.Write([]byte(payload))
|
|
return hex.EncodeToString(mac.Sum(nil))
|
|
}
|
|
|
|
func mustMarshalCodes(codes []string) []byte {
|
|
data, err := totp.EncodeRecoveryCodes(codes)
|
|
if err != nil {
|
|
panic(fmt.Sprintf("totp marshal: %v", err))
|
|
}
|
|
return data
|
|
}
|
|
|
|
// ---- Shared helpers ----
|
|
|
|
func (s *Server) recordAttempt(ctx context.Context, ip, email string, success bool) {
|
|
if s.deps.Brute != nil {
|
|
s.deps.Brute.RecordAttempt(ctx, ip, email, success)
|
|
}
|
|
}
|
|
|
|
func realIP(r *http.Request) string {
|
|
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
|
parts := strings.Split(xff, ",")
|
|
ip := strings.TrimSpace(parts[0])
|
|
if net.ParseIP(ip) != nil {
|
|
return ip
|
|
}
|
|
}
|
|
host, _, _ := net.SplitHostPort(r.RemoteAddr)
|
|
return host
|
|
}
|