diff --git a/cmd/server/main.go b/cmd/server/main.go index ec71960..46e145f 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -150,10 +150,13 @@ func main() { // Profile / auth api.HandleFunc("/me", h.Auth.Me).Methods("GET") + api.HandleFunc("/profile", h.Auth.UpdateProfile).Methods("PUT") api.HandleFunc("/change-password", h.Auth.ChangePassword).Methods("POST") api.HandleFunc("/mfa/setup", h.Auth.MFASetupBegin).Methods("POST") api.HandleFunc("/mfa/confirm", h.Auth.MFASetupConfirm).Methods("POST") api.HandleFunc("/mfa/disable", h.Auth.MFADisable).Methods("POST") + api.HandleFunc("/ip-rules", h.Auth.GetUserIPRule).Methods("GET") + api.HandleFunc("/ip-rules", h.Auth.SetUserIPRule).Methods("PUT") // Providers (which OAuth providers are configured) api.HandleFunc("/providers", h.API.GetProviders).Methods("GET") diff --git a/internal/db/db.go b/internal/db/db.go index c76e05a..e5e0351 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -230,6 +230,21 @@ func (d *DB) Migrate() error { return fmt.Errorf("create ip_blocks: %w", err) } + // Per-user IP access rules. + // mode: "brute_skip" = skip brute force check for this user from listed IPs + // "allow_only" = only allow login from listed IPs (all others get 403) + if _, err := d.sql.Exec(`CREATE TABLE IF NOT EXISTS user_ip_rules ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, + mode TEXT NOT NULL DEFAULT 'brute_skip', + ip_list TEXT NOT NULL DEFAULT '', + created_at DATETIME DEFAULT (datetime('now')), + updated_at DATETIME DEFAULT (datetime('now')), + UNIQUE(user_id) + )`); err != nil { + return fmt.Errorf("create user_ip_rules: %w", err) + } + // Bootstrap admin account if no users exist return d.bootstrapAdmin() } @@ -1873,3 +1888,103 @@ func (d *DB) LookupCachedCountry(ip string) (country, countryCode string) { ).Scan(&country, &countryCode) return } + +// ---- Profile Updates ---- + +// UpdateUserEmail changes a user's email address. Returns error if already taken. +func (d *DB) UpdateUserEmail(userID int64, newEmail string) error { + _, err := d.sql.Exec( + `UPDATE users SET email=?, updated_at=datetime('now') WHERE id=?`, + newEmail, userID) + return err +} + +// UpdateUserUsername changes a user's display username. Returns error if already taken. +func (d *DB) UpdateUserUsername(userID int64, newUsername string) error { + _, err := d.sql.Exec( + `UPDATE users SET username=?, updated_at=datetime('now') WHERE id=?`, + newUsername, userID) + return err +} + +// ---- Per-User IP Rules ---- + +// UserIPRule holds per-user IP access settings. +type UserIPRule struct { + UserID int64 `json:"user_id"` + Mode string `json:"mode"` // "brute_skip" | "allow_only" | "disabled" + IPList string `json:"ip_list"` // comma-separated IPs +} + +// GetUserIPRule returns the IP rule for a user, or nil if none set. +func (d *DB) GetUserIPRule(userID int64) (*UserIPRule, error) { + row := d.sql.QueryRow(`SELECT user_id, mode, ip_list FROM user_ip_rules WHERE user_id=?`, userID) + r := &UserIPRule{} + if err := row.Scan(&r.UserID, &r.Mode, &r.IPList); err == sql.ErrNoRows { + return nil, nil + } else if err != nil { + return nil, err + } + return r, nil +} + +// SetUserIPRule upserts the IP rule for a user. +func (d *DB) SetUserIPRule(userID int64, mode, ipList string) error { + _, err := d.sql.Exec(` + INSERT INTO user_ip_rules (user_id, mode, ip_list, updated_at) + VALUES (?, ?, ?, datetime('now')) + ON CONFLICT(user_id) DO UPDATE SET + mode=excluded.mode, + ip_list=excluded.ip_list, + updated_at=datetime('now')`, + userID, mode, ipList) + return err +} + +// DeleteUserIPRule removes IP rules for a user (disables the feature). +func (d *DB) DeleteUserIPRule(userID int64) error { + _, err := d.sql.Exec(`DELETE FROM user_ip_rules WHERE user_id=?`, userID) + return err +} + +// CheckUserIPAccess evaluates per-user IP rules against a connecting IP. +// Returns: +// "allow" — rule says allow (brute_skip match or allow_only match) +// "deny" — allow_only mode and IP is not in list +// "skip_brute" — brute_skip mode and IP is in list (skip brute force check) +// "default" — no rule exists, fall through to global rules +func (d *DB) CheckUserIPAccess(userID int64, ip string) string { + rule, err := d.GetUserIPRule(userID) + if err != nil || rule == nil || rule.Mode == "disabled" || rule.IPList == "" { + return "default" + } + for _, listed := range splitIPs(rule.IPList) { + if listed == ip { + if rule.Mode == "allow_only" { + return "allow" + } + return "skip_brute" + } + } + // IP not in list + if rule.Mode == "allow_only" { + return "deny" + } + return "default" +} + +// SplitIPList splits a comma-separated IP string into trimmed, non-empty entries. +func SplitIPList(s string) []string { + return splitIPs(s) +} + +func splitIPs(s string) []string { + var result []string + for _, p := range strings.Split(s, ",") { + p = strings.TrimSpace(p) + if p != "" { + result = append(result, p) + } + } + return result +} diff --git a/internal/handlers/auth.go b/internal/handlers/auth.go index 24baf3b..be16e39 100644 --- a/internal/handlers/auth.go +++ b/internal/handlers/auth.go @@ -4,6 +4,7 @@ import ( "crypto/rand" "encoding/base64" "encoding/json" + "net" "net/http" "time" @@ -53,6 +54,17 @@ func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) { return } + // Per-user IP access check — evaluated before password to avoid timing leaks + switch h.db.CheckUserIPAccess(user.ID, ip) { + case "deny": + h.db.WriteAudit(&user.ID, models.AuditLoginFail, "IP not in allow-list: "+ip, ip, ua) + http.Redirect(w, r, "/auth/login?error=location_not_authorized", http.StatusFound) + return + case "skip_brute": + // Signal the BruteForceProtect middleware to skip failure counting for this user/IP + w.Header().Set("X-Skip-Brute", "1") + } + if err := crypto.CheckPassword(password, user.PasswordHash); err != nil { uid := user.ID h.db.WriteAudit(&uid, models.AuditLoginFail, "bad password for: "+username, ip, ua) @@ -403,3 +415,119 @@ func writeJSONError(w http.ResponseWriter, status int, msg string) { w.WriteHeader(status) json.NewEncoder(w).Encode(map[string]string{"error": msg}) } + +// ---- Profile Updates ---- + +func (h *AuthHandler) UpdateProfile(w http.ResponseWriter, r *http.Request) { + userID := middleware.GetUserID(r) + user, err := h.db.GetUserByID(userID) + if err != nil || user == nil { + writeJSONError(w, http.StatusUnauthorized, "not authenticated") + return + } + + var req struct { + Field string `json:"field"` // "email" | "username" + Value string `json:"value"` + Password string `json:"password"` // current password required for confirmation + } + json.NewDecoder(r.Body).Decode(&req) + + if req.Value == "" { + writeJSONError(w, http.StatusBadRequest, "value required") + return + } + if req.Password == "" { + writeJSONError(w, http.StatusBadRequest, "current password required to confirm profile changes") + return + } + if err := crypto.CheckPassword(req.Password, user.PasswordHash); err != nil { + writeJSONError(w, http.StatusForbidden, "incorrect password") + return + } + + switch req.Field { + case "email": + // Check uniqueness + existing, _ := h.db.GetUserByEmail(req.Value) + if existing != nil && existing.ID != userID { + writeJSONError(w, http.StatusConflict, "email already in use") + return + } + if err := h.db.UpdateUserEmail(userID, req.Value); err != nil { + writeJSONError(w, http.StatusInternalServerError, "failed to update email") + return + } + case "username": + existing, _ := h.db.GetUserByUsername(req.Value) + if existing != nil && existing.ID != userID { + writeJSONError(w, http.StatusConflict, "username already in use") + return + } + if err := h.db.UpdateUserUsername(userID, req.Value); err != nil { + writeJSONError(w, http.StatusInternalServerError, "failed to update username") + return + } + default: + writeJSONError(w, http.StatusBadRequest, "field must be 'email' or 'username'") + return + } + + ip := middleware.ClientIP(r) + h.db.WriteAudit(&userID, models.AuditUserUpdate, "profile update: "+req.Field, ip, r.UserAgent()) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]bool{"ok": true}) +} + +// ---- Per-User IP Rules ---- + +func (h *AuthHandler) GetUserIPRule(w http.ResponseWriter, r *http.Request) { + userID := middleware.GetUserID(r) + rule, err := h.db.GetUserIPRule(userID) + if err != nil { + writeJSONError(w, http.StatusInternalServerError, "db error") + return + } + if rule == nil { + rule = &db.UserIPRule{UserID: userID, Mode: "disabled", IPList: ""} + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(rule) +} + +func (h *AuthHandler) SetUserIPRule(w http.ResponseWriter, r *http.Request) { + userID := middleware.GetUserID(r) + var req struct { + Mode string `json:"mode"` // "disabled" | "brute_skip" | "allow_only" + IPList string `json:"ip_list"` // comma-separated + } + json.NewDecoder(r.Body).Decode(&req) + + validModes := map[string]bool{"disabled": true, "brute_skip": true, "allow_only": true} + if !validModes[req.Mode] { + writeJSONError(w, http.StatusBadRequest, "mode must be disabled, brute_skip, or allow_only") + return + } + + // Validate IPs + for _, rawIP := range db.SplitIPList(req.IPList) { + if net.ParseIP(rawIP) == nil { + writeJSONError(w, http.StatusBadRequest, "invalid IP address: "+rawIP) + return + } + } + + if req.Mode == "disabled" { + h.db.DeleteUserIPRule(userID) + } else { + if err := h.db.SetUserIPRule(userID, req.Mode, req.IPList); err != nil { + writeJSONError(w, http.StatusInternalServerError, "failed to save rule") + return + } + } + + ip := middleware.ClientIP(r) + h.db.WriteAudit(&userID, models.AuditUserUpdate, "IP rule updated: "+req.Mode, ip, r.UserAgent()) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]bool{"ok": true}) +} diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go index 2223e5d..388a4e8 100644 --- a/internal/middleware/middleware.go +++ b/internal/middleware/middleware.go @@ -227,7 +227,7 @@ func BruteForceProtect(database *db.DB, cfg *config.Config, next http.Handler) h username := r.FormValue("username") database.RecordLoginAttempt(ip, username, geoResult.Country, geoResult.CountryCode, success) - if !success { + if !success && !rw.skipBrute { failures := database.CountRecentFailures(ip, cfg.BruteWindowMins) if failures >= cfg.BruteMaxAttempts { reason := "Too many failed logins" @@ -260,16 +260,21 @@ func BruteForceProtect(database *db.DB, cfg *config.Config, next http.Handler) h }) } -// loginResponseCapture captures the redirect location from the login handler. +// loginResponseCapture captures the redirect location and skip-brute signal from the login handler. type loginResponseCapture struct { http.ResponseWriter statusCode int location string + skipBrute bool } func (lrc *loginResponseCapture) WriteHeader(code int) { lrc.statusCode = code lrc.location = lrc.ResponseWriter.Header().Get("Location") + if lrc.Header().Get("X-Skip-Brute") == "1" { + lrc.skipBrute = true + lrc.Header().Del("X-Skip-Brute") // strip before sending to client + } lrc.ResponseWriter.WriteHeader(code) } diff --git a/web/static/js/app.js b/web/static/js/app.js index 856ded4..c900f56 100644 --- a/web/static/js/app.js +++ b/web/static/js/app.js @@ -1379,6 +1379,28 @@ async function openSettings() { openModal('settings-modal'); loadSyncInterval(); renderMFAPanel(); + loadIPRules(); + // Pre-fill profile fields with current values + const me = await api('GET', '/me'); + if (me) { + document.getElementById('profile-username').placeholder = me.username || 'New username'; + document.getElementById('profile-email').placeholder = me.email || 'New email'; + } +} + +async function updateProfile(field) { + const value = document.getElementById('profile-' + field).value.trim(); + const password = document.getElementById('profile-confirm-pw').value; + if (!value) { toast('Please enter a new ' + field, 'error'); return; } + if (!password) { toast('Current password required to confirm changes', 'error'); return; } + const r = await api('PUT', '/profile', { field, value, password }); + if (r?.ok) { + toast(field.charAt(0).toUpperCase() + field.slice(1) + ' updated', 'success'); + document.getElementById('profile-' + field).value = ''; + document.getElementById('profile-confirm-pw').value = ''; + } else { + toast(r?.error || 'Update failed', 'error'); + } } async function loadSyncInterval() { @@ -1432,6 +1454,39 @@ async function disableMFA() { if(r?.ok){toast('MFA disabled','success');renderMFAPanel();}else toast(r?.error||'Invalid code','error'); } +async function loadIPRules() { + const r = await api('GET', '/ip-rules'); + if (!r) return; + document.getElementById('ip-rule-mode').value = r.mode || 'disabled'; + document.getElementById('ip-rule-list').value = r.ip_list || ''; + toggleIPRuleHelp(); +} + +function toggleIPRuleHelp() { + const mode = document.getElementById('ip-rule-mode').value; + const helpEl = document.getElementById('ip-rule-help'); + const listField = document.getElementById('ip-rule-list-field'); + const helps = { + disabled: '', + brute_skip: 'IPs in the list below will never be locked out of your account, even after many failed attempts. All other IPs are subject to global brute-force protection.', + allow_only: '⚠ Only IPs in the list below will be able to log into your account. All other IPs will see an "Access not authorized" error. Make sure to include your current IP before saving.', + }; + helpEl.textContent = helps[mode] || ''; + helpEl.style.display = mode !== 'disabled' ? 'block' : 'none'; + listField.style.display = mode !== 'disabled' ? 'block' : 'none'; +} + +async function saveIPRules() { + const mode = document.getElementById('ip-rule-mode').value; + const ip_list = document.getElementById('ip-rule-list').value.trim(); + if (mode !== 'disabled' && !ip_list) { + toast('Please enter at least one IP address', 'error'); return; + } + const r = await api('PUT', '/ip-rules', { mode, ip_list }); + if (r?.ok) toast('IP rules saved', 'success'); + else toast(r?.error || 'Save failed', 'error'); +} + async function doLogout() { await fetch('/auth/logout',{method:'POST'}); location.href='/auth/login'; } // ── Context menu helper ──────────────────────────────────────────────────── diff --git a/web/templates/admin.html b/web/templates/admin.html index c39dfb4..d53da30 100644 --- a/web/templates/admin.html +++ b/web/templates/admin.html @@ -39,5 +39,5 @@ {{end}} {{define "scripts"}} - + {{end}} \ No newline at end of file diff --git a/web/templates/app.html b/web/templates/app.html index b17cc99..9d507eb 100644 --- a/web/templates/app.html +++ b/web/templates/app.html @@ -264,12 +264,34 @@ @@ -309,5 +352,5 @@ {{end}} {{define "scripts"}} - + {{end}} diff --git a/web/templates/base.html b/web/templates/base.html index cbb9366..0ed7e90 100644 --- a/web/templates/base.html +++ b/web/templates/base.html @@ -5,12 +5,12 @@ {{block "title" .}}GoWebMail{{end}} - + {{block "head_extra" .}}{{end}} {{block "body" .}}{{end}} - + {{block "scripts" .}}{{end}} diff --git a/web/templates/login.html b/web/templates/login.html index dbecc14..3ddff5b 100644 --- a/web/templates/login.html +++ b/web/templates/login.html @@ -19,7 +19,7 @@ {{end}} {{define "scripts"}}