build-base
This commit is contained in:
@@ -0,0 +1,730 @@
|
||||
// Package caldav implements a CalDAV server (RFC 4791 over WebDAV RFC 4918).
|
||||
// Authentication: HTTP Basic Auth against the user DB.
|
||||
// Events are stored as AES-256-GCM encrypted iCalendar blobs.
|
||||
package caldav
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
appCrypto "ghb.freebede.com/nahakubuilder/mailgosend/internal/crypto"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/db"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/models"
|
||||
)
|
||||
|
||||
const (
|
||||
nsDAV = "DAV:"
|
||||
nsCalDAV = "urn:ietf:params:xml:ns:caldav"
|
||||
nsCS = "http://calendarserver.org/ns/"
|
||||
)
|
||||
|
||||
// Deps holds CalDAV server dependencies.
|
||||
type Deps struct {
|
||||
DB *db.DB
|
||||
Crypt *appCrypto.Crypto
|
||||
}
|
||||
|
||||
// Server is the CalDAV HTTP handler.
|
||||
type Server struct {
|
||||
deps *Deps
|
||||
mux *http.ServeMux
|
||||
}
|
||||
|
||||
// New creates a CalDAV server and registers handlers on the given mux prefix.
|
||||
func New(deps *Deps) *Server {
|
||||
s := &Server{deps: deps, mux: http.NewServeMux()}
|
||||
s.setup()
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
s.mux.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func (s *Server) setup() {
|
||||
// Well-known discovery redirects.
|
||||
s.mux.HandleFunc("/.well-known/caldav", func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, "/caldav/", http.StatusMovedPermanently)
|
||||
})
|
||||
|
||||
// All CalDAV routes use a catch-all; we route internally by path/method.
|
||||
s.mux.HandleFunc("/caldav/", s.withAuth(s.route))
|
||||
}
|
||||
|
||||
// withAuth wraps a handler requiring HTTP Basic Auth.
|
||||
func (s *Server) withAuth(next func(http.ResponseWriter, *http.Request, *models.User)) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
user, err := s.authenticate(r)
|
||||
if err != nil || user == nil {
|
||||
w.Header().Set("WWW-Authenticate", `Basic realm="mailgosend CalDAV", charset="UTF-8"`)
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
if !user.Enabled {
|
||||
http.Error(w, "Account disabled", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
next(w, r, user)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) authenticate(r *http.Request) (*models.User, error) {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if !strings.HasPrefix(authHeader, "Basic ") {
|
||||
return nil, nil
|
||||
}
|
||||
decoded, err := base64.StdEncoding.DecodeString(authHeader[6:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("basic auth decode: %w", err)
|
||||
}
|
||||
parts := strings.SplitN(string(decoded), ":", 2)
|
||||
if len(parts) != 2 {
|
||||
return nil, fmt.Errorf("basic auth format")
|
||||
}
|
||||
email := strings.TrimSpace(parts[0])
|
||||
password := parts[1]
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
user, err := s.deps.DB.GetUserByEmail(ctx, email)
|
||||
if err != nil || user == nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := appCrypto.CheckPassword(user.PasswordHash, password); err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// route dispatches CalDAV requests by path structure and HTTP method.
|
||||
// Path patterns under /caldav/:
|
||||
// /caldav/ → root (discovery)
|
||||
// /caldav/p/{userID} → principal
|
||||
// /caldav/{userID}/ → calendar home
|
||||
// /caldav/{userID}/{calID}/ → calendar collection
|
||||
// /caldav/{userID}/{calID}/{uid}.ics → event resource
|
||||
func (s *Server) route(w http.ResponseWriter, r *http.Request, user *models.User) {
|
||||
path := strings.TrimPrefix(r.URL.Path, "/caldav")
|
||||
path = strings.TrimSuffix(path, "/")
|
||||
segments := splitPath(path)
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second)
|
||||
defer cancel()
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
switch len(segments) {
|
||||
case 0:
|
||||
// /caldav/ — root
|
||||
s.handleOptions(w, r, "caldav")
|
||||
if r.Method == "PROPFIND" {
|
||||
s.propfindRoot(w, r, user)
|
||||
}
|
||||
|
||||
case 1:
|
||||
if segments[0] == "p" || strings.HasPrefix(segments[0], "p") {
|
||||
// /caldav/p — principal without user ID (redirect to user principal)
|
||||
http.Redirect(w, r, fmt.Sprintf("/caldav/p/%d", user.ID), http.StatusMovedPermanently)
|
||||
return
|
||||
}
|
||||
// /caldav/{userID}/ — calendar home
|
||||
ownerID, err := strconv.ParseInt(segments[0], 10, 64)
|
||||
if err != nil || ownerID != user.ID {
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
switch r.Method {
|
||||
case "OPTIONS":
|
||||
s.handleOptions(w, r, "collection")
|
||||
case "PROPFIND":
|
||||
s.propfindHome(w, r, user)
|
||||
case "MKCOL":
|
||||
s.mkcalendarHome(w, r, user)
|
||||
default:
|
||||
w.Header().Set("Allow", davAllow)
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
|
||||
case 2:
|
||||
// /caldav/p/{userID} — principal
|
||||
if segments[0] == "p" {
|
||||
switch r.Method {
|
||||
case "OPTIONS":
|
||||
s.handleOptions(w, r, "principal")
|
||||
case "PROPFIND":
|
||||
s.propfindPrincipal(w, r, user)
|
||||
default:
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
return
|
||||
}
|
||||
// /caldav/{userID}/{calID}/ — calendar collection
|
||||
ownerID, _ := strconv.ParseInt(segments[0], 10, 64)
|
||||
if ownerID != user.ID {
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
calID, err := strconv.ParseInt(segments[1], 10, 64)
|
||||
if err != nil {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
cal, err := s.deps.DB.GetCalendarByID(r.Context(), calID)
|
||||
if err != nil || cal == nil || cal.UserID != user.ID {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
switch r.Method {
|
||||
case "OPTIONS":
|
||||
s.handleOptions(w, r, "calendar")
|
||||
case "PROPFIND":
|
||||
s.propfindCalendar(w, r, user, cal)
|
||||
case "REPORT":
|
||||
s.reportCalendar(w, r, user, cal)
|
||||
case "DELETE":
|
||||
s.deps.DB.DeleteCalendar(r.Context(), calID) //nolint:errcheck
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
default:
|
||||
w.Header().Set("Allow", davAllow)
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
|
||||
case 3:
|
||||
// /caldav/{userID}/{calID}/{uid}.ics — event resource
|
||||
ownerID, _ := strconv.ParseInt(segments[0], 10, 64)
|
||||
if ownerID != user.ID {
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
calID, err := strconv.ParseInt(segments[1], 10, 64)
|
||||
if err != nil {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
cal, err := s.deps.DB.GetCalendarByID(r.Context(), calID)
|
||||
if err != nil || cal == nil || cal.UserID != user.ID {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
uid := strings.TrimSuffix(segments[2], ".ics")
|
||||
switch r.Method {
|
||||
case "GET", "HEAD":
|
||||
s.getEvent(w, r, user, cal, uid)
|
||||
case "PUT":
|
||||
s.putEvent(w, r, user, cal, uid)
|
||||
case "DELETE":
|
||||
s.deleteEvent(w, r, user, cal, uid)
|
||||
case "OPTIONS":
|
||||
s.handleOptions(w, r, "event")
|
||||
default:
|
||||
w.Header().Set("Allow", davAllow)
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
const davAllow = "OPTIONS, GET, PUT, DELETE, PROPFIND, MKCOL, REPORT"
|
||||
|
||||
func (s *Server) handleOptions(w http.ResponseWriter, r *http.Request, resType string) {
|
||||
w.Header().Set("DAV", "1, 2, 3, calendar-access")
|
||||
w.Header().Set("Allow", davAllow)
|
||||
w.Header().Set("Ms-Author-Via", "DAV")
|
||||
if r.Method == "OPTIONS" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
// ---- PROPFIND handlers ----
|
||||
|
||||
func (s *Server) propfindRoot(w http.ResponseWriter, r *http.Request, user *models.User) {
|
||||
depth := r.Header.Get("Depth")
|
||||
if depth == "" {
|
||||
depth = "0"
|
||||
}
|
||||
|
||||
responses := []davResponse{
|
||||
{
|
||||
Href: "/caldav/",
|
||||
Props: []davProp{
|
||||
{Name: "resourcetype", NS: nsDAV, Value: "<D:collection/>"},
|
||||
{Name: "displayname", NS: nsDAV, Value: "CalDAV"},
|
||||
{Name: "current-user-principal", NS: nsDAV,
|
||||
Value: fmt.Sprintf("<D:href>/caldav/p/%d</D:href>", user.ID)},
|
||||
},
|
||||
},
|
||||
}
|
||||
writeMultiStatus(w, responses)
|
||||
}
|
||||
|
||||
func (s *Server) propfindPrincipal(w http.ResponseWriter, r *http.Request, user *models.User) {
|
||||
responses := []davResponse{
|
||||
{
|
||||
Href: fmt.Sprintf("/caldav/p/%d", user.ID),
|
||||
Props: []davProp{
|
||||
{Name: "resourcetype", NS: nsDAV, Value: "<D:principal/>"},
|
||||
{Name: "displayname", NS: nsDAV, Value: xmlEscape(user.DisplayName)},
|
||||
{Name: "current-user-principal", NS: nsDAV,
|
||||
Value: fmt.Sprintf("<D:href>/caldav/p/%d</D:href>", user.ID)},
|
||||
{Name: "calendar-home-set", NS: nsCalDAV,
|
||||
Value: fmt.Sprintf("<D:href>/caldav/%d/</D:href>", user.ID)},
|
||||
},
|
||||
},
|
||||
}
|
||||
writeMultiStatus(w, responses)
|
||||
}
|
||||
|
||||
func (s *Server) propfindHome(w http.ResponseWriter, r *http.Request, user *models.User) {
|
||||
depth := r.Header.Get("Depth")
|
||||
|
||||
cals, err := s.deps.DB.ListCalendars(r.Context(), user.ID)
|
||||
if err != nil {
|
||||
log.Printf("[caldav] list calendars: %v", err)
|
||||
}
|
||||
|
||||
responses := []davResponse{
|
||||
{
|
||||
Href: fmt.Sprintf("/caldav/%d/", user.ID),
|
||||
Props: []davProp{
|
||||
{Name: "resourcetype", NS: nsDAV, Value: "<D:collection/>"},
|
||||
{Name: "displayname", NS: nsDAV, Value: "Calendars"},
|
||||
{Name: "current-user-principal", NS: nsDAV,
|
||||
Value: fmt.Sprintf("<D:href>/caldav/p/%d</D:href>", user.ID)},
|
||||
{Name: "calendar-home-set", NS: nsCalDAV,
|
||||
Value: fmt.Sprintf("<D:href>/caldav/%d/</D:href>", user.ID)},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if depth != "0" {
|
||||
for _, cal := range cals {
|
||||
responses = append(responses, calendarResponse(user.ID, cal))
|
||||
}
|
||||
}
|
||||
writeMultiStatus(w, responses)
|
||||
}
|
||||
|
||||
func (s *Server) propfindCalendar(w http.ResponseWriter, r *http.Request, user *models.User, cal *models.Calendar) {
|
||||
depth := r.Header.Get("Depth")
|
||||
|
||||
responses := []davResponse{calendarResponse(user.ID, cal)}
|
||||
|
||||
if depth != "0" {
|
||||
events, err := s.deps.DB.ListCalendarEvents(r.Context(), cal.ID)
|
||||
if err != nil {
|
||||
log.Printf("[caldav] list events: %v", err)
|
||||
}
|
||||
for _, ev := range events {
|
||||
responses = append(responses, eventResponse(user.ID, cal.ID, ev))
|
||||
}
|
||||
}
|
||||
writeMultiStatus(w, responses)
|
||||
}
|
||||
|
||||
func calendarResponse(userID int64, cal *models.Calendar) davResponse {
|
||||
ctag := strconv.FormatInt(cal.SyncToken, 10)
|
||||
syncToken := fmt.Sprintf("https://example.com/ns/sync/%d", cal.SyncToken)
|
||||
return davResponse{
|
||||
Href: fmt.Sprintf("/caldav/%d/%d/", userID, cal.ID),
|
||||
Props: []davProp{
|
||||
{Name: "resourcetype", NS: nsDAV,
|
||||
Value: `<D:collection/><C:calendar xmlns:C="` + nsCalDAV + `"/>`},
|
||||
{Name: "displayname", NS: nsDAV, Value: xmlEscape(cal.Name)},
|
||||
{Name: "calendar-description", NS: nsCalDAV, Value: xmlEscape(cal.Description)},
|
||||
{Name: "calendar-color", NS: "http://apple.com/ns/ical/", Value: xmlEscape(cal.Color)},
|
||||
{Name: "supported-calendar-component-set", NS: nsCalDAV,
|
||||
Value: `<C:comp xmlns:C="` + nsCalDAV + `" name="VEVENT"/><C:comp xmlns:C="` + nsCalDAV + `" name="VTODO"/>`},
|
||||
{Name: "getctag", NS: nsCS, Value: ctag},
|
||||
{Name: "sync-token", NS: nsDAV, Value: xmlEscape(syncToken)},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func eventResponse(userID, calID int64, ev *models.CalendarEvent) davResponse {
|
||||
return davResponse{
|
||||
Href: fmt.Sprintf("/caldav/%d/%d/%s.ics", userID, calID, ev.UID),
|
||||
Props: []davProp{
|
||||
{Name: "resourcetype", NS: nsDAV, Value: ""},
|
||||
{Name: "getetag", NS: nsDAV, Value: `"` + ev.ETag + `"`},
|
||||
{Name: "getcontenttype", NS: nsDAV, Value: "text/calendar; charset=utf-8"},
|
||||
{Name: "getlastmodified", NS: nsDAV, Value: ev.UpdatedAt.UTC().Format(http.TimeFormat)},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ---- REPORT ----
|
||||
|
||||
func (s *Server) reportCalendar(w http.ResponseWriter, r *http.Request, user *models.User, cal *models.Calendar) {
|
||||
body, err := io.ReadAll(io.LimitReader(r.Body, 256*1024))
|
||||
if err != nil {
|
||||
http.Error(w, "read error", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
XMLName xml.Name `xml:""`
|
||||
SyncToken string `xml:"sync-token"`
|
||||
Hrefs []string `xml:"href"`
|
||||
}
|
||||
_ = xml.Unmarshal(body, &req)
|
||||
|
||||
localName := ""
|
||||
if req.XMLName.Local != "" {
|
||||
localName = req.XMLName.Local
|
||||
}
|
||||
|
||||
switch localName {
|
||||
case "sync-collection":
|
||||
s.syncCollection(w, r, user, cal, req.SyncToken)
|
||||
case "calendar-multiget":
|
||||
s.calendarMultiget(w, r, user, cal, req.Hrefs)
|
||||
default:
|
||||
// calendar-query or unknown: return all events.
|
||||
s.propfindCalendar(w, r, user, cal)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) syncCollection(w http.ResponseWriter, r *http.Request, user *models.User, cal *models.Calendar, clientToken string) {
|
||||
// Minimal sync: if token matches current, return 0 changes; else return all.
|
||||
curToken := fmt.Sprintf("https://example.com/ns/sync/%d", cal.SyncToken)
|
||||
if clientToken == curToken {
|
||||
// No changes since last sync.
|
||||
writeMultiStatus(w, []davResponse{})
|
||||
return
|
||||
}
|
||||
// Full sync: return all events.
|
||||
events, err := s.deps.DB.ListCalendarEvents(r.Context(), cal.ID)
|
||||
if err != nil {
|
||||
log.Printf("[caldav] sync events: %v", err)
|
||||
}
|
||||
var responses []davResponse
|
||||
for _, ev := range events {
|
||||
responses = append(responses, eventResponse(user.ID, cal.ID, ev))
|
||||
}
|
||||
writeMultiStatus(w, responses)
|
||||
}
|
||||
|
||||
func (s *Server) calendarMultiget(w http.ResponseWriter, r *http.Request, user *models.User, cal *models.Calendar, hrefs []string) {
|
||||
var responses []davResponse
|
||||
for _, href := range hrefs {
|
||||
// Extract UID from href like /caldav/{uid}/{calid}/{uid}.ics
|
||||
parts := splitPath(strings.TrimPrefix(href, "/caldav"))
|
||||
if len(parts) < 3 {
|
||||
continue
|
||||
}
|
||||
uid := strings.TrimSuffix(parts[len(parts)-1], ".ics")
|
||||
ev, err := s.deps.DB.GetCalendarEvent(r.Context(), cal.ID, uid)
|
||||
if err != nil || ev == nil {
|
||||
responses = append(responses, davResponse{
|
||||
Href: href,
|
||||
Status: "HTTP/1.1 404 Not Found",
|
||||
})
|
||||
continue
|
||||
}
|
||||
// Return event with ical data.
|
||||
raw, err := s.decryptICal(user.ID, ev.ICalEnc)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
res := eventResponse(user.ID, cal.ID, ev)
|
||||
res.Props = append(res.Props, davProp{
|
||||
Name: "calendar-data", NS: nsCalDAV, Value: xmlEscape(string(raw)), CData: true,
|
||||
})
|
||||
responses = append(responses, res)
|
||||
}
|
||||
writeMultiStatus(w, responses)
|
||||
}
|
||||
|
||||
// ---- Event GET/PUT/DELETE ----
|
||||
|
||||
func (s *Server) getEvent(w http.ResponseWriter, r *http.Request, user *models.User, cal *models.Calendar, uid string) {
|
||||
ev, err := s.deps.DB.GetCalendarEvent(r.Context(), cal.ID, uid)
|
||||
if err != nil || ev == nil {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Conditional GET.
|
||||
if match := r.Header.Get("If-None-Match"); match != "" {
|
||||
if match == `"`+ev.ETag+`"` {
|
||||
w.WriteHeader(http.StatusNotModified)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
raw, err := s.decryptICal(user.ID, ev.ICalEnc)
|
||||
if err != nil {
|
||||
http.Error(w, "decrypt error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/calendar; charset=utf-8")
|
||||
w.Header().Set("ETag", `"`+ev.ETag+`"`)
|
||||
w.Header().Set("Last-Modified", ev.UpdatedAt.UTC().Format(http.TimeFormat))
|
||||
if r.Method == "HEAD" {
|
||||
w.Header().Set("Content-Length", strconv.Itoa(len(raw)))
|
||||
return
|
||||
}
|
||||
w.Write(raw) //nolint:errcheck
|
||||
}
|
||||
|
||||
func (s *Server) putEvent(w http.ResponseWriter, r *http.Request, user *models.User, cal *models.Calendar, uid string) {
|
||||
if r.ContentLength > 1024*1024 {
|
||||
http.Error(w, "Request too large", http.StatusRequestEntityTooLarge)
|
||||
return
|
||||
}
|
||||
|
||||
raw, err := io.ReadAll(io.LimitReader(r.Body, 1024*1024))
|
||||
if err != nil || len(raw) == 0 {
|
||||
http.Error(w, "read error", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate it's iCalendar data.
|
||||
rawStr := string(raw)
|
||||
if !strings.Contains(rawStr, "BEGIN:VCALENDAR") {
|
||||
http.Error(w, "Invalid iCalendar data", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Conditional check: If-Match must match existing ETag.
|
||||
existing, err := s.deps.DB.GetCalendarEvent(r.Context(), cal.ID, uid)
|
||||
if err != nil {
|
||||
http.Error(w, "db error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if ifMatch := r.Header.Get("If-Match"); ifMatch != "" && ifMatch != "*" {
|
||||
if existing == nil || `"`+existing.ETag+`"` != ifMatch {
|
||||
http.Error(w, "Precondition Failed", http.StatusPreconditionFailed)
|
||||
return
|
||||
}
|
||||
}
|
||||
if r.Header.Get("If-None-Match") == "*" && existing != nil {
|
||||
http.Error(w, "Precondition Failed", http.StatusPreconditionFailed)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse minimal fields from iCalendar for DB index.
|
||||
dtStart, dtEnd, summary, recurring := parseICal(rawStr)
|
||||
|
||||
// Encrypt and store.
|
||||
icalEnc, err := s.encryptICal(user.ID, raw)
|
||||
if err != nil {
|
||||
http.Error(w, "encrypt error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
etag := sha256Hex(raw)
|
||||
if err := s.deps.DB.UpsertCalendarEvent(r.Context(), cal.ID, uid, etag, icalEnc, dtStart, dtEnd, summary, recurring); err != nil {
|
||||
log.Printf("[caldav] upsert event: %v", err)
|
||||
http.Error(w, "db error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if _, err := s.deps.DB.BumpCalendarSyncToken(r.Context(), cal.ID); err != nil {
|
||||
log.Printf("[caldav] bump token: %v", err)
|
||||
}
|
||||
|
||||
w.Header().Set("ETag", `"`+etag+`"`)
|
||||
if existing == nil {
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) deleteEvent(w http.ResponseWriter, r *http.Request, user *models.User, cal *models.Calendar, uid string) {
|
||||
ev, err := s.deps.DB.GetCalendarEvent(r.Context(), cal.ID, uid)
|
||||
if err != nil || ev == nil {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
if ifMatch := r.Header.Get("If-Match"); ifMatch != "" && ifMatch != "*" {
|
||||
if `"`+ev.ETag+`"` != ifMatch {
|
||||
http.Error(w, "Precondition Failed", http.StatusPreconditionFailed)
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := s.deps.DB.DeleteCalendarEvent(r.Context(), cal.ID, uid); err != nil {
|
||||
http.Error(w, "db error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if _, err := s.deps.DB.BumpCalendarSyncToken(r.Context(), cal.ID); err != nil {
|
||||
log.Printf("[caldav] bump token delete: %v", err)
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (s *Server) mkcalendarHome(w http.ResponseWriter, r *http.Request, user *models.User) {
|
||||
// MKCOL on calendar home: ensure default calendar exists.
|
||||
if _, err := s.deps.DB.EnsureDefaultCalendar(r.Context(), user.ID); err != nil {
|
||||
http.Error(w, "error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
}
|
||||
|
||||
// ---- Encryption helpers ----
|
||||
|
||||
func (s *Server) encryptICal(userID int64, plain []byte) ([]byte, error) {
|
||||
key, err := s.deps.Crypt.DeriveKey("ical", userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return appCrypto.Encrypt(key, plain)
|
||||
}
|
||||
|
||||
func (s *Server) decryptICal(userID int64, enc []byte) ([]byte, error) {
|
||||
key, err := s.deps.Crypt.DeriveKey("ical", userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return appCrypto.Decrypt(key, enc)
|
||||
}
|
||||
|
||||
// ---- iCalendar parsing (minimal, no third-party lib) ----
|
||||
|
||||
// parseICal extracts DTSTART, DTEND, SUMMARY, RRULE presence from raw iCal text.
|
||||
func parseICal(raw string) (dtStart, dtEnd time.Time, summary string, recurring bool) {
|
||||
for _, line := range strings.Split(raw, "\n") {
|
||||
line = strings.TrimRight(line, "\r")
|
||||
k, v, ok := strings.Cut(line, ":")
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
// Strip parameters from key (e.g., DTSTART;TZID=UTC → DTSTART)
|
||||
k = strings.SplitN(k, ";", 2)[0]
|
||||
switch strings.ToUpper(k) {
|
||||
case "DTSTART":
|
||||
dtStart = parseICalTime(v)
|
||||
case "DTEND":
|
||||
dtEnd = parseICalTime(v)
|
||||
case "SUMMARY":
|
||||
summary = icalUnfold(v)
|
||||
case "RRULE":
|
||||
recurring = true
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func parseICalTime(s string) time.Time {
|
||||
s = strings.TrimSuffix(s, "Z")
|
||||
formats := []string{"20060102T150405", "20060102"}
|
||||
for _, f := range formats {
|
||||
if t, err := time.Parse(f, s); err == nil {
|
||||
return t.UTC()
|
||||
}
|
||||
}
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
func icalUnfold(s string) string {
|
||||
return strings.ReplaceAll(s, `\n`, "\n")
|
||||
}
|
||||
|
||||
// ---- XML multi-status response ----
|
||||
|
||||
type davProp struct {
|
||||
Name string
|
||||
NS string
|
||||
Value string
|
||||
CData bool // wrap Value in CDATA
|
||||
}
|
||||
|
||||
type davResponse struct {
|
||||
Href string
|
||||
Status string // if empty, use 200 OK
|
||||
Props []davProp
|
||||
}
|
||||
|
||||
func writeMultiStatus(w http.ResponseWriter, responses []davResponse) {
|
||||
w.Header().Set("Content-Type", "application/xml; charset=utf-8")
|
||||
w.WriteHeader(http.StatusMultiStatus)
|
||||
|
||||
fmt.Fprintf(w, `<?xml version="1.0" encoding="UTF-8"?>`+"\n")
|
||||
fmt.Fprintf(w, `<D:multistatus xmlns:D="DAV:" xmlns:C="urn:ietf:params:xml:ns:caldav" xmlns:CS="http://calendarserver.org/ns/" xmlns:ICAL="http://apple.com/ns/ical/">`+"\n")
|
||||
|
||||
for _, resp := range responses {
|
||||
fmt.Fprintf(w, " <D:response>\n")
|
||||
fmt.Fprintf(w, " <D:href>%s</D:href>\n", xmlEscape(resp.Href))
|
||||
|
||||
if resp.Status != "" {
|
||||
fmt.Fprintf(w, " <D:status>%s</D:status>\n", xmlEscape(resp.Status))
|
||||
} else if len(resp.Props) > 0 {
|
||||
fmt.Fprintf(w, " <D:propstat>\n <D:prop>\n")
|
||||
for _, p := range resp.Props {
|
||||
writeDAVProp(w, p)
|
||||
}
|
||||
fmt.Fprintf(w, " </D:prop>\n <D:status>HTTP/1.1 200 OK</D:status>\n </D:propstat>\n")
|
||||
} else {
|
||||
fmt.Fprintf(w, " <D:status>HTTP/1.1 200 OK</D:status>\n")
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, " </D:response>\n")
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, "</D:multistatus>\n")
|
||||
}
|
||||
|
||||
func writeDAVProp(w http.ResponseWriter, p davProp) {
|
||||
ns := ""
|
||||
switch p.NS {
|
||||
case nsDAV:
|
||||
ns = "D"
|
||||
case nsCalDAV:
|
||||
ns = "C"
|
||||
case nsCS:
|
||||
ns = "CS"
|
||||
case "http://apple.com/ns/ical/":
|
||||
ns = "ICAL"
|
||||
default:
|
||||
ns = "D"
|
||||
}
|
||||
|
||||
if p.Value == "" {
|
||||
fmt.Fprintf(w, " <%s:%s/>\n", ns, p.Name)
|
||||
return
|
||||
}
|
||||
if p.CData {
|
||||
fmt.Fprintf(w, " <%s:%s><![CDATA[%s]]></%s:%s>\n", ns, p.Name, p.Value, ns, p.Name)
|
||||
} else {
|
||||
fmt.Fprintf(w, " <%s:%s>%s</%s:%s>\n", ns, p.Name, p.Value, ns, p.Name)
|
||||
}
|
||||
}
|
||||
|
||||
// ---- helpers ----
|
||||
|
||||
func splitPath(path string) []string {
|
||||
var out []string
|
||||
for _, s := range strings.Split(strings.Trim(path, "/"), "/") {
|
||||
if s != "" {
|
||||
out = append(out, s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func xmlEscape(s string) string {
|
||||
s = strings.ReplaceAll(s, "&", "&")
|
||||
s = strings.ReplaceAll(s, "<", "<")
|
||||
s = strings.ReplaceAll(s, ">", ">")
|
||||
s = strings.ReplaceAll(s, `"`, """)
|
||||
return s
|
||||
}
|
||||
|
||||
func sha256Hex(data []byte) string {
|
||||
h := sha256.Sum256(data)
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
@@ -0,0 +1,614 @@
|
||||
// Package carddav implements a CardDAV server (RFC 6352 over WebDAV RFC 4918).
|
||||
// Authentication: HTTP Basic Auth against the user DB.
|
||||
// Contacts are stored as AES-256-GCM encrypted vCard blobs.
|
||||
package carddav
|
||||
|
||||
import (
|
||||
"context"
|
||||
gocrypto "crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
appCrypto "ghb.freebede.com/nahakubuilder/mailgosend/internal/crypto"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/db"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/models"
|
||||
)
|
||||
|
||||
const (
|
||||
nsDAV = "DAV:"
|
||||
nsCardDAV = "urn:ietf:params:xml:ns:carddav"
|
||||
nsCS = "http://calendarserver.org/ns/"
|
||||
)
|
||||
|
||||
// Deps holds CardDAV server dependencies.
|
||||
type Deps struct {
|
||||
DB *db.DB
|
||||
Crypt *appCrypto.Crypto
|
||||
}
|
||||
|
||||
// Server is the CardDAV HTTP handler.
|
||||
type Server struct {
|
||||
deps *Deps
|
||||
mux *http.ServeMux
|
||||
}
|
||||
|
||||
// New creates a CardDAV server.
|
||||
func New(deps *Deps) *Server {
|
||||
s := &Server{deps: deps, mux: http.NewServeMux()}
|
||||
s.setup()
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
s.mux.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func (s *Server) setup() {
|
||||
s.mux.HandleFunc("/.well-known/carddav", func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, "/carddav/", http.StatusMovedPermanently)
|
||||
})
|
||||
s.mux.HandleFunc("/carddav/", s.withAuth(s.route))
|
||||
}
|
||||
|
||||
func (s *Server) withAuth(next func(http.ResponseWriter, *http.Request, *models.User)) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
user, err := s.authenticate(r)
|
||||
if err != nil || user == nil {
|
||||
w.Header().Set("WWW-Authenticate", `Basic realm="mailgosend CardDAV", charset="UTF-8"`)
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
if !user.Enabled {
|
||||
http.Error(w, "Account disabled", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
next(w, r, user)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) authenticate(r *http.Request) (*models.User, error) {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if !strings.HasPrefix(authHeader, "Basic ") {
|
||||
return nil, nil
|
||||
}
|
||||
decoded, err := base64.StdEncoding.DecodeString(authHeader[6:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("basic auth decode: %w", err)
|
||||
}
|
||||
parts := strings.SplitN(string(decoded), ":", 2)
|
||||
if len(parts) != 2 {
|
||||
return nil, fmt.Errorf("basic auth format")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
user, err := s.deps.DB.GetUserByEmail(ctx, strings.TrimSpace(parts[0]))
|
||||
if err != nil || user == nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := appCrypto.CheckPassword(user.PasswordHash, parts[1]); err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// route dispatches CardDAV requests.
|
||||
// Paths under /carddav/:
|
||||
// /carddav/ → root
|
||||
// /carddav/p/{userID} → principal
|
||||
// /carddav/{userID}/ → address book home
|
||||
// /carddav/{userID}/{abID}/ → address book collection
|
||||
// /carddav/{userID}/{abID}/{uid}.vcf → contact resource
|
||||
func (s *Server) route(w http.ResponseWriter, r *http.Request, user *models.User) {
|
||||
path := strings.TrimPrefix(r.URL.Path, "/carddav")
|
||||
path = strings.TrimSuffix(path, "/")
|
||||
segments := splitPath(path)
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second)
|
||||
defer cancel()
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
const allow = "OPTIONS, GET, PUT, DELETE, PROPFIND, MKCOL, REPORT"
|
||||
|
||||
switch len(segments) {
|
||||
case 0:
|
||||
s.setDavHeaders(w)
|
||||
if r.Method == "OPTIONS" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
if r.Method == "PROPFIND" {
|
||||
s.propfindRoot(w, r, user)
|
||||
return
|
||||
}
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
|
||||
case 1:
|
||||
if segments[0] == "p" {
|
||||
http.Redirect(w, r, fmt.Sprintf("/carddav/p/%d", user.ID), http.StatusMovedPermanently)
|
||||
return
|
||||
}
|
||||
ownerID, err := strconv.ParseInt(segments[0], 10, 64)
|
||||
if err != nil || ownerID != user.ID {
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
s.setDavHeaders(w)
|
||||
switch r.Method {
|
||||
case "OPTIONS":
|
||||
w.WriteHeader(http.StatusOK)
|
||||
case "PROPFIND":
|
||||
s.propfindHome(w, r, user)
|
||||
default:
|
||||
w.Header().Set("Allow", allow)
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
|
||||
case 2:
|
||||
if segments[0] == "p" {
|
||||
s.setDavHeaders(w)
|
||||
if r.Method == "PROPFIND" {
|
||||
s.propfindPrincipal(w, r, user)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
ownerID, _ := strconv.ParseInt(segments[0], 10, 64)
|
||||
if ownerID != user.ID {
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
abID, err := strconv.ParseInt(segments[1], 10, 64)
|
||||
if err != nil {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
ab, err := s.deps.DB.GetAddressBookByID(r.Context(), abID)
|
||||
if err != nil || ab == nil || ab.UserID != user.ID {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
s.setDavHeaders(w)
|
||||
switch r.Method {
|
||||
case "OPTIONS":
|
||||
w.WriteHeader(http.StatusOK)
|
||||
case "PROPFIND":
|
||||
s.propfindAddressBook(w, r, user, ab)
|
||||
case "REPORT":
|
||||
s.reportAddressBook(w, r, user, ab)
|
||||
case "DELETE":
|
||||
s.deps.DB.DeleteAddressBook(r.Context(), abID) //nolint:errcheck
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
default:
|
||||
w.Header().Set("Allow", allow)
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
|
||||
case 3:
|
||||
ownerID, _ := strconv.ParseInt(segments[0], 10, 64)
|
||||
if ownerID != user.ID {
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
abID, err := strconv.ParseInt(segments[1], 10, 64)
|
||||
if err != nil {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
ab, err := s.deps.DB.GetAddressBookByID(r.Context(), abID)
|
||||
if err != nil || ab == nil || ab.UserID != user.ID {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
uid := strings.TrimSuffix(segments[2], ".vcf")
|
||||
switch r.Method {
|
||||
case "GET", "HEAD":
|
||||
s.getContact(w, r, user, ab, uid)
|
||||
case "PUT":
|
||||
s.putContact(w, r, user, ab, uid)
|
||||
case "DELETE":
|
||||
s.deleteContact(w, r, user, ab, uid)
|
||||
case "OPTIONS":
|
||||
s.setDavHeaders(w)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
w.Header().Set("Allow", allow)
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) setDavHeaders(w http.ResponseWriter) {
|
||||
w.Header().Set("DAV", "1, 2, 3, addressbook")
|
||||
w.Header().Set("Allow", "OPTIONS, GET, PUT, DELETE, PROPFIND, MKCOL, REPORT")
|
||||
w.Header().Set("Ms-Author-Via", "DAV")
|
||||
}
|
||||
|
||||
// ---- PROPFIND handlers ----
|
||||
|
||||
func (s *Server) propfindRoot(w http.ResponseWriter, r *http.Request, user *models.User) {
|
||||
writeMultiStatus(w, []davResponse{
|
||||
{
|
||||
Href: "/carddav/",
|
||||
Props: []davProp{
|
||||
{Name: "resourcetype", NS: nsDAV, Value: "<D:collection/>"},
|
||||
{Name: "displayname", NS: nsDAV, Value: "CardDAV"},
|
||||
{Name: "current-user-principal", NS: nsDAV,
|
||||
Value: fmt.Sprintf("<D:href>/carddav/p/%d</D:href>", user.ID)},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) propfindPrincipal(w http.ResponseWriter, r *http.Request, user *models.User) {
|
||||
writeMultiStatus(w, []davResponse{
|
||||
{
|
||||
Href: fmt.Sprintf("/carddav/p/%d", user.ID),
|
||||
Props: []davProp{
|
||||
{Name: "resourcetype", NS: nsDAV, Value: "<D:principal/>"},
|
||||
{Name: "displayname", NS: nsDAV, Value: xmlEscape(user.DisplayName)},
|
||||
{Name: "current-user-principal", NS: nsDAV,
|
||||
Value: fmt.Sprintf("<D:href>/carddav/p/%d</D:href>", user.ID)},
|
||||
{Name: "addressbook-home-set", NS: nsCardDAV,
|
||||
Value: fmt.Sprintf("<D:href>/carddav/%d/</D:href>", user.ID)},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) propfindHome(w http.ResponseWriter, r *http.Request, user *models.User) {
|
||||
depth := r.Header.Get("Depth")
|
||||
abs, err := s.deps.DB.ListAddressBooks(r.Context(), user.ID)
|
||||
if err != nil {
|
||||
log.Printf("[carddav] list address books: %v", err)
|
||||
}
|
||||
|
||||
responses := []davResponse{
|
||||
{
|
||||
Href: fmt.Sprintf("/carddav/%d/", user.ID),
|
||||
Props: []davProp{
|
||||
{Name: "resourcetype", NS: nsDAV, Value: "<D:collection/>"},
|
||||
{Name: "displayname", NS: nsDAV, Value: "Address Books"},
|
||||
{Name: "addressbook-home-set", NS: nsCardDAV,
|
||||
Value: fmt.Sprintf("<D:href>/carddav/%d/</D:href>", user.ID)},
|
||||
},
|
||||
},
|
||||
}
|
||||
if depth != "0" {
|
||||
for _, ab := range abs {
|
||||
responses = append(responses, addressBookResponse(user.ID, ab))
|
||||
}
|
||||
}
|
||||
writeMultiStatus(w, responses)
|
||||
}
|
||||
|
||||
func (s *Server) propfindAddressBook(w http.ResponseWriter, r *http.Request, user *models.User, ab *models.AddressBook) {
|
||||
depth := r.Header.Get("Depth")
|
||||
responses := []davResponse{addressBookResponse(user.ID, ab)}
|
||||
if depth != "0" {
|
||||
contacts, err := s.deps.DB.ListContacts(r.Context(), ab.ID)
|
||||
if err != nil {
|
||||
log.Printf("[carddav] list contacts: %v", err)
|
||||
}
|
||||
for _, c := range contacts {
|
||||
responses = append(responses, contactResponse(user.ID, ab.ID, c))
|
||||
}
|
||||
}
|
||||
writeMultiStatus(w, responses)
|
||||
}
|
||||
|
||||
func addressBookResponse(userID int64, ab *models.AddressBook) davResponse {
|
||||
ctag := strconv.FormatInt(ab.SyncToken, 10)
|
||||
return davResponse{
|
||||
Href: fmt.Sprintf("/carddav/%d/%d/", userID, ab.ID),
|
||||
Props: []davProp{
|
||||
{Name: "resourcetype", NS: nsDAV,
|
||||
Value: `<D:collection/><CARD:addressbook xmlns:CARD="` + nsCardDAV + `"/>`},
|
||||
{Name: "displayname", NS: nsDAV, Value: xmlEscape(ab.Name)},
|
||||
{Name: "addressbook-description", NS: nsCardDAV, Value: xmlEscape(ab.Description)},
|
||||
{Name: "getctag", NS: nsCS, Value: ctag},
|
||||
{Name: "sync-token", NS: nsDAV,
|
||||
Value: fmt.Sprintf("https://example.com/ns/sync/%d", ab.SyncToken)},
|
||||
{Name: "supported-address-data", NS: nsCardDAV,
|
||||
Value: `<CARD:address-data-type xmlns:CARD="` + nsCardDAV + `" content-type="text/vcard" version="3.0"/>`},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func contactResponse(userID, abID int64, c *models.Contact) davResponse {
|
||||
return davResponse{
|
||||
Href: fmt.Sprintf("/carddav/%d/%d/%s.vcf", userID, abID, c.UID),
|
||||
Props: []davProp{
|
||||
{Name: "resourcetype", NS: nsDAV, Value: ""},
|
||||
{Name: "getetag", NS: nsDAV, Value: `"` + c.ETag + `"`},
|
||||
{Name: "getcontenttype", NS: nsDAV, Value: "text/vcard; charset=utf-8"},
|
||||
{Name: "getlastmodified", NS: nsDAV, Value: c.UpdatedAt.UTC().Format(http.TimeFormat)},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ---- REPORT ----
|
||||
|
||||
func (s *Server) reportAddressBook(w http.ResponseWriter, r *http.Request, user *models.User, ab *models.AddressBook) {
|
||||
body, err := io.ReadAll(io.LimitReader(r.Body, 256*1024))
|
||||
if err != nil {
|
||||
http.Error(w, "read error", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
XMLName xml.Name `xml:""`
|
||||
SyncToken string `xml:"sync-token"`
|
||||
Hrefs []string `xml:"href"`
|
||||
}
|
||||
_ = xml.Unmarshal(body, &req)
|
||||
|
||||
switch req.XMLName.Local {
|
||||
case "sync-collection":
|
||||
curToken := fmt.Sprintf("https://example.com/ns/sync/%d", ab.SyncToken)
|
||||
if req.SyncToken == curToken {
|
||||
writeMultiStatus(w, []davResponse{})
|
||||
return
|
||||
}
|
||||
contacts, err := s.deps.DB.ListContacts(r.Context(), ab.ID)
|
||||
if err != nil {
|
||||
log.Printf("[carddav] sync contacts: %v", err)
|
||||
}
|
||||
var responses []davResponse
|
||||
for _, c := range contacts {
|
||||
responses = append(responses, contactResponse(user.ID, ab.ID, c))
|
||||
}
|
||||
writeMultiStatus(w, responses)
|
||||
|
||||
case "addressbook-multiget":
|
||||
var responses []davResponse
|
||||
for _, href := range req.Hrefs {
|
||||
parts := splitPath(strings.TrimPrefix(href, "/carddav"))
|
||||
if len(parts) < 3 {
|
||||
continue
|
||||
}
|
||||
uid := strings.TrimSuffix(parts[len(parts)-1], ".vcf")
|
||||
c, err := s.deps.DB.GetContact(r.Context(), ab.ID, uid)
|
||||
if err != nil || c == nil {
|
||||
responses = append(responses, davResponse{
|
||||
Href: href,
|
||||
Status: "HTTP/1.1 404 Not Found",
|
||||
})
|
||||
continue
|
||||
}
|
||||
raw, err := s.decryptVCard(user.ID, c.VCardEnc)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
res := contactResponse(user.ID, ab.ID, c)
|
||||
res.Props = append(res.Props, davProp{
|
||||
Name: "address-data",
|
||||
NS: nsCardDAV,
|
||||
Value: xmlEscape(string(raw)),
|
||||
CData: true,
|
||||
})
|
||||
responses = append(responses, res)
|
||||
}
|
||||
writeMultiStatus(w, responses)
|
||||
|
||||
default:
|
||||
s.propfindAddressBook(w, r, user, ab)
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Contact GET/PUT/DELETE ----
|
||||
|
||||
func (s *Server) getContact(w http.ResponseWriter, r *http.Request, user *models.User, ab *models.AddressBook, uid string) {
|
||||
c, err := s.deps.DB.GetContact(r.Context(), ab.ID, uid)
|
||||
if err != nil || c == nil {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
if match := r.Header.Get("If-None-Match"); match == `"`+c.ETag+`"` {
|
||||
w.WriteHeader(http.StatusNotModified)
|
||||
return
|
||||
}
|
||||
|
||||
raw, err := s.decryptVCard(user.ID, c.VCardEnc)
|
||||
if err != nil {
|
||||
http.Error(w, "decrypt error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/vcard; charset=utf-8")
|
||||
w.Header().Set("ETag", `"`+c.ETag+`"`)
|
||||
w.Header().Set("Last-Modified", c.UpdatedAt.UTC().Format(http.TimeFormat))
|
||||
if r.Method == "HEAD" {
|
||||
w.Header().Set("Content-Length", strconv.Itoa(len(raw)))
|
||||
return
|
||||
}
|
||||
w.Write(raw) //nolint:errcheck
|
||||
}
|
||||
|
||||
func (s *Server) putContact(w http.ResponseWriter, r *http.Request, user *models.User, ab *models.AddressBook, uid string) {
|
||||
if r.ContentLength > 512*1024 {
|
||||
http.Error(w, "Request too large", http.StatusRequestEntityTooLarge)
|
||||
return
|
||||
}
|
||||
raw, err := io.ReadAll(io.LimitReader(r.Body, 512*1024))
|
||||
if err != nil || len(raw) == 0 {
|
||||
http.Error(w, "read error", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if !strings.Contains(string(raw), "BEGIN:VCARD") {
|
||||
http.Error(w, "Invalid vCard data", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
existing, err := s.deps.DB.GetContact(r.Context(), ab.ID, uid)
|
||||
if err != nil {
|
||||
http.Error(w, "db error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if ifMatch := r.Header.Get("If-Match"); ifMatch != "" && ifMatch != "*" {
|
||||
if existing == nil || `"`+existing.ETag+`"` != ifMatch {
|
||||
http.Error(w, "Precondition Failed", http.StatusPreconditionFailed)
|
||||
return
|
||||
}
|
||||
}
|
||||
if r.Header.Get("If-None-Match") == "*" && existing != nil {
|
||||
http.Error(w, "Precondition Failed", http.StatusPreconditionFailed)
|
||||
return
|
||||
}
|
||||
|
||||
vcardEnc, err := s.encryptVCard(user.ID, raw)
|
||||
if err != nil {
|
||||
http.Error(w, "encrypt error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
etag := sha256Hex(raw)
|
||||
if err := s.deps.DB.UpsertContact(r.Context(), ab.ID, uid, etag, vcardEnc); err != nil {
|
||||
log.Printf("[carddav] upsert contact: %v", err)
|
||||
http.Error(w, "db error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if _, err := s.deps.DB.BumpAddressBookSyncToken(r.Context(), ab.ID); err != nil {
|
||||
log.Printf("[carddav] bump token: %v", err)
|
||||
}
|
||||
w.Header().Set("ETag", `"`+etag+`"`)
|
||||
if existing == nil {
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) deleteContact(w http.ResponseWriter, r *http.Request, user *models.User, ab *models.AddressBook, uid string) {
|
||||
c, err := s.deps.DB.GetContact(r.Context(), ab.ID, uid)
|
||||
if err != nil || c == nil {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
if ifMatch := r.Header.Get("If-Match"); ifMatch != "" && ifMatch != "*" {
|
||||
if `"`+c.ETag+`"` != ifMatch {
|
||||
http.Error(w, "Precondition Failed", http.StatusPreconditionFailed)
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := s.deps.DB.DeleteContact(r.Context(), ab.ID, uid); err != nil {
|
||||
http.Error(w, "db error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if _, err := s.deps.DB.BumpAddressBookSyncToken(r.Context(), ab.ID); err != nil {
|
||||
log.Printf("[carddav] bump token delete: %v", err)
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// ---- Encryption ----
|
||||
|
||||
func (s *Server) encryptVCard(userID int64, plain []byte) ([]byte, error) {
|
||||
key, err := s.deps.Crypt.DeriveKey("vcard", userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return appCrypto.Encrypt(key, plain)
|
||||
}
|
||||
|
||||
func (s *Server) decryptVCard(userID int64, enc []byte) ([]byte, error) {
|
||||
key, err := s.deps.Crypt.DeriveKey("vcard", userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return appCrypto.Decrypt(key, enc)
|
||||
}
|
||||
|
||||
// ---- XML helpers (shared pattern with caldav) ----
|
||||
|
||||
type davProp struct {
|
||||
Name string
|
||||
NS string
|
||||
Value string
|
||||
CData bool
|
||||
}
|
||||
|
||||
type davResponse struct {
|
||||
Href string
|
||||
Status string
|
||||
Props []davProp
|
||||
}
|
||||
|
||||
func writeMultiStatus(w http.ResponseWriter, responses []davResponse) {
|
||||
w.Header().Set("Content-Type", "application/xml; charset=utf-8")
|
||||
w.WriteHeader(http.StatusMultiStatus)
|
||||
|
||||
fmt.Fprintf(w, `<?xml version="1.0" encoding="UTF-8"?>`+"\n")
|
||||
fmt.Fprintf(w, `<D:multistatus xmlns:D="DAV:" xmlns:CARD="urn:ietf:params:xml:ns:carddav" xmlns:CS="http://calendarserver.org/ns/">`+"\n")
|
||||
|
||||
for _, resp := range responses {
|
||||
fmt.Fprintf(w, " <D:response>\n")
|
||||
fmt.Fprintf(w, " <D:href>%s</D:href>\n", xmlEscape(resp.Href))
|
||||
if resp.Status != "" {
|
||||
fmt.Fprintf(w, " <D:status>%s</D:status>\n", xmlEscape(resp.Status))
|
||||
} else if len(resp.Props) > 0 {
|
||||
fmt.Fprintf(w, " <D:propstat>\n <D:prop>\n")
|
||||
for _, p := range resp.Props {
|
||||
writeDAVProp(w, p)
|
||||
}
|
||||
fmt.Fprintf(w, " </D:prop>\n <D:status>HTTP/1.1 200 OK</D:status>\n </D:propstat>\n")
|
||||
} else {
|
||||
fmt.Fprintf(w, " <D:status>HTTP/1.1 200 OK</D:status>\n")
|
||||
}
|
||||
fmt.Fprintf(w, " </D:response>\n")
|
||||
}
|
||||
fmt.Fprintf(w, "</D:multistatus>\n")
|
||||
}
|
||||
|
||||
func writeDAVProp(w http.ResponseWriter, p davProp) {
|
||||
ns := "D"
|
||||
switch p.NS {
|
||||
case nsCardDAV:
|
||||
ns = "CARD"
|
||||
case nsCS:
|
||||
ns = "CS"
|
||||
}
|
||||
if p.Value == "" {
|
||||
fmt.Fprintf(w, " <%s:%s/>\n", ns, p.Name)
|
||||
return
|
||||
}
|
||||
if p.CData {
|
||||
fmt.Fprintf(w, " <%s:%s><![CDATA[%s]]></%s:%s>\n", ns, p.Name, p.Value, ns, p.Name)
|
||||
} else {
|
||||
fmt.Fprintf(w, " <%s:%s>%s</%s:%s>\n", ns, p.Name, p.Value, ns, p.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func splitPath(path string) []string {
|
||||
var out []string
|
||||
for _, s := range strings.Split(strings.Trim(path, "/"), "/") {
|
||||
if s != "" {
|
||||
out = append(out, s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func xmlEscape(s string) string {
|
||||
s = strings.ReplaceAll(s, "&", "&")
|
||||
s = strings.ReplaceAll(s, "<", "<")
|
||||
s = strings.ReplaceAll(s, ">", ">")
|
||||
s = strings.ReplaceAll(s, `"`, """)
|
||||
return s
|
||||
}
|
||||
|
||||
func sha256Hex(data []byte) string {
|
||||
h := gocrypto.Sum256(data)
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
@@ -680,6 +680,56 @@ func logStartup(c *Config) {
|
||||
fmt.Printf(" CalDAV : %s:%d\n", c.CalDAVIface, c.CalDAVPort)
|
||||
}
|
||||
|
||||
// ---- Validation ----
|
||||
|
||||
// Validate checks the configuration for missing required fields and insecure
|
||||
// defaults, printing warnings to stdout. Returns a non-nil error only for
|
||||
// truly fatal conditions (no listening ports enabled).
|
||||
func (c *Config) Validate() error {
|
||||
warn := func(msg string) { fmt.Printf("[config] WARNING: %s\n", msg) }
|
||||
|
||||
// Secrets
|
||||
if len(c.SessionSecret) < 32 {
|
||||
warn("SESSION_SECRET is missing or too short (< 32 bytes). Regenerate it.")
|
||||
}
|
||||
if len(c.EncryptionKey) == 0 {
|
||||
warn("ENCRYPTION_KEY is missing. Emails will not be encrypted at rest.")
|
||||
}
|
||||
|
||||
// Hostname / domain
|
||||
if c.Hostname == "" || c.Hostname == "mail.example.com" {
|
||||
warn("HOSTNAME is not set to a real FQDN. SMTP HELO will be rejected by strict servers.")
|
||||
}
|
||||
if c.DefaultDomain == "" || c.DefaultDomain == "example.com" {
|
||||
warn("DEFAULT_DOMAIN is not set. Email delivery may fail.")
|
||||
}
|
||||
|
||||
// ACME
|
||||
if (c.TLSMode == "dns01" || c.TLSMode == "http01") && c.ACMEEmail == "" {
|
||||
warn("ACME_EMAIL is required for automatic TLS certificate provisioning.")
|
||||
}
|
||||
|
||||
// Ports — at least one service should be listening.
|
||||
anyPort := c.SMTPEnabled || c.SubmitEnabled || c.SMTPSEnabled ||
|
||||
c.IMAPEnabled || c.IMAPSEnabled ||
|
||||
c.WebClientPort > 0 || c.WebAdminPort > 0
|
||||
if !anyPort {
|
||||
return fmt.Errorf("no services enabled — check SMTP_ENABLED, IMAP_ENABLED, WEB_CLIENT_PORT, WEB_ADMIN_PORT")
|
||||
}
|
||||
|
||||
// Admin binding — warn if admin panel is exposed on non-loopback.
|
||||
if c.WebAdminPort > 0 && c.WebAdminIface != "127.0.0.1" && c.WebAdminIface != "::1" {
|
||||
warn(fmt.Sprintf("WEB_ADMIN_IFACE=%q exposes the admin panel on a public interface. Consider restricting to 127.0.0.1.", c.WebAdminIface))
|
||||
}
|
||||
|
||||
// Queue
|
||||
if c.QueueMaxAgeHours <= 0 {
|
||||
warn("QUEUE_MAX_AGE_HOURS is 0 or negative. Queued messages will never expire.")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ---- Helpers ----
|
||||
|
||||
func mustHex(n int) string {
|
||||
|
||||
@@ -0,0 +1,373 @@
|
||||
// Package db — admin-specific queries: queue, ip_bans, security_events,
|
||||
// and update/delete operations for users and domains.
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ---- Types ----
|
||||
|
||||
// QueueEntry is a row from the delivery queue.
|
||||
type QueueEntry struct {
|
||||
ID int64
|
||||
DomainID sql.NullInt64
|
||||
FromAddr string
|
||||
ToAddr string
|
||||
MessageID string
|
||||
Status string // pending | failed | sent
|
||||
Attempts int
|
||||
LastAttempt sql.NullTime
|
||||
NextAttempt time.Time
|
||||
ErrorLog string
|
||||
CreatedAt time.Time
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
// IPBan is a row from ip_bans.
|
||||
type IPBan struct {
|
||||
ID int64
|
||||
IP string
|
||||
Reason string
|
||||
BannedAt time.Time
|
||||
ExpiresAt sql.NullTime
|
||||
ReleasedBy string
|
||||
}
|
||||
|
||||
// SecurityEvent is a row from security_events.
|
||||
type SecurityEvent struct {
|
||||
ID int64
|
||||
Type string
|
||||
IP string
|
||||
UserID sql.NullInt64
|
||||
Detail string
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// AdminStats aggregates summary counts for the dashboard.
|
||||
type AdminStats struct {
|
||||
TotalDomains int
|
||||
TotalUsers int
|
||||
TotalMessages int
|
||||
QueuePending int
|
||||
QueueFailed int
|
||||
ActiveBans int
|
||||
RecentEvents int // last 24h
|
||||
}
|
||||
|
||||
// ---- Queue ----
|
||||
|
||||
// ListQueueEntries returns all non-sent queue entries ordered by created_at desc.
|
||||
func (d *DB) ListQueueEntries(ctx context.Context) ([]*QueueEntry, error) {
|
||||
rows, err := d.db.QueryContext(ctx, `
|
||||
SELECT id, domain_id, from_addr, to_addr, message_id, status,
|
||||
attempts, last_attempt, next_attempt, error_log, created_at, expires_at
|
||||
FROM queue
|
||||
WHERE status != 'sent'
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 500`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var out []*QueueEntry
|
||||
for rows.Next() {
|
||||
q := &QueueEntry{}
|
||||
err := rows.Scan(
|
||||
&q.ID, &q.DomainID, &q.FromAddr, &q.ToAddr, &q.MessageID, &q.Status,
|
||||
&q.Attempts, &q.LastAttempt, &q.NextAttempt, &q.ErrorLog, &q.CreatedAt, &q.ExpiresAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, q)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
// RetryQueueEntry resets a queue entry to pending with immediate next_attempt.
|
||||
func (d *DB) RetryQueueEntry(ctx context.Context, id int64) error {
|
||||
_, err := d.db.ExecContext(ctx,
|
||||
"UPDATE queue SET status='pending', next_attempt=? WHERE id=?",
|
||||
time.Now().UTC(), id)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteQueueEntry removes a queue entry permanently.
|
||||
func (d *DB) DeleteQueueEntry(ctx context.Context, id int64) error {
|
||||
_, err := d.db.ExecContext(ctx, "DELETE FROM queue WHERE id=?", id)
|
||||
return err
|
||||
}
|
||||
|
||||
// ---- IP Bans ----
|
||||
|
||||
// ListIPBans returns all IP bans, active first.
|
||||
func (d *DB) ListIPBans(ctx context.Context) ([]*IPBan, error) {
|
||||
rows, err := d.db.QueryContext(ctx, `
|
||||
SELECT id, ip, reason, banned_at, expires_at, released_by
|
||||
FROM ip_bans
|
||||
ORDER BY banned_at DESC
|
||||
LIMIT 500`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var out []*IPBan
|
||||
for rows.Next() {
|
||||
b := &IPBan{}
|
||||
err := rows.Scan(&b.ID, &b.IP, &b.Reason, &b.BannedAt, &b.ExpiresAt, &b.ReleasedBy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, b)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
// AddIPBan inserts a manual IP ban. hours=0 means permanent.
|
||||
func (d *DB) AddIPBan(ctx context.Context, ip, reason string, hours int) error {
|
||||
var expiresAt *time.Time
|
||||
if hours > 0 {
|
||||
t := time.Now().UTC().Add(time.Duration(hours) * time.Hour)
|
||||
expiresAt = &t
|
||||
}
|
||||
_, err := d.db.ExecContext(ctx,
|
||||
"INSERT OR REPLACE INTO ip_bans (ip, reason, banned_at, expires_at) VALUES (?, ?, ?, ?)",
|
||||
ip, reason, time.Now().UTC(), expiresAt)
|
||||
return err
|
||||
}
|
||||
|
||||
// RemoveIPBan deletes a ban by IP address.
|
||||
func (d *DB) RemoveIPBan(ctx context.Context, ip string) error {
|
||||
_, err := d.db.ExecContext(ctx, "DELETE FROM ip_bans WHERE ip=?", ip)
|
||||
return err
|
||||
}
|
||||
|
||||
// ---- Security Events ----
|
||||
|
||||
// ListSecurityEvents returns recent security events.
|
||||
func (d *DB) ListSecurityEvents(ctx context.Context, limit int) ([]*SecurityEvent, error) {
|
||||
if limit <= 0 || limit > 1000 {
|
||||
limit = 200
|
||||
}
|
||||
rows, err := d.db.QueryContext(ctx, `
|
||||
SELECT id, type, ip, user_id, detail, created_at
|
||||
FROM security_events
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ?`, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var out []*SecurityEvent
|
||||
for rows.Next() {
|
||||
ev := &SecurityEvent{}
|
||||
err := rows.Scan(&ev.ID, &ev.Type, &ev.IP, &ev.UserID, &ev.Detail, &ev.CreatedAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, ev)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
// LogSecurityEvent inserts a security event record.
|
||||
func (d *DB) LogSecurityEvent(ctx context.Context, eventType, ip string, userID *int64, detail string) error {
|
||||
_, err := d.db.ExecContext(ctx,
|
||||
"INSERT INTO security_events (type, ip, user_id, detail, created_at) VALUES (?, ?, ?, ?, ?)",
|
||||
eventType, ip, userID, detail, time.Now().UTC())
|
||||
return err
|
||||
}
|
||||
|
||||
// ---- User admin operations ----
|
||||
|
||||
// SetUserEnabled enables or disables a user account.
|
||||
func (d *DB) SetUserEnabled(ctx context.Context, userID int64, enabled bool) error {
|
||||
_, err := d.db.ExecContext(ctx,
|
||||
"UPDATE users SET enabled=? WHERE id=?", enabled, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
// SetUserAdmin sets admin / domain_admin flags.
|
||||
func (d *DB) SetUserAdmin(ctx context.Context, userID int64, admin, domainAdmin bool) error {
|
||||
_, err := d.db.ExecContext(ctx,
|
||||
"UPDATE users SET admin=?, domain_admin=? WHERE id=?", admin, domainAdmin, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
// SetUserQuota updates quota_bytes.
|
||||
func (d *DB) SetUserQuota(ctx context.Context, userID int64, quotaBytes int64) error {
|
||||
_, err := d.db.ExecContext(ctx,
|
||||
"UPDATE users SET quota_bytes=? WHERE id=?", quotaBytes, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
// SetUserPassword replaces the bcrypt hash for a user.
|
||||
func (d *DB) SetUserPassword(ctx context.Context, userID int64, hash string) error {
|
||||
_, err := d.db.ExecContext(ctx,
|
||||
"UPDATE users SET password_hash=? WHERE id=?", hash, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
// SetUserDisplayName updates display_name.
|
||||
func (d *DB) SetUserDisplayName(ctx context.Context, userID int64, name string) error {
|
||||
_, err := d.db.ExecContext(ctx,
|
||||
"UPDATE users SET display_name=? WHERE id=?", name, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteUser permanently deletes a user and all associated data (cascade).
|
||||
func (d *DB) DeleteUser(ctx context.Context, userID int64) error {
|
||||
_, err := d.db.ExecContext(ctx, "DELETE FROM users WHERE id=?", userID)
|
||||
return err
|
||||
}
|
||||
|
||||
// ListAllUsers returns users across all domains, joined with domain name.
|
||||
func (d *DB) ListAllUsers(ctx context.Context) ([]*UserWithDomain, error) {
|
||||
rows, err := d.db.QueryContext(ctx, `
|
||||
SELECT u.id, u.domain_id, u.username, u.email, u.display_name,
|
||||
u.quota_bytes, u.used_bytes, u.enabled, u.admin, u.domain_admin,
|
||||
u.mfa_enabled, u.created_at, u.last_login,
|
||||
d.name AS domain_name
|
||||
FROM users u
|
||||
LEFT JOIN domains d ON d.id = u.domain_id
|
||||
ORDER BY d.name, u.email`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var out []*UserWithDomain
|
||||
for rows.Next() {
|
||||
u := &UserWithDomain{}
|
||||
var lastLogin sql.NullTime
|
||||
err := rows.Scan(
|
||||
&u.ID, &u.DomainID, &u.Username, &u.Email, &u.DisplayName,
|
||||
&u.QuotaBytes, &u.UsedBytes, &u.Enabled, &u.Admin, &u.DomainAdmin,
|
||||
&u.MFAEnabled, &u.CreatedAt, &lastLogin,
|
||||
&u.DomainName,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if lastLogin.Valid {
|
||||
u.LastLogin = lastLogin.Time
|
||||
}
|
||||
out = append(out, u)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
// UserWithDomain is a user row augmented with the domain name.
|
||||
type UserWithDomain struct {
|
||||
ID int64
|
||||
DomainID int64
|
||||
Username string
|
||||
Email string
|
||||
DisplayName string
|
||||
DomainName string
|
||||
QuotaBytes int64
|
||||
UsedBytes int64
|
||||
Enabled bool
|
||||
Admin bool
|
||||
DomainAdmin bool
|
||||
MFAEnabled bool
|
||||
CreatedAt time.Time
|
||||
LastLogin time.Time
|
||||
}
|
||||
|
||||
// ---- MFA operations ----
|
||||
|
||||
// SetMFASecret stores the encrypted TOTP secret. Passing nil clears the secret.
|
||||
func (d *DB) SetMFASecret(ctx context.Context, userID int64, encSecret []byte) error {
|
||||
_, err := d.db.ExecContext(ctx,
|
||||
"UPDATE users SET mfa_secret_enc=? WHERE id=?", encSecret, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
// SetMFAEnabled enables or disables TOTP for a user.
|
||||
func (d *DB) SetMFAEnabled(ctx context.Context, userID int64, enabled bool) error {
|
||||
_, err := d.db.ExecContext(ctx,
|
||||
"UPDATE users SET mfa_enabled=? WHERE id=?", enabled, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
// SetRecoveryCodes stores the encrypted recovery codes JSON.
|
||||
func (d *DB) SetRecoveryCodes(ctx context.Context, userID int64, encCodes []byte) error {
|
||||
_, err := d.db.ExecContext(ctx,
|
||||
"UPDATE users SET recovery_codes_enc=? WHERE id=?", encCodes, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
// ClearMFA disables MFA and removes the secret + recovery codes atomically.
|
||||
func (d *DB) ClearMFA(ctx context.Context, userID int64) error {
|
||||
_, err := d.db.ExecContext(ctx,
|
||||
"UPDATE users SET mfa_enabled=0, mfa_secret_enc=NULL, recovery_codes_enc=NULL WHERE id=?",
|
||||
userID)
|
||||
return err
|
||||
}
|
||||
|
||||
// ---- Domain admin operations ----
|
||||
|
||||
// SetDomainEnabled enables or disables a domain.
|
||||
func (d *DB) SetDomainEnabled(ctx context.Context, domainID int64, enabled bool) error {
|
||||
_, err := d.db.ExecContext(ctx,
|
||||
"UPDATE domains SET enabled=? WHERE id=?", enabled, domainID)
|
||||
return err
|
||||
}
|
||||
|
||||
// SetDomainLimits updates max_users and max_quota_bytes.
|
||||
func (d *DB) SetDomainLimits(ctx context.Context, domainID int64, maxUsers int, maxQuota int64) error {
|
||||
_, err := d.db.ExecContext(ctx,
|
||||
"UPDATE domains SET max_users=?, max_quota_bytes=? WHERE id=?",
|
||||
maxUsers, maxQuota, domainID)
|
||||
return err
|
||||
}
|
||||
|
||||
// SetDomainDNS stores SPF and DMARC policy strings (informational, for display).
|
||||
func (d *DB) SetDomainDNS(ctx context.Context, domainID int64, spf, dmarc string) error {
|
||||
_, err := d.db.ExecContext(ctx,
|
||||
"UPDATE domains SET spf_policy=?, dmarc_policy=? WHERE id=?", spf, dmarc, domainID)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteDomain permanently removes a domain (cascade deletes users + their data).
|
||||
func (d *DB) DeleteDomain(ctx context.Context, domainID int64) error {
|
||||
_, err := d.db.ExecContext(ctx, "DELETE FROM domains WHERE id=?", domainID)
|
||||
return err
|
||||
}
|
||||
|
||||
// ---- Stats ----
|
||||
|
||||
// GetAdminStats returns aggregate counts for the admin dashboard.
|
||||
func (d *DB) GetAdminStats(ctx context.Context) (*AdminStats, error) {
|
||||
s := &AdminStats{}
|
||||
|
||||
queries := []struct {
|
||||
dest *int
|
||||
sql string
|
||||
args []any
|
||||
}{
|
||||
{&s.TotalDomains, "SELECT COUNT(*) FROM domains WHERE enabled=1", nil},
|
||||
{&s.TotalUsers, "SELECT COUNT(*) FROM users WHERE enabled=1", nil},
|
||||
{&s.TotalMessages, "SELECT COUNT(*) FROM messages WHERE deleted_at IS NULL", nil},
|
||||
{&s.QueuePending, "SELECT COUNT(*) FROM queue WHERE status='pending'", nil},
|
||||
{&s.QueueFailed, "SELECT COUNT(*) FROM queue WHERE status='failed'", nil},
|
||||
{&s.ActiveBans, "SELECT COUNT(*) FROM ip_bans WHERE (expires_at IS NULL OR expires_at > ?)", []any{time.Now().UTC()}},
|
||||
{&s.RecentEvents, "SELECT COUNT(*) FROM security_events WHERE created_at > ?", []any{time.Now().UTC().Add(-24 * time.Hour)}},
|
||||
}
|
||||
|
||||
for _, q := range queries {
|
||||
err := d.db.QueryRowContext(ctx, q.sql, q.args...).Scan(q.dest)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("admin stats: %w", err)
|
||||
}
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
@@ -0,0 +1,325 @@
|
||||
// Package db — CalDAV and CardDAV database operations.
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/models"
|
||||
)
|
||||
|
||||
// ---- CalDAV ----
|
||||
|
||||
// ListCalendars returns all calendars for a user.
|
||||
func (d *DB) ListCalendars(ctx context.Context, userID int64) ([]*models.Calendar, error) {
|
||||
rows, err := d.db.QueryContext(ctx, `
|
||||
SELECT id, user_id, name, description, color, timezone, sync_token, created_at
|
||||
FROM calendars WHERE user_id=? ORDER BY name`, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var out []*models.Calendar
|
||||
for rows.Next() {
|
||||
c := &models.Calendar{}
|
||||
if err := rows.Scan(&c.ID, &c.UserID, &c.Name, &c.Description, &c.Color, &c.Timezone, &c.SyncToken, &c.CreatedAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, c)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
// GetCalendarByID returns a calendar by ID.
|
||||
func (d *DB) GetCalendarByID(ctx context.Context, id int64) (*models.Calendar, error) {
|
||||
row := d.db.QueryRowContext(ctx,
|
||||
"SELECT id, user_id, name, description, color, timezone, sync_token, created_at FROM calendars WHERE id=?", id)
|
||||
c := &models.Calendar{}
|
||||
err := row.Scan(&c.ID, &c.UserID, &c.Name, &c.Description, &c.Color, &c.Timezone, &c.SyncToken, &c.CreatedAt)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return c, err
|
||||
}
|
||||
|
||||
// CreateCalendar inserts a new calendar and returns the ID.
|
||||
func (d *DB) CreateCalendar(ctx context.Context, userID int64, name, description, color, timezone string) (int64, error) {
|
||||
if color == "" {
|
||||
color = "#4CAF50"
|
||||
}
|
||||
if timezone == "" {
|
||||
timezone = "UTC"
|
||||
}
|
||||
res, err := d.db.ExecContext(ctx,
|
||||
"INSERT INTO calendars (user_id, name, description, color, timezone, sync_token, created_at) VALUES (?,?,?,?,?,1,?)",
|
||||
userID, name, description, color, timezone, time.Now().UTC())
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("create calendar: %w", err)
|
||||
}
|
||||
return res.LastInsertId()
|
||||
}
|
||||
|
||||
// EnsureDefaultCalendar creates a "Personal" calendar for a user if none exists.
|
||||
func (d *DB) EnsureDefaultCalendar(ctx context.Context, userID int64) (*models.Calendar, error) {
|
||||
cals, err := d.ListCalendars(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(cals) > 0 {
|
||||
return cals[0], nil
|
||||
}
|
||||
id, err := d.CreateCalendar(ctx, userID, "Personal", "", "#4CAF50", "UTC")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return d.GetCalendarByID(ctx, id)
|
||||
}
|
||||
|
||||
// DeleteCalendar removes a calendar and all its events.
|
||||
func (d *DB) DeleteCalendar(ctx context.Context, calendarID int64) error {
|
||||
_, err := d.db.ExecContext(ctx, "DELETE FROM calendars WHERE id=?", calendarID)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateCalendar updates calendar metadata.
|
||||
func (d *DB) UpdateCalendar(ctx context.Context, calendarID int64, name, description, color, timezone string) error {
|
||||
_, err := d.db.ExecContext(ctx,
|
||||
"UPDATE calendars SET name=?, description=?, color=?, timezone=? WHERE id=?",
|
||||
name, description, color, timezone, calendarID)
|
||||
return err
|
||||
}
|
||||
|
||||
// BumpCalendarSyncToken increments sync_token and returns the new value.
|
||||
func (d *DB) BumpCalendarSyncToken(ctx context.Context, calendarID int64) (int64, error) {
|
||||
_, err := d.db.ExecContext(ctx,
|
||||
"UPDATE calendars SET sync_token = sync_token + 1 WHERE id=?", calendarID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
var token int64
|
||||
err = d.db.QueryRowContext(ctx, "SELECT sync_token FROM calendars WHERE id=?", calendarID).Scan(&token)
|
||||
return token, err
|
||||
}
|
||||
|
||||
// ---- Calendar events ----
|
||||
|
||||
// ListCalendarEvents returns all events in a calendar.
|
||||
func (d *DB) ListCalendarEvents(ctx context.Context, calendarID int64) ([]*models.CalendarEvent, error) {
|
||||
rows, err := d.db.QueryContext(ctx, `
|
||||
SELECT id, calendar_id, uid, ical_enc, etag, dt_start, dt_end, summary, recurring, created_at, updated_at
|
||||
FROM calendar_events WHERE calendar_id=? ORDER BY dt_start`, calendarID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var out []*models.CalendarEvent
|
||||
for rows.Next() {
|
||||
ev := &models.CalendarEvent{}
|
||||
var dtStart, dtEnd sql.NullTime
|
||||
err := rows.Scan(&ev.ID, &ev.CalendarID, &ev.UID, &ev.ICalEnc, &ev.ETag,
|
||||
&dtStart, &dtEnd, &ev.Summary, &ev.Recurring, &ev.CreatedAt, &ev.UpdatedAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if dtStart.Valid {
|
||||
ev.DTStart = dtStart.Time
|
||||
}
|
||||
if dtEnd.Valid {
|
||||
ev.DTEnd = dtEnd.Time
|
||||
}
|
||||
out = append(out, ev)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
// GetCalendarEvent returns one event by UID within a calendar.
|
||||
func (d *DB) GetCalendarEvent(ctx context.Context, calendarID int64, uid string) (*models.CalendarEvent, error) {
|
||||
row := d.db.QueryRowContext(ctx, `
|
||||
SELECT id, calendar_id, uid, ical_enc, etag, dt_start, dt_end, summary, recurring, created_at, updated_at
|
||||
FROM calendar_events WHERE calendar_id=? AND uid=?`, calendarID, uid)
|
||||
ev := &models.CalendarEvent{}
|
||||
var dtStart, dtEnd sql.NullTime
|
||||
err := row.Scan(&ev.ID, &ev.CalendarID, &ev.UID, &ev.ICalEnc, &ev.ETag,
|
||||
&dtStart, &dtEnd, &ev.Summary, &ev.Recurring, &ev.CreatedAt, &ev.UpdatedAt)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if dtStart.Valid {
|
||||
ev.DTStart = dtStart.Time
|
||||
}
|
||||
if dtEnd.Valid {
|
||||
ev.DTEnd = dtEnd.Time
|
||||
}
|
||||
return ev, nil
|
||||
}
|
||||
|
||||
// UpsertCalendarEvent creates or replaces a calendar event.
|
||||
func (d *DB) UpsertCalendarEvent(ctx context.Context, calendarID int64, uid, etag string, icalEnc []byte, dtStart, dtEnd time.Time, summary string, recurring bool) error {
|
||||
now := time.Now().UTC()
|
||||
_, err := d.db.ExecContext(ctx, `
|
||||
INSERT INTO calendar_events (calendar_id, uid, ical_enc, etag, dt_start, dt_end, summary, recurring, created_at, updated_at)
|
||||
VALUES (?,?,?,?,?,?,?,?,?,?)
|
||||
ON CONFLICT(calendar_id, uid) DO UPDATE SET
|
||||
ical_enc=excluded.ical_enc, etag=excluded.etag,
|
||||
dt_start=excluded.dt_start, dt_end=excluded.dt_end,
|
||||
summary=excluded.summary, recurring=excluded.recurring,
|
||||
updated_at=excluded.updated_at`,
|
||||
calendarID, uid, icalEnc, etag, nullTime(dtStart), nullTime(dtEnd), summary, recurring, now, now)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteCalendarEvent removes one event by UID.
|
||||
func (d *DB) DeleteCalendarEvent(ctx context.Context, calendarID int64, uid string) error {
|
||||
_, err := d.db.ExecContext(ctx,
|
||||
"DELETE FROM calendar_events WHERE calendar_id=? AND uid=?", calendarID, uid)
|
||||
return err
|
||||
}
|
||||
|
||||
// ---- CardDAV ----
|
||||
|
||||
// ListAddressBooks returns all address books for a user.
|
||||
func (d *DB) ListAddressBooks(ctx context.Context, userID int64) ([]*models.AddressBook, error) {
|
||||
rows, err := d.db.QueryContext(ctx,
|
||||
"SELECT id, user_id, name, description, color, sync_token, created_at FROM address_books WHERE user_id=? ORDER BY name", userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var out []*models.AddressBook
|
||||
for rows.Next() {
|
||||
ab := &models.AddressBook{}
|
||||
if err := rows.Scan(&ab.ID, &ab.UserID, &ab.Name, &ab.Description, &ab.Color, &ab.SyncToken, &ab.CreatedAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, ab)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
// GetAddressBookByID returns an address book by ID.
|
||||
func (d *DB) GetAddressBookByID(ctx context.Context, id int64) (*models.AddressBook, error) {
|
||||
row := d.db.QueryRowContext(ctx,
|
||||
"SELECT id, user_id, name, description, color, sync_token, created_at FROM address_books WHERE id=?", id)
|
||||
ab := &models.AddressBook{}
|
||||
err := row.Scan(&ab.ID, &ab.UserID, &ab.Name, &ab.Description, &ab.Color, &ab.SyncToken, &ab.CreatedAt)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return ab, err
|
||||
}
|
||||
|
||||
// CreateAddressBook inserts a new address book.
|
||||
func (d *DB) CreateAddressBook(ctx context.Context, userID int64, name, description, color string) (int64, error) {
|
||||
if color == "" {
|
||||
color = "#4A90E2"
|
||||
}
|
||||
res, err := d.db.ExecContext(ctx,
|
||||
"INSERT INTO address_books (user_id, name, description, color, sync_token, created_at) VALUES (?,?,?,?,1,?)",
|
||||
userID, name, description, color, time.Now().UTC())
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("create address book: %w", err)
|
||||
}
|
||||
return res.LastInsertId()
|
||||
}
|
||||
|
||||
// EnsureDefaultAddressBook creates a "Personal" address book for a user if none exists.
|
||||
func (d *DB) EnsureDefaultAddressBook(ctx context.Context, userID int64) (*models.AddressBook, error) {
|
||||
abs, err := d.ListAddressBooks(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(abs) > 0 {
|
||||
return abs[0], nil
|
||||
}
|
||||
id, err := d.CreateAddressBook(ctx, userID, "Personal", "", "#4A90E2")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return d.GetAddressBookByID(ctx, id)
|
||||
}
|
||||
|
||||
// DeleteAddressBook removes an address book and all contacts.
|
||||
func (d *DB) DeleteAddressBook(ctx context.Context, addressBookID int64) error {
|
||||
_, err := d.db.ExecContext(ctx, "DELETE FROM address_books WHERE id=?", addressBookID)
|
||||
return err
|
||||
}
|
||||
|
||||
// BumpAddressBookSyncToken increments sync_token and returns the new value.
|
||||
func (d *DB) BumpAddressBookSyncToken(ctx context.Context, addressBookID int64) (int64, error) {
|
||||
_, err := d.db.ExecContext(ctx,
|
||||
"UPDATE address_books SET sync_token = sync_token + 1 WHERE id=?", addressBookID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
var token int64
|
||||
err = d.db.QueryRowContext(ctx, "SELECT sync_token FROM address_books WHERE id=?", addressBookID).Scan(&token)
|
||||
return token, err
|
||||
}
|
||||
|
||||
// ---- Contacts ----
|
||||
|
||||
// ListContacts returns all contacts in an address book.
|
||||
func (d *DB) ListContacts(ctx context.Context, addressBookID int64) ([]*models.Contact, error) {
|
||||
rows, err := d.db.QueryContext(ctx,
|
||||
"SELECT id, address_book_id, uid, vcard_enc, etag, created_at, updated_at FROM contacts WHERE address_book_id=? ORDER BY uid",
|
||||
addressBookID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var out []*models.Contact
|
||||
for rows.Next() {
|
||||
c := &models.Contact{}
|
||||
if err := rows.Scan(&c.ID, &c.AddressBookID, &c.UID, &c.VCardEnc, &c.ETag, &c.CreatedAt, &c.UpdatedAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, c)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
// GetContact returns one contact by UID.
|
||||
func (d *DB) GetContact(ctx context.Context, addressBookID int64, uid string) (*models.Contact, error) {
|
||||
row := d.db.QueryRowContext(ctx,
|
||||
"SELECT id, address_book_id, uid, vcard_enc, etag, created_at, updated_at FROM contacts WHERE address_book_id=? AND uid=?",
|
||||
addressBookID, uid)
|
||||
c := &models.Contact{}
|
||||
err := row.Scan(&c.ID, &c.AddressBookID, &c.UID, &c.VCardEnc, &c.ETag, &c.CreatedAt, &c.UpdatedAt)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return c, err
|
||||
}
|
||||
|
||||
// UpsertContact creates or replaces a contact.
|
||||
func (d *DB) UpsertContact(ctx context.Context, addressBookID int64, uid, etag string, vcardEnc []byte) error {
|
||||
now := time.Now().UTC()
|
||||
_, err := d.db.ExecContext(ctx, `
|
||||
INSERT INTO contacts (address_book_id, uid, vcard_enc, etag, created_at, updated_at)
|
||||
VALUES (?,?,?,?,?,?)
|
||||
ON CONFLICT(address_book_id, uid) DO UPDATE SET
|
||||
vcard_enc=excluded.vcard_enc, etag=excluded.etag, updated_at=excluded.updated_at`,
|
||||
addressBookID, uid, vcardEnc, etag, now, now)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteContact removes a contact by UID.
|
||||
func (d *DB) DeleteContact(ctx context.Context, addressBookID int64, uid string) error {
|
||||
_, err := d.db.ExecContext(ctx,
|
||||
"DELETE FROM contacts WHERE address_book_id=? AND uid=?", addressBookID, uid)
|
||||
return err
|
||||
}
|
||||
|
||||
// ---- helpers ----
|
||||
|
||||
func nullTime(t time.Time) *time.Time {
|
||||
if t.IsZero() {
|
||||
return nil
|
||||
}
|
||||
return &t
|
||||
}
|
||||
@@ -0,0 +1,165 @@
|
||||
// Package middleware provides reusable HTTP middleware for security headers
|
||||
// and in-memory rate limiting (token bucket per IP).
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ---- Security Headers ----
|
||||
|
||||
// SecureHeaders wraps an http.Handler and adds security-relevant response headers.
|
||||
// The CSP is intentionally strict: scripts only from same-origin CDN; no inline scripts.
|
||||
func SecureHeaders(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
h := w.Header()
|
||||
h.Set("Strict-Transport-Security", "max-age=63072000; includeSubDomains; preload")
|
||||
h.Set("X-Content-Type-Options", "nosniff")
|
||||
h.Set("X-Frame-Options", "SAMEORIGIN")
|
||||
h.Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||
h.Set("Permissions-Policy", "camera=(), microphone=(), geolocation=(), payment=()")
|
||||
h.Set("Content-Security-Policy",
|
||||
"default-src 'self'; "+
|
||||
"script-src 'self' https://cdn.tailwindcss.com; "+
|
||||
"style-src 'self' 'unsafe-inline' https://cdn.tailwindcss.com; "+
|
||||
"img-src 'self' data:; "+
|
||||
"font-src 'self'; "+
|
||||
"frame-src 'self'; "+
|
||||
"object-src 'none'; "+
|
||||
"base-uri 'self'; "+
|
||||
"form-action 'self'")
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// ---- Token-bucket rate limiter ----
|
||||
|
||||
// bucket is one client's token state.
|
||||
type bucket struct {
|
||||
tokens float64
|
||||
lastSeen time.Time
|
||||
}
|
||||
|
||||
// RateLimiter is an in-memory token-bucket limiter keyed by IP.
|
||||
// Safe for concurrent use.
|
||||
type RateLimiter struct {
|
||||
mu sync.Mutex
|
||||
buckets map[string]*bucket
|
||||
rate float64 // tokens refilled per second
|
||||
capacity float64 // max token count
|
||||
status int // HTTP status on exceeded (default 429)
|
||||
}
|
||||
|
||||
// NewRateLimiter creates a limiter.
|
||||
// - ratePerMin: tokens refilled per minute (e.g. 20 = 20 req/min steady-state)
|
||||
// - burst: max burst size (e.g. 5 = 5 simultaneous requests)
|
||||
func NewRateLimiter(ratePerMin int, burst int) *RateLimiter {
|
||||
rl := &RateLimiter{
|
||||
buckets: make(map[string]*bucket),
|
||||
rate: float64(ratePerMin) / 60.0,
|
||||
capacity: float64(burst),
|
||||
status: http.StatusTooManyRequests,
|
||||
}
|
||||
// Periodic cleanup goroutine (runs for process lifetime).
|
||||
go rl.cleanup()
|
||||
return rl
|
||||
}
|
||||
|
||||
// Allow returns true if the request is within limit, consuming one token.
|
||||
func (rl *RateLimiter) Allow(ip string) bool {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
b, ok := rl.buckets[ip]
|
||||
if !ok {
|
||||
b = &bucket{tokens: rl.capacity, lastSeen: now}
|
||||
rl.buckets[ip] = b
|
||||
}
|
||||
|
||||
// Refill tokens based on elapsed time.
|
||||
elapsed := now.Sub(b.lastSeen).Seconds()
|
||||
b.tokens += elapsed * rl.rate
|
||||
if b.tokens > rl.capacity {
|
||||
b.tokens = rl.capacity
|
||||
}
|
||||
b.lastSeen = now
|
||||
|
||||
if b.tokens < 1 {
|
||||
return false
|
||||
}
|
||||
b.tokens--
|
||||
return true
|
||||
}
|
||||
|
||||
// Middleware returns an http.Handler that rate-limits by remote IP.
|
||||
func (rl *RateLimiter) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ip := remoteIP(r)
|
||||
if !rl.Allow(ip) {
|
||||
w.Header().Set("Retry-After", "60")
|
||||
http.Error(w, http.StatusText(rl.status), rl.status)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// cleanup removes stale buckets every 5 minutes (no activity for >10 min).
|
||||
func (rl *RateLimiter) cleanup() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
cutoff := time.Now().Add(-10 * time.Minute)
|
||||
rl.mu.Lock()
|
||||
for ip, b := range rl.buckets {
|
||||
if b.lastSeen.Before(cutoff) {
|
||||
delete(rl.buckets, ip)
|
||||
}
|
||||
}
|
||||
rl.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// remoteIP extracts the real client IP, honoring X-Real-IP and X-Forwarded-For.
|
||||
// Falls back to RemoteAddr.
|
||||
func remoteIP(r *http.Request) string {
|
||||
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
||||
if ip := net.ParseIP(xri); ip != nil {
|
||||
return ip.String()
|
||||
}
|
||||
}
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
// First address in the chain is the original client.
|
||||
for _, part := range splitComma(xff) {
|
||||
trimmed := strings.TrimSpace(part)
|
||||
if ip := net.ParseIP(trimmed); ip != nil {
|
||||
return ip.String()
|
||||
}
|
||||
}
|
||||
}
|
||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
return r.RemoteAddr
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
// splitComma splits on comma without allocating a regex.
|
||||
func splitComma(s string) []string {
|
||||
var out []string
|
||||
start := 0
|
||||
for i := 0; i < len(s); i++ {
|
||||
if s[i] == ',' {
|
||||
out = append(out, s[start:i])
|
||||
start = i + 1
|
||||
}
|
||||
}
|
||||
out = append(out, s[start:])
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -1,12 +1,19 @@
|
||||
package smtp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"mime/multipart"
|
||||
"net/textproto"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/crypto"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/delivery"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/models"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/storage"
|
||||
)
|
||||
|
||||
// QueueWorker polls the delivery queue and dispatches messages.
|
||||
@@ -106,6 +113,7 @@ func (w *QueueWorker) drainQueue() {
|
||||
}
|
||||
|
||||
// markFailed updates queue status with exponential back-off or marks permanent failure.
|
||||
// On permanent failure with a non-empty sender, generates a DSN (RFC 3464) bounce.
|
||||
func (w *QueueWorker) markFailed(ctx context.Context, queueID int64, from, to, errMsg string, perm bool) {
|
||||
status := "failed"
|
||||
if perm {
|
||||
@@ -125,6 +133,135 @@ func (w *QueueWorker) markFailed(ctx context.Context, queueID int64, from, to, e
|
||||
}
|
||||
w.deps.DB.LogDelivery(ctx, queueID, from, to, status, 0, errMsg, "") //nolint:errcheck
|
||||
log.Printf("[queue] %s %d → %s: %s", status, queueID, to, errMsg)
|
||||
|
||||
// Generate DSN bounce for permanent failures only, never bounce a bounce
|
||||
// (null sender <> = already a DSN).
|
||||
if perm && from != "" && from != "<>" {
|
||||
w.sendDSN(ctx, from, to, errMsg)
|
||||
}
|
||||
}
|
||||
|
||||
// sendDSN delivers a Delivery Status Notification (RFC 3464) to the original sender.
|
||||
// Failures here are logged but not re-queued to avoid bounce loops.
|
||||
func (w *QueueWorker) sendDSN(ctx context.Context, originalFrom, failedTo, reason string) {
|
||||
sender := strings.ToLower(originalFrom)
|
||||
|
||||
// Determine if original sender is a local user.
|
||||
user, err := w.deps.DB.GetUserByEmail(ctx, sender)
|
||||
if err != nil {
|
||||
log.Printf("[queue] dsn lookup sender %s: %v", sender, err)
|
||||
return
|
||||
}
|
||||
if user == nil || !user.Enabled {
|
||||
// Sender is remote — attempt external SMTP delivery of DSN.
|
||||
dsnRaw, buildErr := buildDSN(w.deps.Cfg.Hostname, failedTo, reason)
|
||||
if buildErr != nil {
|
||||
log.Printf("[queue] dsn build: %v", buildErr)
|
||||
return
|
||||
}
|
||||
ehlo := w.deps.Cfg.SMTPHostname
|
||||
if ehlo == "" {
|
||||
ehlo = w.deps.Cfg.Hostname
|
||||
}
|
||||
result := delivery.Deliver(ctx, ehlo, "", sender, dsnRaw)
|
||||
if result.SMTPCode != 250 {
|
||||
log.Printf("[queue] dsn delivery to %s failed: %s", sender, result.Message)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Local sender — save DSN directly to INBOX.
|
||||
inbox, err := w.deps.DB.GetMailboxByType(ctx, user.ID, models.MailboxInbox)
|
||||
if err != nil || inbox == nil {
|
||||
log.Printf("[queue] dsn inbox for %s: %v", sender, err)
|
||||
return
|
||||
}
|
||||
|
||||
dsnRaw, err := buildDSN(w.deps.Cfg.Hostname, failedTo, reason)
|
||||
if err != nil {
|
||||
log.Printf("[queue] dsn build: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
hostname := w.deps.Cfg.Hostname
|
||||
if hostname == "" {
|
||||
hostname = "localhost"
|
||||
}
|
||||
|
||||
msg := &storage.IncomingMessage{
|
||||
Raw: dsnRaw,
|
||||
FromEmail: "",
|
||||
Subject: "Delivery Status Notification: failed to deliver to " + failedTo,
|
||||
Date: time.Now().UTC(),
|
||||
MessageID: fmt.Sprintf("<dsn-%d@%s>", time.Now().UnixNano(), hostname),
|
||||
}
|
||||
if _, err := w.deps.Store.SaveIncoming(ctx, user.ID, inbox.ID, msg); err != nil {
|
||||
log.Printf("[queue] dsn save: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// buildDSN constructs a minimal RFC 3464 multipart/report message.
|
||||
func buildDSN(hostname, failedTo, reason string) ([]byte, error) {
|
||||
if hostname == "" {
|
||||
hostname = "localhost"
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
msgID := fmt.Sprintf("<dsn-%d@%s>", now.UnixNano(), hostname)
|
||||
|
||||
var buf bytes.Buffer
|
||||
mw := multipart.NewWriter(&buf)
|
||||
boundary := mw.Boundary()
|
||||
|
||||
// Outer headers.
|
||||
header := fmt.Sprintf(
|
||||
"From: Mail Delivery Subsystem <mailer-daemon@%s>\r\n"+
|
||||
"To: <%s>\r\n"+
|
||||
"Subject: Delivery Status Notification (Failure)\r\n"+
|
||||
"Date: %s\r\n"+
|
||||
"Message-ID: %s\r\n"+
|
||||
"MIME-Version: 1.0\r\n"+
|
||||
"Content-Type: multipart/report; report-type=delivery-status; boundary=%q\r\n"+
|
||||
"\r\n",
|
||||
hostname, failedTo,
|
||||
now.Format("Mon, 02 Jan 2006 15:04:05 -0700"),
|
||||
msgID, boundary,
|
||||
)
|
||||
buf.WriteString(header)
|
||||
|
||||
// Part 1: human-readable explanation.
|
||||
ph := make(textproto.MIMEHeader)
|
||||
ph.Set("Content-Type", "text/plain; charset=utf-8")
|
||||
pw, err := mw.CreatePart(ph)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fmt.Fprintf(pw,
|
||||
"Your message could not be delivered to the following recipient:\r\n\r\n"+
|
||||
" Recipient: %s\r\n"+
|
||||
" Reason: %s\r\n\r\n"+
|
||||
"This is a permanent error. The message has not been delivered and will not be retried.\r\n",
|
||||
failedTo, reason)
|
||||
|
||||
// Part 2: machine-readable delivery-status (RFC 3464).
|
||||
sh := make(textproto.MIMEHeader)
|
||||
sh.Set("Content-Type", "message/delivery-status")
|
||||
sw, err := mw.CreatePart(sh)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fmt.Fprintf(sw,
|
||||
"Reporting-MTA: dns; %s\r\n\r\n"+
|
||||
"Final-Recipient: rfc822; %s\r\n"+
|
||||
"Action: failed\r\n"+
|
||||
"Status: 5.0.0\r\n"+
|
||||
"Diagnostic-Code: smtp; %s\r\n",
|
||||
hostname, failedTo, reason)
|
||||
|
||||
if err := mw.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// nextBackoff returns the back-off duration based on attempt count using
|
||||
|
||||
@@ -198,6 +198,40 @@ func (s *Store) GetRaw(ctx context.Context, userID, messageID int64) ([]byte, er
|
||||
return plain, nil
|
||||
}
|
||||
|
||||
// BodyParts holds decoded body content returned by GetBodyParts.
|
||||
type BodyParts struct {
|
||||
Text string
|
||||
HTML string
|
||||
Attachments []AttachmentMeta
|
||||
}
|
||||
|
||||
// AttachmentMeta describes an attachment without loading its bytes.
|
||||
type AttachmentMeta struct {
|
||||
Filename string
|
||||
ContentType string
|
||||
ContentID string // for inline images
|
||||
Inline bool
|
||||
}
|
||||
|
||||
// GetBodyParts decrypts a message and returns the text/HTML body and attachment list.
|
||||
func (s *Store) GetBodyParts(ctx context.Context, userID, messageID int64) (*BodyParts, error) {
|
||||
raw, err := s.GetRaw(ctx, userID, messageID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
text, html, atts := parseMIME(raw)
|
||||
bp := &BodyParts{Text: text, HTML: html}
|
||||
for _, a := range atts {
|
||||
bp.Attachments = append(bp.Attachments, AttachmentMeta{
|
||||
Filename: a.Filename,
|
||||
ContentType: a.ContentType,
|
||||
ContentID: a.ContentID,
|
||||
Inline: a.Inline,
|
||||
})
|
||||
}
|
||||
return bp, nil
|
||||
}
|
||||
|
||||
// parseMIME walks a raw RFC822 message and extracts body parts and attachments.
|
||||
func parseMIME(raw []byte) (bodyText, bodyHTML string, attachments []parsedAttachment) {
|
||||
m, err := mail.ReadMessage(bytes.NewReader(raw))
|
||||
|
||||
@@ -0,0 +1,143 @@
|
||||
// Package totp implements RFC 6238 TOTP using stdlib only.
|
||||
package totp
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha1" //nolint:gosec — RFC 6238 mandates SHA-1 for TOTP
|
||||
"encoding/base32"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// Digits is the number of OTP digits (6).
|
||||
Digits = 6
|
||||
// Period is the TOTP time step in seconds (30).
|
||||
Period = 30
|
||||
// Window is the number of adjacent windows to accept (±1 = 3 windows total).
|
||||
Window = 1
|
||||
// SecretBytes is the raw secret length (20 bytes = 160 bits).
|
||||
SecretBytes = 20
|
||||
// RecoveryCodeCount is the number of single-use backup codes generated.
|
||||
RecoveryCodeCount = 10
|
||||
// RecoveryCodeLen is the length of each backup code (characters, base32 subset).
|
||||
RecoveryCodeLen = 8
|
||||
)
|
||||
|
||||
// GenerateSecret returns a cryptographically random base32-encoded TOTP secret.
|
||||
func GenerateSecret() (string, error) {
|
||||
raw := make([]byte, SecretBytes)
|
||||
if _, err := rand.Read(raw); err != nil {
|
||||
return "", fmt.Errorf("totp: generate secret: %w", err)
|
||||
}
|
||||
return base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(raw), nil
|
||||
}
|
||||
|
||||
// OTPAuthURI builds the otpauth:// URI for a QR code or manual import.
|
||||
func OTPAuthURI(secret, accountName, issuer string) string {
|
||||
enc := func(s string) string {
|
||||
// RFC 3986 percent-encode for URI path/query
|
||||
var buf strings.Builder
|
||||
for _, b := range []byte(s) {
|
||||
if (b >= 'A' && b <= 'Z') || (b >= 'a' && b <= 'z') ||
|
||||
(b >= '0' && b <= '9') || b == '-' || b == '_' || b == '.' || b == '~' {
|
||||
buf.WriteByte(b)
|
||||
} else {
|
||||
fmt.Fprintf(&buf, "%%%02X", b)
|
||||
}
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
return fmt.Sprintf("otpauth://totp/%s:%s?secret=%s&issuer=%s&algorithm=SHA1&digits=%d&period=%d",
|
||||
enc(issuer), enc(accountName), secret, enc(issuer), Digits, Period)
|
||||
}
|
||||
|
||||
// Verify checks a 6-digit code against the secret for the current time window (±Window steps).
|
||||
// Returns true if valid. secret is base32-encoded (no padding).
|
||||
func Verify(secret, code string) bool {
|
||||
raw, err := base32.StdEncoding.WithPadding(base32.NoPadding).DecodeString(strings.ToUpper(secret))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
now := time.Now().Unix()
|
||||
step := now / Period
|
||||
for w := int64(-Window); w <= int64(Window); w++ {
|
||||
if hotp(raw, step+w) == code {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// hotp computes an HOTP value for a key and counter.
|
||||
func hotp(key []byte, counter int64) string {
|
||||
msg := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(msg, uint64(counter)) //nolint:gosec
|
||||
mac := hmac.New(sha1.New, key)
|
||||
mac.Write(msg)
|
||||
h := mac.Sum(nil)
|
||||
|
||||
// Dynamic truncation (RFC 4226 §5.4)
|
||||
offset := h[len(h)-1] & 0x0f
|
||||
binCode := (uint32(h[offset]&0x7f) << 24) |
|
||||
(uint32(h[offset+1]) << 16) |
|
||||
(uint32(h[offset+2]) << 8) |
|
||||
uint32(h[offset+3])
|
||||
|
||||
otp := binCode % uint32(math.Pow10(Digits))
|
||||
return fmt.Sprintf("%0*d", Digits, otp)
|
||||
}
|
||||
|
||||
// ---- Recovery codes ----
|
||||
|
||||
// GenerateRecoveryCodes returns RecoveryCodeCount random single-use codes.
|
||||
func GenerateRecoveryCodes() ([]string, error) {
|
||||
codes := make([]string, RecoveryCodeCount)
|
||||
const charset = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789" // no O/0/I/1 to avoid confusion
|
||||
buf := make([]byte, RecoveryCodeLen)
|
||||
for i := range codes {
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return nil, fmt.Errorf("totp: generate recovery codes: %w", err)
|
||||
}
|
||||
var sb strings.Builder
|
||||
for _, b := range buf {
|
||||
sb.WriteByte(charset[int(b)%len(charset)])
|
||||
}
|
||||
codes[i] = sb.String()
|
||||
}
|
||||
return codes, nil
|
||||
}
|
||||
|
||||
// EncodeRecoveryCodes marshals a code slice to JSON bytes for storage.
|
||||
func EncodeRecoveryCodes(codes []string) ([]byte, error) {
|
||||
return json.Marshal(codes)
|
||||
}
|
||||
|
||||
// DecodeRecoveryCodes unmarshals JSON bytes back to a code slice.
|
||||
func DecodeRecoveryCodes(data []byte) ([]string, error) {
|
||||
var codes []string
|
||||
if err := json.Unmarshal(data, &codes); err != nil {
|
||||
return nil, fmt.Errorf("totp: decode recovery codes: %w", err)
|
||||
}
|
||||
return codes, nil
|
||||
}
|
||||
|
||||
// ConsumeRecoveryCode removes a matching code (case-insensitive) and returns the
|
||||
// updated slice. Returns (nil, false) if code not found.
|
||||
func ConsumeRecoveryCode(codes []string, input string) ([]string, bool) {
|
||||
input = strings.ToUpper(strings.TrimSpace(input))
|
||||
for i, c := range codes {
|
||||
if strings.EqualFold(c, input) {
|
||||
updated := make([]string, 0, len(codes)-1)
|
||||
updated = append(updated, codes[:i]...)
|
||||
updated = append(updated, codes[i+1:]...)
|
||||
return updated, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
@@ -0,0 +1,311 @@
|
||||
package webadmin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
appCrypto "ghb.freebede.com/nahakubuilder/mailgosend/internal/crypto"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/totp"
|
||||
)
|
||||
|
||||
const adminPreAuthCookieName = "mailgo_admin_preauth"
|
||||
const adminPreAuthMaxAge = 300 // 5 minutes
|
||||
|
||||
// loginGet renders the admin login form.
|
||||
func (s *Server) loginGet(w http.ResponseWriter, r *http.Request) {
|
||||
// Already logged in → dashboard.
|
||||
if s.currentAdmin(r) != nil {
|
||||
http.Redirect(w, r, "/admin/", http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
flash, errMsg := flashFrom(r)
|
||||
s.render(w, "login", struct {
|
||||
basePage
|
||||
}{newBaseNoUser(flash, errMsg)})
|
||||
}
|
||||
|
||||
// loginPost handles credential submission (step 1 of 2 if MFA enabled).
|
||||
func (s *Server) loginPost(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
http.Error(w, "bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
email := strings.TrimSpace(r.FormValue("email"))
|
||||
password := r.FormValue("password")
|
||||
|
||||
clientIP := realIP(r)
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Brute-force check.
|
||||
if s.deps.Brute != nil && s.deps.Cfg.BruteMaxTries > 0 {
|
||||
banned, err := s.deps.Brute.IsBanned(ctx, clientIP)
|
||||
if err != nil {
|
||||
log.Printf("[admin] brute check: %v", err)
|
||||
}
|
||||
if banned {
|
||||
redirect(w, r, "/admin/login", "", "Too many failed attempts. Try again later.")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if email == "" || password == "" || len(email) > 254 || len(password) > 1024 {
|
||||
s.recordAttempt(ctx, clientIP, email, false)
|
||||
redirect(w, r, "/admin/login", "", "Invalid credentials.")
|
||||
return
|
||||
}
|
||||
|
||||
user, err := s.deps.DB.GetUserByEmail(ctx, email)
|
||||
if err != nil {
|
||||
log.Printf("[admin] login db: %v", err)
|
||||
redirect(w, r, "/admin/login", "", "Internal error.")
|
||||
return
|
||||
}
|
||||
|
||||
if user == nil || !user.Enabled || !user.Admin {
|
||||
s.recordAttempt(ctx, clientIP, email, false)
|
||||
redirect(w, r, "/admin/login", "", "Invalid credentials.")
|
||||
return
|
||||
}
|
||||
|
||||
if err := appCrypto.CheckPassword(user.PasswordHash, password); err != nil {
|
||||
s.recordAttempt(ctx, clientIP, email, false)
|
||||
redirect(w, r, "/admin/login", "", "Invalid credentials.")
|
||||
return
|
||||
}
|
||||
|
||||
// Password OK. Check MFA.
|
||||
if user.MFAEnabled && len(user.MFASecretEnc) > 0 {
|
||||
s.setAdminPreAuthCookie(w, user.ID)
|
||||
http.Redirect(w, r, "/admin/login/mfa", http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
// No MFA — create session.
|
||||
s.recordAttempt(ctx, clientIP, email, true)
|
||||
if _, err := s.deps.Sessions.Create(w, r, user.ID); err != nil {
|
||||
log.Printf("[admin] session create: %v", err)
|
||||
redirect(w, r, "/admin/login", "", "Session error.")
|
||||
return
|
||||
}
|
||||
s.deps.DB.UpdateLastLogin(ctx, user.ID)
|
||||
http.Redirect(w, r, "/admin/", http.StatusSeeOther)
|
||||
}
|
||||
|
||||
// mfaGet renders the TOTP challenge page for admin login.
|
||||
func (s *Server) mfaGet(w http.ResponseWriter, r *http.Request) {
|
||||
if s.currentAdmin(r) != nil {
|
||||
http.Redirect(w, r, "/admin/", http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
if _, ok := s.adminPreAuthUserID(r); !ok {
|
||||
http.Redirect(w, r, "/admin/login", http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
flash, errMsg := flashFrom(r)
|
||||
s.render(w, "mfa", struct{ basePage }{newBaseNoUser(flash, errMsg)})
|
||||
}
|
||||
|
||||
// mfaPost verifies TOTP and completes admin login.
|
||||
func (s *Server) mfaPost(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
http.Error(w, "bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
userID, ok := s.adminPreAuthUserID(r)
|
||||
if !ok {
|
||||
http.Redirect(w, r, "/admin/login", http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
clientIP := realIP(r)
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if s.deps.Brute != nil && s.deps.Cfg.BruteMaxTries > 0 {
|
||||
banned, _ := s.deps.Brute.IsBanned(ctx, clientIP)
|
||||
if banned {
|
||||
clearAdminPreAuth(w)
|
||||
redirect(w, r, "/admin/login", "", "Too many failed attempts. Try again later.")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
code := strings.TrimSpace(r.FormValue("code"))
|
||||
if len(code) == 0 || len(code) > 64 {
|
||||
s.recordAttempt(ctx, clientIP, "", false)
|
||||
redirect(w, r, "/admin/login/mfa", "", "Invalid code.")
|
||||
return
|
||||
}
|
||||
|
||||
user, err := s.deps.DB.GetUserByID(ctx, userID)
|
||||
if err != nil || user == nil || !user.Enabled || !user.Admin || !user.MFAEnabled {
|
||||
clearAdminPreAuth(w)
|
||||
http.Redirect(w, r, "/admin/login", http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
secretRaw, err := s.deps.Crypt.DecryptForUser(user.ID, "totp", user.MFASecretEnc)
|
||||
if err != nil {
|
||||
log.Printf("[admin] mfa decrypt: %v", err)
|
||||
clearAdminPreAuth(w)
|
||||
redirect(w, r, "/admin/login", "", "MFA error. Please try again.")
|
||||
return
|
||||
}
|
||||
|
||||
var authenticated bool
|
||||
if len(code) == totp.Digits {
|
||||
authenticated = totp.Verify(string(secretRaw), code)
|
||||
} else if len(user.RecoveryCodesEnc) > 0 {
|
||||
codesRaw, cerr := s.deps.Crypt.DecryptForUser(user.ID, "recovery", user.RecoveryCodesEnc)
|
||||
if cerr == nil {
|
||||
codes, cerr := totp.DecodeRecoveryCodes(codesRaw)
|
||||
if cerr == nil {
|
||||
updated, consumed := totp.ConsumeRecoveryCode(codes, code)
|
||||
if consumed {
|
||||
authenticated = true
|
||||
newEnc, encErr := s.deps.Crypt.EncryptForUser(user.ID, "recovery", mustMarshalAdminCodes(updated))
|
||||
if encErr == nil {
|
||||
_ = s.deps.DB.SetRecoveryCodes(ctx, user.ID, newEnc)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !authenticated {
|
||||
s.recordAttempt(ctx, clientIP, user.Email, false)
|
||||
redirect(w, r, "/admin/login/mfa", "", "Invalid code.")
|
||||
return
|
||||
}
|
||||
|
||||
clearAdminPreAuth(w)
|
||||
s.recordAttempt(ctx, clientIP, user.Email, true)
|
||||
if _, err := s.deps.Sessions.Create(w, r, user.ID); err != nil {
|
||||
log.Printf("[admin] session create: %v", err)
|
||||
redirect(w, r, "/admin/login", "", "Session error.")
|
||||
return
|
||||
}
|
||||
s.deps.DB.UpdateLastLogin(ctx, user.ID)
|
||||
http.Redirect(w, r, "/admin/", http.StatusSeeOther)
|
||||
}
|
||||
|
||||
// logout destroys the session and redirects to login.
|
||||
func (s *Server) logout(w http.ResponseWriter, r *http.Request) {
|
||||
if err := s.deps.Sessions.Destroy(w, r); err != nil {
|
||||
log.Printf("[admin] session destroy: %v", err)
|
||||
}
|
||||
http.Redirect(w, r, "/admin/login", http.StatusSeeOther)
|
||||
}
|
||||
|
||||
// ---- Pre-auth cookie helpers ----
|
||||
|
||||
func (s *Server) setAdminPreAuthCookie(w http.ResponseWriter, userID int64) {
|
||||
ts := fmt.Sprintf("%d", time.Now().Unix())
|
||||
uid := fmt.Sprintf("%d", userID)
|
||||
payload := uid + "|" + ts
|
||||
mac := adminPreAuthMAC(s.deps.Cfg.SessionSecret, payload)
|
||||
value := payload + "|" + mac
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: adminPreAuthCookieName,
|
||||
Value: value,
|
||||
Path: "/admin/login",
|
||||
MaxAge: adminPreAuthMaxAge,
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) adminPreAuthUserID(r *http.Request) (int64, bool) {
|
||||
c, err := r.Cookie(adminPreAuthCookieName)
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
parts := strings.SplitN(c.Value, "|", 3)
|
||||
if len(parts) != 3 {
|
||||
return 0, false
|
||||
}
|
||||
uid, ts, gotMAC := parts[0], parts[1], parts[2]
|
||||
payload := uid + "|" + ts
|
||||
|
||||
wantMAC := adminPreAuthMAC(s.deps.Cfg.SessionSecret, payload)
|
||||
if !hmac.Equal([]byte(gotMAC), []byte(wantMAC)) {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
tsi, err := strconv.ParseInt(ts, 10, 64)
|
||||
if err != nil || time.Now().Unix()-tsi > adminPreAuthMaxAge {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
id, err := strconv.ParseInt(uid, 10, 64)
|
||||
if err != nil || id <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
return id, true
|
||||
}
|
||||
|
||||
func clearAdminPreAuth(w http.ResponseWriter) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: adminPreAuthCookieName,
|
||||
Value: "",
|
||||
Path: "/admin/login",
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
})
|
||||
}
|
||||
|
||||
func adminPreAuthMAC(secret []byte, payload string) string {
|
||||
mac := hmac.New(sha256.New, secret)
|
||||
mac.Write([]byte(payload))
|
||||
return hex.EncodeToString(mac.Sum(nil))
|
||||
}
|
||||
|
||||
func mustMarshalAdminCodes(codes []string) []byte {
|
||||
data, err := totp.EncodeRecoveryCodes(codes)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("totp marshal: %v", err))
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
// ---- helpers ----
|
||||
|
||||
func (s *Server) recordAttempt(ctx context.Context, ip, email string, success bool) {
|
||||
if s.deps.Brute != nil {
|
||||
s.deps.Brute.RecordAttempt(ctx, ip, email, success)
|
||||
}
|
||||
}
|
||||
|
||||
// newBaseNoUser builds a basePage without requiring a session (used on login page).
|
||||
func newBaseNoUser(flash, errMsg string) basePage {
|
||||
return basePage{Flash: flash, Error: errMsg}
|
||||
}
|
||||
|
||||
// realIP extracts the client IP, honouring X-Forwarded-For when behind a proxy.
|
||||
func realIP(r *http.Request) string {
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
parts := strings.Split(xff, ",")
|
||||
ip := strings.TrimSpace(parts[0])
|
||||
if net.ParseIP(ip) != nil {
|
||||
return ip
|
||||
}
|
||||
}
|
||||
host, _, _ := net.SplitHostPort(r.RemoteAddr)
|
||||
return host
|
||||
}
|
||||
@@ -0,0 +1,712 @@
|
||||
package webadmin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
appCrypto "ghb.freebede.com/nahakubuilder/mailgosend/internal/crypto"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/db"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/dkim"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/models"
|
||||
)
|
||||
|
||||
// ---- Dashboard ----
|
||||
|
||||
type dashboardData struct {
|
||||
basePage
|
||||
Stats *db.AdminStats
|
||||
}
|
||||
|
||||
func (s *Server) dashboard(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
stats, err := s.deps.DB.GetAdminStats(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[admin] stats: %v", err)
|
||||
stats = &db.AdminStats{}
|
||||
}
|
||||
flash, errMsg := flashFrom(r)
|
||||
s.render(w, "dashboard", dashboardData{
|
||||
basePage: s.newBase(r, flash, errMsg),
|
||||
Stats: stats,
|
||||
})
|
||||
}
|
||||
|
||||
// ---- Domains ----
|
||||
|
||||
type domainsData struct {
|
||||
basePage
|
||||
Domains []*models.Domain
|
||||
}
|
||||
|
||||
func (s *Server) domainsList(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
doms, err := s.deps.DB.ListDomains(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[admin] list domains: %v", err)
|
||||
}
|
||||
flash, errMsg := flashFrom(r)
|
||||
s.render(w, "domains", domainsData{
|
||||
basePage: s.newBase(r, flash, errMsg),
|
||||
Domains: doms,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) domainsCreate(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if !s.validateCSRF(w, r) {
|
||||
return
|
||||
}
|
||||
|
||||
name := strings.ToLower(strings.TrimSpace(r.FormValue("name")))
|
||||
selector := strings.TrimSpace(r.FormValue("selector"))
|
||||
algo := r.FormValue("algo")
|
||||
|
||||
if !validDomain(name) {
|
||||
redirect(w, r, "/admin/domains", "", "Invalid domain name.")
|
||||
return
|
||||
}
|
||||
if selector == "" {
|
||||
selector = "mail"
|
||||
}
|
||||
if !validIdentifier(selector) {
|
||||
redirect(w, r, "/admin/domains", "", "Invalid DKIM selector.")
|
||||
return
|
||||
}
|
||||
if algo != "rsa2048" && algo != "ed25519" {
|
||||
algo = "rsa2048"
|
||||
}
|
||||
|
||||
domID, err := s.deps.DB.CreateDomain(ctx, name, selector, algo)
|
||||
if err != nil {
|
||||
log.Printf("[admin] create domain: %v", err)
|
||||
redirect(w, r, "/admin/domains", "", "Failed to create domain.")
|
||||
return
|
||||
}
|
||||
|
||||
// Auto-generate DKIM key pair.
|
||||
if err := s.generateDKIM(ctx, domID, algo, selector); err != nil {
|
||||
log.Printf("[admin] dkim keygen: %v", err)
|
||||
}
|
||||
|
||||
redirect(w, r, fmt.Sprintf("/admin/domains/%d", domID), "Domain created.", "")
|
||||
}
|
||||
|
||||
type domainDetailData struct {
|
||||
basePage
|
||||
Domain *models.Domain
|
||||
Users []*models.User
|
||||
// DNS hint strings
|
||||
DKIMRecord string
|
||||
SPFHint string
|
||||
DMARCHint string
|
||||
}
|
||||
|
||||
func (s *Server) domainDetail(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
id := pathID(r, "id")
|
||||
dom, err := s.deps.DB.GetDomainByID(ctx, id)
|
||||
if err != nil || dom == nil {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
users, _ := s.deps.DB.ListUsers(ctx, id)
|
||||
flash, errMsg := flashFrom(r)
|
||||
|
||||
dkimRec := ""
|
||||
if dom.DKIMPublic != "" {
|
||||
// Strip PEM headers and newlines to get bare base64 for DNS TXT.
|
||||
pub := strings.ReplaceAll(dom.DKIMPublic, "-----BEGIN PUBLIC KEY-----", "")
|
||||
pub = strings.ReplaceAll(pub, "-----END PUBLIC KEY-----", "")
|
||||
pub = strings.ReplaceAll(pub, "\n", "")
|
||||
pub = strings.TrimSpace(pub)
|
||||
dkimRec = fmt.Sprintf(`%s._domainkey.%s IN TXT "v=DKIM1; k=%s; p=%s"`,
|
||||
dom.DKIMSelector, dom.Name, dkimAlgoKey(dom.DKIMAlgo), pub)
|
||||
}
|
||||
|
||||
s.render(w, "domain", domainDetailData{
|
||||
basePage: s.newBase(r, flash, errMsg),
|
||||
Domain: dom,
|
||||
Users: users,
|
||||
DKIMRecord: dkimRec,
|
||||
SPFHint: fmt.Sprintf(`%s IN TXT "v=spf1 a mx ~all"`, dom.Name),
|
||||
DMARCHint: fmt.Sprintf(`_dmarc.%s IN TXT "v=DMARC1; p=quarantine; rua=mailto:postmaster@%s"`, dom.Name, dom.Name),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) domainToggleEnable(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if !s.validateCSRF(w, r) {
|
||||
return
|
||||
}
|
||||
|
||||
id := pathID(r, "id")
|
||||
enabled := r.FormValue("enabled") == "1"
|
||||
if err := s.deps.DB.SetDomainEnabled(ctx, id, enabled); err != nil {
|
||||
log.Printf("[admin] domain enable: %v", err)
|
||||
redirect(w, r, fmt.Sprintf("/admin/domains/%d", id), "", "Update failed.")
|
||||
return
|
||||
}
|
||||
redirect(w, r, fmt.Sprintf("/admin/domains/%d", id), "Domain updated.", "")
|
||||
}
|
||||
|
||||
func (s *Server) domainGenDKIM(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if !s.validateCSRF(w, r) {
|
||||
return
|
||||
}
|
||||
|
||||
id := pathID(r, "id")
|
||||
dom, err := s.deps.DB.GetDomainByID(ctx, id)
|
||||
if err != nil || dom == nil {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
algo := r.FormValue("algo")
|
||||
if algo != "rsa2048" && algo != "ed25519" {
|
||||
algo = dom.DKIMAlgo
|
||||
}
|
||||
if algo == "" {
|
||||
algo = "rsa2048"
|
||||
}
|
||||
|
||||
if err := s.generateDKIM(ctx, id, algo, dom.DKIMSelector); err != nil {
|
||||
log.Printf("[admin] dkim regen: %v", err)
|
||||
redirect(w, r, fmt.Sprintf("/admin/domains/%d", id), "", "DKIM generation failed.")
|
||||
return
|
||||
}
|
||||
redirect(w, r, fmt.Sprintf("/admin/domains/%d", id), "DKIM key regenerated. Update your DNS TXT record.", "")
|
||||
}
|
||||
|
||||
func (s *Server) domainSetLimits(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if !s.validateCSRF(w, r) {
|
||||
return
|
||||
}
|
||||
|
||||
id := pathID(r, "id")
|
||||
maxUsers, _ := strconv.Atoi(r.FormValue("max_users"))
|
||||
maxQuotaMB, _ := strconv.ParseInt(r.FormValue("max_quota_mb"), 10, 64)
|
||||
if maxUsers < 0 {
|
||||
maxUsers = 0
|
||||
}
|
||||
if maxQuotaMB < 0 {
|
||||
maxQuotaMB = 0
|
||||
}
|
||||
|
||||
if err := s.deps.DB.SetDomainLimits(ctx, id, maxUsers, maxQuotaMB*1024*1024); err != nil {
|
||||
log.Printf("[admin] domain limits: %v", err)
|
||||
redirect(w, r, fmt.Sprintf("/admin/domains/%d", id), "", "Update failed.")
|
||||
return
|
||||
}
|
||||
redirect(w, r, fmt.Sprintf("/admin/domains/%d", id), "Limits updated.", "")
|
||||
}
|
||||
|
||||
func (s *Server) domainDelete(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if !s.validateCSRF(w, r) {
|
||||
return
|
||||
}
|
||||
|
||||
id := pathID(r, "id")
|
||||
if err := s.deps.DB.DeleteDomain(ctx, id); err != nil {
|
||||
log.Printf("[admin] delete domain: %v", err)
|
||||
redirect(w, r, fmt.Sprintf("/admin/domains/%d", id), "", "Delete failed.")
|
||||
return
|
||||
}
|
||||
redirect(w, r, "/admin/domains", "Domain deleted.", "")
|
||||
}
|
||||
|
||||
// ---- Users ----
|
||||
|
||||
type usersData struct {
|
||||
basePage
|
||||
Users []*db.UserWithDomain
|
||||
Domains []*models.Domain
|
||||
}
|
||||
|
||||
func (s *Server) usersList(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
users, err := s.deps.DB.ListAllUsers(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[admin] list users: %v", err)
|
||||
}
|
||||
doms, _ := s.deps.DB.ListDomains(ctx)
|
||||
flash, errMsg := flashFrom(r)
|
||||
s.render(w, "users", usersData{
|
||||
basePage: s.newBase(r, flash, errMsg),
|
||||
Users: users,
|
||||
Domains: doms,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) usersCreate(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if !s.validateCSRF(w, r) {
|
||||
return
|
||||
}
|
||||
|
||||
domainID, _ := strconv.ParseInt(r.FormValue("domain_id"), 10, 64)
|
||||
username := strings.ToLower(strings.TrimSpace(r.FormValue("username")))
|
||||
password := r.FormValue("password")
|
||||
displayName := strings.TrimSpace(r.FormValue("display_name"))
|
||||
quotaMB, _ := strconv.ParseInt(r.FormValue("quota_mb"), 10, 64)
|
||||
domainAdmin := r.FormValue("domain_admin") == "1"
|
||||
|
||||
if domainID <= 0 || !validUsername(username) || len(password) < 8 || len(password) > 1024 {
|
||||
redirect(w, r, "/admin/users", "", "Invalid input. Username must be alphanumeric, password min 8 chars.")
|
||||
return
|
||||
}
|
||||
if quotaMB <= 0 {
|
||||
quotaMB = 1024 // 1 GB default
|
||||
}
|
||||
|
||||
dom, err := s.deps.DB.GetDomainByID(ctx, domainID)
|
||||
if err != nil || dom == nil {
|
||||
redirect(w, r, "/admin/users", "", "Domain not found.")
|
||||
return
|
||||
}
|
||||
|
||||
email := username + "@" + dom.Name
|
||||
exists, err := s.deps.DB.UserExistsByEmail(ctx, email)
|
||||
if err != nil || exists {
|
||||
redirect(w, r, "/admin/users", "", "Email already in use.")
|
||||
return
|
||||
}
|
||||
|
||||
hash, err := appCrypto.HashPassword(password)
|
||||
if err != nil {
|
||||
redirect(w, r, "/admin/users", "", "Password hashing failed.")
|
||||
return
|
||||
}
|
||||
|
||||
userID, err := s.deps.DB.CreateUser(ctx, domainID, username, email, hash, displayName, quotaMB*1024*1024, domainAdmin)
|
||||
if err != nil {
|
||||
log.Printf("[admin] create user: %v", err)
|
||||
redirect(w, r, "/admin/users", "", "Failed to create user.")
|
||||
return
|
||||
}
|
||||
|
||||
// Create default mailboxes, calendar, address book.
|
||||
if err := createDefaultMailboxes(ctx, s.deps.DB, userID); err != nil {
|
||||
log.Printf("[admin] default mailboxes: %v", err)
|
||||
}
|
||||
if _, err := s.deps.DB.EnsureDefaultCalendar(ctx, userID); err != nil {
|
||||
log.Printf("[admin] default calendar: %v", err)
|
||||
}
|
||||
if _, err := s.deps.DB.EnsureDefaultAddressBook(ctx, userID); err != nil {
|
||||
log.Printf("[admin] default address book: %v", err)
|
||||
}
|
||||
|
||||
redirect(w, r, fmt.Sprintf("/admin/users/%d", userID), "User created.", "")
|
||||
}
|
||||
|
||||
type userDetailData struct {
|
||||
basePage
|
||||
U *db.UserWithDomain
|
||||
}
|
||||
|
||||
func (s *Server) userDetail(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
id := pathID(r, "id")
|
||||
users, err := s.deps.DB.ListAllUsers(ctx)
|
||||
if err != nil {
|
||||
http.Error(w, "db error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
var found *db.UserWithDomain
|
||||
for _, u := range users {
|
||||
if u.ID == id {
|
||||
found = u
|
||||
break
|
||||
}
|
||||
}
|
||||
if found == nil {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
flash, errMsg := flashFrom(r)
|
||||
s.render(w, "user", userDetailData{
|
||||
basePage: s.newBase(r, flash, errMsg),
|
||||
U: found,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) userUpdate(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if !s.validateCSRF(w, r) {
|
||||
return
|
||||
}
|
||||
|
||||
id := pathID(r, "id")
|
||||
enabled := r.FormValue("enabled") == "1"
|
||||
admin := r.FormValue("admin") == "1"
|
||||
domainAdmin := r.FormValue("domain_admin") == "1"
|
||||
quotaMB, _ := strconv.ParseInt(r.FormValue("quota_mb"), 10, 64)
|
||||
displayName := strings.TrimSpace(r.FormValue("display_name"))
|
||||
|
||||
if len(displayName) > 255 {
|
||||
displayName = displayName[:255]
|
||||
}
|
||||
if quotaMB < 0 {
|
||||
quotaMB = 0
|
||||
}
|
||||
|
||||
if err := s.deps.DB.SetUserEnabled(ctx, id, enabled); err != nil {
|
||||
log.Printf("[admin] user enabled: %v", err)
|
||||
}
|
||||
if err := s.deps.DB.SetUserAdmin(ctx, id, admin, domainAdmin); err != nil {
|
||||
log.Printf("[admin] user admin: %v", err)
|
||||
}
|
||||
if err := s.deps.DB.SetUserQuota(ctx, id, quotaMB*1024*1024); err != nil {
|
||||
log.Printf("[admin] user quota: %v", err)
|
||||
}
|
||||
if err := s.deps.DB.SetUserDisplayName(ctx, id, displayName); err != nil {
|
||||
log.Printf("[admin] user display: %v", err)
|
||||
}
|
||||
|
||||
redirect(w, r, fmt.Sprintf("/admin/users/%d", id), "User updated.", "")
|
||||
}
|
||||
|
||||
func (s *Server) userPassword(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if !s.validateCSRF(w, r) {
|
||||
return
|
||||
}
|
||||
|
||||
id := pathID(r, "id")
|
||||
password := r.FormValue("password")
|
||||
if len(password) < 8 || len(password) > 1024 {
|
||||
redirect(w, r, fmt.Sprintf("/admin/users/%d", id), "", "Password must be 8-1024 characters.")
|
||||
return
|
||||
}
|
||||
|
||||
hash, err := appCrypto.HashPassword(password)
|
||||
if err != nil {
|
||||
redirect(w, r, fmt.Sprintf("/admin/users/%d", id), "", "Password error.")
|
||||
return
|
||||
}
|
||||
if err := s.deps.DB.SetUserPassword(ctx, id, hash); err != nil {
|
||||
log.Printf("[admin] set password: %v", err)
|
||||
redirect(w, r, fmt.Sprintf("/admin/users/%d", id), "", "Failed to update password.")
|
||||
return
|
||||
}
|
||||
redirect(w, r, fmt.Sprintf("/admin/users/%d", id), "Password updated.", "")
|
||||
}
|
||||
|
||||
func (s *Server) userDelete(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if !s.validateCSRF(w, r) {
|
||||
return
|
||||
}
|
||||
|
||||
id := pathID(r, "id")
|
||||
if err := s.deps.DB.DeleteUser(ctx, id); err != nil {
|
||||
log.Printf("[admin] delete user: %v", err)
|
||||
redirect(w, r, fmt.Sprintf("/admin/users/%d", id), "", "Delete failed.")
|
||||
return
|
||||
}
|
||||
redirect(w, r, "/admin/users", "User deleted.", "")
|
||||
}
|
||||
|
||||
// ---- Queue ----
|
||||
|
||||
type queueData struct {
|
||||
basePage
|
||||
Entries []*db.QueueEntry
|
||||
}
|
||||
|
||||
func (s *Server) queueList(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
entries, err := s.deps.DB.ListQueueEntries(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[admin] queue list: %v", err)
|
||||
}
|
||||
flash, errMsg := flashFrom(r)
|
||||
s.render(w, "queue", queueData{
|
||||
basePage: s.newBase(r, flash, errMsg),
|
||||
Entries: entries,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) queueRetry(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if !s.validateCSRF(w, r) {
|
||||
return
|
||||
}
|
||||
|
||||
id := pathID(r, "id")
|
||||
if err := s.deps.DB.RetryQueueEntry(ctx, id); err != nil {
|
||||
log.Printf("[admin] queue retry: %v", err)
|
||||
redirect(w, r, "/admin/queue", "", "Retry failed.")
|
||||
return
|
||||
}
|
||||
redirect(w, r, "/admin/queue", "Entry queued for immediate retry.", "")
|
||||
}
|
||||
|
||||
func (s *Server) queueDelete(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if !s.validateCSRF(w, r) {
|
||||
return
|
||||
}
|
||||
|
||||
id := pathID(r, "id")
|
||||
if err := s.deps.DB.DeleteQueueEntry(ctx, id); err != nil {
|
||||
log.Printf("[admin] queue delete: %v", err)
|
||||
redirect(w, r, "/admin/queue", "", "Delete failed.")
|
||||
return
|
||||
}
|
||||
redirect(w, r, "/admin/queue", "Queue entry deleted.", "")
|
||||
}
|
||||
|
||||
// ---- IP Bans ----
|
||||
|
||||
type bansData struct {
|
||||
basePage
|
||||
Bans []*db.IPBan
|
||||
}
|
||||
|
||||
func (s *Server) bansList(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
bans, err := s.deps.DB.ListIPBans(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[admin] list bans: %v", err)
|
||||
}
|
||||
flash, errMsg := flashFrom(r)
|
||||
s.render(w, "bans", bansData{
|
||||
basePage: s.newBase(r, flash, errMsg),
|
||||
Bans: bans,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) bansAdd(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if !s.validateCSRF(w, r) {
|
||||
return
|
||||
}
|
||||
|
||||
ip := strings.TrimSpace(r.FormValue("ip"))
|
||||
reason := strings.TrimSpace(r.FormValue("reason"))
|
||||
hours, _ := strconv.Atoi(r.FormValue("hours"))
|
||||
|
||||
if net.ParseIP(ip) == nil {
|
||||
redirect(w, r, "/admin/bans", "", "Invalid IP address.")
|
||||
return
|
||||
}
|
||||
if len(reason) > 255 {
|
||||
reason = reason[:255]
|
||||
}
|
||||
if hours < 0 {
|
||||
hours = 0
|
||||
}
|
||||
|
||||
if err := s.deps.DB.AddIPBan(ctx, ip, reason, hours); err != nil {
|
||||
log.Printf("[admin] add ban: %v", err)
|
||||
redirect(w, r, "/admin/bans", "", "Failed to add ban.")
|
||||
return
|
||||
}
|
||||
redirect(w, r, "/admin/bans", fmt.Sprintf("IP %s banned.", ip), "")
|
||||
}
|
||||
|
||||
func (s *Server) bansRemove(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if !s.validateCSRF(w, r) {
|
||||
return
|
||||
}
|
||||
|
||||
ip := r.PathValue("ip")
|
||||
if net.ParseIP(ip) == nil {
|
||||
redirect(w, r, "/admin/bans", "", "Invalid IP.")
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.deps.DB.RemoveIPBan(ctx, ip); err != nil {
|
||||
log.Printf("[admin] remove ban: %v", err)
|
||||
redirect(w, r, "/admin/bans", "", "Remove failed.")
|
||||
return
|
||||
}
|
||||
redirect(w, r, "/admin/bans", fmt.Sprintf("Ban on %s removed.", ip), "")
|
||||
}
|
||||
|
||||
// ---- Security Events ----
|
||||
|
||||
type eventsData struct {
|
||||
basePage
|
||||
Events []*db.SecurityEvent
|
||||
}
|
||||
|
||||
func (s *Server) eventsList(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
limit := 200
|
||||
if q := r.URL.Query().Get("limit"); q != "" {
|
||||
if n, err := strconv.Atoi(q); err == nil && n > 0 && n <= 1000 {
|
||||
limit = n
|
||||
}
|
||||
}
|
||||
|
||||
evs, err := s.deps.DB.ListSecurityEvents(ctx, limit)
|
||||
if err != nil {
|
||||
log.Printf("[admin] events: %v", err)
|
||||
}
|
||||
flash, errMsg := flashFrom(r)
|
||||
s.render(w, "events", eventsData{
|
||||
basePage: s.newBase(r, flash, errMsg),
|
||||
Events: evs,
|
||||
})
|
||||
}
|
||||
|
||||
// ---- internal helpers ----
|
||||
|
||||
// validateCSRF checks the CSRF token; on failure writes a 403 and returns false.
|
||||
func (s *Server) validateCSRF(w http.ResponseWriter, r *http.Request) bool {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
http.Error(w, "bad request", http.StatusBadRequest)
|
||||
return false
|
||||
}
|
||||
sess, _, _ := s.deps.Sessions.Get(r)
|
||||
if sess == nil {
|
||||
http.Error(w, "unauthenticated", http.StatusForbidden)
|
||||
return false
|
||||
}
|
||||
if !s.checkCSRF(r, sess.TokenHash) {
|
||||
http.Error(w, "CSRF validation failed", http.StatusForbidden)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// pathID extracts a positive int64 from a URL path value. Returns 0 on error.
|
||||
func pathID(r *http.Request, key string) int64 {
|
||||
id, _ := strconv.ParseInt(r.PathValue(key), 10, 64)
|
||||
return id
|
||||
}
|
||||
|
||||
// validDomain accepts simple dot-separated labels (a-z0-9 and hyphens).
|
||||
func validDomain(s string) bool {
|
||||
if len(s) < 3 || len(s) > 253 {
|
||||
return false
|
||||
}
|
||||
for _, label := range strings.Split(s, ".") {
|
||||
if len(label) == 0 || len(label) > 63 {
|
||||
return false
|
||||
}
|
||||
for _, c := range label {
|
||||
if !((c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-') {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// validUsername accepts lowercase alphanumeric + dots + hyphens, 1-64 chars.
|
||||
func validUsername(s string) bool {
|
||||
if len(s) < 1 || len(s) > 64 {
|
||||
return false
|
||||
}
|
||||
for _, c := range s {
|
||||
if !((c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '.' || c == '-' || c == '_') {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// validIdentifier accepts [a-zA-Z0-9_-], 1-63 chars.
|
||||
func validIdentifier(s string) bool {
|
||||
if len(s) < 1 || len(s) > 63 {
|
||||
return false
|
||||
}
|
||||
for _, c := range s {
|
||||
if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '-' || c == '_') {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// dkimAlgoKey converts our algo string to the DNS TXT k= value.
|
||||
func dkimAlgoKey(algo string) string {
|
||||
if algo == "ed25519" {
|
||||
return "ed25519"
|
||||
}
|
||||
return "rsa"
|
||||
}
|
||||
|
||||
// generateDKIM creates a new DKIM key pair and persists it encrypted.
|
||||
func (s *Server) generateDKIM(ctx context.Context, domainID int64, algo, selector string) error {
|
||||
privPEM, pubPEM, err := dkim.GenerateKeyPair(algo)
|
||||
if err != nil {
|
||||
return fmt.Errorf("keygen: %w", err)
|
||||
}
|
||||
|
||||
privEnc, err := s.deps.Crypt.EncryptGlobal("dkim", []byte(privPEM))
|
||||
if err != nil {
|
||||
return fmt.Errorf("encrypt dkim key: %w", err)
|
||||
}
|
||||
|
||||
if err := s.deps.DB.SaveDKIMKeys(ctx, domainID, privEnc, pubPEM); err != nil {
|
||||
return fmt.Errorf("save dkim keys: %w", err)
|
||||
}
|
||||
|
||||
// Update algo + selector in case they changed.
|
||||
_, err = s.deps.DB.SQL().ExecContext(ctx,
|
||||
"UPDATE domains SET dkim_algo=?, dkim_selector=? WHERE id=?", algo, selector, domainID)
|
||||
return err
|
||||
}
|
||||
|
||||
// createDefaultMailboxes creates INBOX, Sent, Drafts, Trash, Spam, Archive for a new user.
|
||||
func createDefaultMailboxes(ctx context.Context, database *db.DB, userID int64) error {
|
||||
return database.CreateDefaultMailboxes(ctx, userID)
|
||||
}
|
||||
@@ -0,0 +1,242 @@
|
||||
// Package webadmin provides the HTTP admin panel (default: 127.0.0.1:8081).
|
||||
// All routes require an authenticated admin session except /admin/login.
|
||||
package webadmin
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"io/fs"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/auth"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/config"
|
||||
appCrypto "ghb.freebede.com/nahakubuilder/mailgosend/internal/crypto"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/db"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/models"
|
||||
)
|
||||
|
||||
// Deps groups all dependencies for the admin panel.
|
||||
type Deps struct {
|
||||
DB *db.DB
|
||||
Crypt *appCrypto.Crypto
|
||||
Sessions *auth.SessionStore
|
||||
Brute *auth.BruteGuard
|
||||
Cfg *config.Config
|
||||
FS fs.FS // embed.FS sub-rooted at web/admin
|
||||
}
|
||||
|
||||
// Server handles all admin HTTP routes.
|
||||
type Server struct {
|
||||
deps *Deps
|
||||
mux *http.ServeMux
|
||||
}
|
||||
|
||||
// New creates an admin Server and registers all routes.
|
||||
func New(deps *Deps) *Server {
|
||||
s := &Server{deps: deps, mux: http.NewServeMux()}
|
||||
s.setupRoutes()
|
||||
return s
|
||||
}
|
||||
|
||||
// Handler returns the HTTP handler for the admin server.
|
||||
// Requests are logged; all routes except /admin/login require admin auth.
|
||||
func (s *Server) Handler() http.Handler {
|
||||
return logMiddleware(s.mux)
|
||||
}
|
||||
|
||||
// ---- routing ----
|
||||
|
||||
func (s *Server) setupRoutes() {
|
||||
m := s.mux
|
||||
|
||||
// Public
|
||||
m.HandleFunc("GET /admin/login", s.loginGet)
|
||||
m.HandleFunc("POST /admin/login", s.loginPost)
|
||||
m.HandleFunc("GET /admin/login/mfa", s.mfaGet)
|
||||
m.HandleFunc("POST /admin/login/mfa", s.mfaPost)
|
||||
m.HandleFunc("GET /admin/logout", s.logout)
|
||||
|
||||
// Protected
|
||||
m.HandleFunc("GET /admin/{$}", s.require(s.dashboard))
|
||||
|
||||
m.HandleFunc("GET /admin/domains", s.require(s.domainsList))
|
||||
m.HandleFunc("POST /admin/domains", s.require(s.domainsCreate))
|
||||
m.HandleFunc("GET /admin/domains/{id}", s.require(s.domainDetail))
|
||||
m.HandleFunc("POST /admin/domains/{id}/enable", s.require(s.domainToggleEnable))
|
||||
m.HandleFunc("POST /admin/domains/{id}/dkim", s.require(s.domainGenDKIM))
|
||||
m.HandleFunc("POST /admin/domains/{id}/limits", s.require(s.domainSetLimits))
|
||||
m.HandleFunc("POST /admin/domains/{id}/delete", s.require(s.domainDelete))
|
||||
|
||||
m.HandleFunc("GET /admin/users", s.require(s.usersList))
|
||||
m.HandleFunc("POST /admin/users", s.require(s.usersCreate))
|
||||
m.HandleFunc("GET /admin/users/{id}", s.require(s.userDetail))
|
||||
m.HandleFunc("POST /admin/users/{id}/update", s.require(s.userUpdate))
|
||||
m.HandleFunc("POST /admin/users/{id}/password", s.require(s.userPassword))
|
||||
m.HandleFunc("POST /admin/users/{id}/delete", s.require(s.userDelete))
|
||||
|
||||
m.HandleFunc("GET /admin/queue", s.require(s.queueList))
|
||||
m.HandleFunc("POST /admin/queue/{id}/retry", s.require(s.queueRetry))
|
||||
m.HandleFunc("POST /admin/queue/{id}/delete", s.require(s.queueDelete))
|
||||
|
||||
m.HandleFunc("GET /admin/bans", s.require(s.bansList))
|
||||
m.HandleFunc("POST /admin/bans", s.require(s.bansAdd))
|
||||
m.HandleFunc("POST /admin/bans/{ip}/remove", s.require(s.bansRemove))
|
||||
|
||||
m.HandleFunc("GET /admin/events", s.require(s.eventsList))
|
||||
|
||||
// Static assets
|
||||
static, _ := fs.Sub(s.deps.FS, "static")
|
||||
m.Handle("GET /admin/static/", http.StripPrefix("/admin/static/", http.FileServer(http.FS(static))))
|
||||
}
|
||||
|
||||
// ---- middleware ----
|
||||
|
||||
// require wraps a handler with admin session enforcement.
|
||||
func (s *Server) require(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
user := s.currentAdmin(r)
|
||||
if user == nil {
|
||||
http.Redirect(w, r, "/admin/login", http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// currentAdmin returns the logged-in admin user, or nil if unauthenticated / not admin.
|
||||
func (s *Server) currentAdmin(r *http.Request) *models.User {
|
||||
_, user, err := s.deps.Sessions.Get(r)
|
||||
if err != nil || user == nil || !user.Admin || !user.Enabled {
|
||||
return nil
|
||||
}
|
||||
return user
|
||||
}
|
||||
|
||||
// logMiddleware logs method + path + duration for every request.
|
||||
func logMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
next.ServeHTTP(w, r)
|
||||
log.Printf("[admin] %s %s %s", r.Method, r.URL.Path, time.Since(start))
|
||||
})
|
||||
}
|
||||
|
||||
// ---- CSRF ----
|
||||
|
||||
// csrfToken returns an HMAC-SHA256 token valid for the current clock-hour.
|
||||
// Bound to sessionHash so it cannot be forged without the session secret.
|
||||
func (s *Server) csrfToken(sessionHash string) string {
|
||||
h := hmac.New(sha256.New, s.deps.Cfg.SessionSecret)
|
||||
h.Write([]byte(sessionHash))
|
||||
h.Write([]byte(time.Now().UTC().Format("2006-01-02-15")))
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
// checkCSRF validates the CSRF token submitted via form field "_csrf".
|
||||
// Also accepts previous-hour token to handle hour-boundary edge cases.
|
||||
func (s *Server) checkCSRF(r *http.Request, sessionHash string) bool {
|
||||
got := r.FormValue("_csrf")
|
||||
if got == "" {
|
||||
return false
|
||||
}
|
||||
cur := s.csrfToken(sessionHash)
|
||||
if hmac.Equal([]byte(got), []byte(cur)) {
|
||||
return true
|
||||
}
|
||||
// Allow previous-hour token (grace window).
|
||||
prev := s.csrfTokenAt(sessionHash, time.Now().UTC().Add(-time.Hour))
|
||||
return hmac.Equal([]byte(got), []byte(prev))
|
||||
}
|
||||
|
||||
func (s *Server) csrfTokenAt(sessionHash string, t time.Time) string {
|
||||
h := hmac.New(sha256.New, s.deps.Cfg.SessionSecret)
|
||||
h.Write([]byte(sessionHash))
|
||||
h.Write([]byte(t.Format("2006-01-02-15")))
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
// ---- Template rendering ----
|
||||
|
||||
// basePage holds fields available in every template.
|
||||
type basePage struct {
|
||||
Flash string
|
||||
Error string
|
||||
CSRF string
|
||||
Admin *models.User
|
||||
}
|
||||
|
||||
func (s *Server) newBase(r *http.Request, flash, errMsg string) basePage {
|
||||
sessObj, user, _ := s.deps.Sessions.Get(r)
|
||||
var csrf string
|
||||
if sessObj != nil {
|
||||
csrf = s.csrfToken(sessObj.TokenHash)
|
||||
}
|
||||
return basePage{Flash: flash, Error: errMsg, CSRF: csrf, Admin: user}
|
||||
}
|
||||
|
||||
// render parses base.html + page.html from embed.FS and executes "base" template.
|
||||
func (s *Server) render(w http.ResponseWriter, page string, data any) {
|
||||
t, err := template.New("").Funcs(tmplFuncs).ParseFS(
|
||||
s.deps.FS,
|
||||
"templates/base.html",
|
||||
"templates/"+page+".html",
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("[admin] template parse %s: %v", page, err)
|
||||
http.Error(w, "template error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
w.Header().Set("X-Frame-Options", "DENY")
|
||||
if err := t.ExecuteTemplate(w, "base", data); err != nil {
|
||||
log.Printf("[admin] template exec %s: %v", page, err)
|
||||
}
|
||||
}
|
||||
|
||||
// redirect sends a 303 to a path with optional flash/error query params.
|
||||
func redirect(w http.ResponseWriter, r *http.Request, path, flash, errMsg string) {
|
||||
target := path
|
||||
if flash != "" {
|
||||
target += "?flash=" + urlEncode(flash)
|
||||
} else if errMsg != "" {
|
||||
target += "?error=" + urlEncode(errMsg)
|
||||
}
|
||||
http.Redirect(w, r, target, http.StatusSeeOther)
|
||||
}
|
||||
|
||||
// flashFrom extracts flash/error from query params (used after redirect).
|
||||
func flashFrom(r *http.Request) (flash, errMsg string) {
|
||||
return r.URL.Query().Get("flash"), r.URL.Query().Get("error")
|
||||
}
|
||||
|
||||
// urlEncode does basic percent-encoding for query values.
|
||||
func urlEncode(s string) string {
|
||||
return fmt.Sprintf("%s", template.URLQueryEscaper(s))
|
||||
}
|
||||
|
||||
// tmplFuncs are custom template functions available in all admin templates.
|
||||
var tmplFuncs = template.FuncMap{
|
||||
"humanBytes": humanBytes,
|
||||
"shortTime": func(t time.Time) string { return t.Format("2006-01-02 15:04") },
|
||||
"isZero": func(t time.Time) bool { return t.IsZero() },
|
||||
"mb": func(b int64) int64 { return b / 1024 / 1024 },
|
||||
}
|
||||
|
||||
func humanBytes(b int64) string {
|
||||
const unit = 1024
|
||||
if b < unit {
|
||||
return fmt.Sprintf("%d B", b)
|
||||
}
|
||||
div, exp := int64(unit), 0
|
||||
for n := b / unit; n >= unit; n /= unit {
|
||||
div *= unit
|
||||
exp++
|
||||
}
|
||||
return fmt.Sprintf("%.1f %cB", float64(b)/float64(div), "KMGTPE"[exp])
|
||||
}
|
||||
@@ -0,0 +1,315 @@
|
||||
package webclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
appCrypto "ghb.freebede.com/nahakubuilder/mailgosend/internal/crypto"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/totp"
|
||||
)
|
||||
|
||||
const preAuthCookieName = "mailgo_preauth"
|
||||
const preAuthMaxAge = 300 // 5 minutes
|
||||
|
||||
// ---- Login (step 1: password) ----
|
||||
|
||||
func (s *Server) loginGet(w http.ResponseWriter, r *http.Request) {
|
||||
if s.currentUser(r) != nil {
|
||||
http.Redirect(w, r, "/", http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
flash, errMsg := flashFrom(r)
|
||||
s.render(w, "login", struct{ basePage }{
|
||||
basePage: basePage{Flash: flash, Error: errMsg},
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) loginPost(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
http.Error(w, "bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
email := strings.ToLower(strings.TrimSpace(r.FormValue("email")))
|
||||
password := r.FormValue("password")
|
||||
clientIP := realIP(r)
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Brute-force check.
|
||||
if s.deps.Brute != nil && s.deps.Cfg.BruteMaxTries > 0 {
|
||||
banned, err := s.deps.Brute.IsBanned(ctx, clientIP)
|
||||
if err != nil {
|
||||
log.Printf("[webmail] brute check: %v", err)
|
||||
}
|
||||
if banned {
|
||||
redirect(w, r, "/login", "", "Too many failed attempts. Try again later.")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if email == "" || password == "" || len(email) > 254 || len(password) > 1024 {
|
||||
s.recordAttempt(ctx, clientIP, email, false)
|
||||
redirect(w, r, "/login", "", "Invalid credentials.")
|
||||
return
|
||||
}
|
||||
|
||||
user, err := s.deps.DB.GetUserByEmail(ctx, email)
|
||||
if err != nil {
|
||||
log.Printf("[webmail] login db: %v", err)
|
||||
redirect(w, r, "/login", "", "Internal error.")
|
||||
return
|
||||
}
|
||||
|
||||
if user == nil || !user.Enabled {
|
||||
s.recordAttempt(ctx, clientIP, email, false)
|
||||
redirect(w, r, "/login", "", "Invalid credentials.")
|
||||
return
|
||||
}
|
||||
|
||||
if err := appCrypto.CheckPassword(user.PasswordHash, password); err != nil {
|
||||
s.recordAttempt(ctx, clientIP, email, false)
|
||||
redirect(w, r, "/login", "", "Invalid credentials.")
|
||||
return
|
||||
}
|
||||
|
||||
// Password OK. Check MFA.
|
||||
if user.MFAEnabled && len(user.MFASecretEnc) > 0 {
|
||||
// Issue pre-auth cookie and redirect to TOTP step.
|
||||
s.setPreAuthCookie(w, user.ID)
|
||||
http.Redirect(w, r, "/login/mfa", http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
// No MFA — create session directly.
|
||||
s.recordAttempt(ctx, clientIP, email, true)
|
||||
if _, err := s.deps.Sessions.Create(w, r, user.ID); err != nil {
|
||||
log.Printf("[webmail] session create: %v", err)
|
||||
redirect(w, r, "/login", "", "Session error.")
|
||||
return
|
||||
}
|
||||
s.deps.DB.UpdateLastLogin(ctx, user.ID)
|
||||
http.Redirect(w, r, "/", http.StatusSeeOther)
|
||||
}
|
||||
|
||||
// ---- Login (step 2: TOTP) ----
|
||||
|
||||
func (s *Server) mfaGet(w http.ResponseWriter, r *http.Request) {
|
||||
if s.currentUser(r) != nil {
|
||||
http.Redirect(w, r, "/", http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
if _, ok := s.preAuthUserID(r); !ok {
|
||||
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
flash, errMsg := flashFrom(r)
|
||||
s.render(w, "mfa", struct{ basePage }{
|
||||
basePage: basePage{Flash: flash, Error: errMsg},
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) mfaPost(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
http.Error(w, "bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
userID, ok := s.preAuthUserID(r)
|
||||
if !ok {
|
||||
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
clientIP := realIP(r)
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Brute-force check (TOTP codes are brute-forceable too).
|
||||
if s.deps.Brute != nil && s.deps.Cfg.BruteMaxTries > 0 {
|
||||
banned, _ := s.deps.Brute.IsBanned(ctx, clientIP)
|
||||
if banned {
|
||||
clearPreAuth(w)
|
||||
redirect(w, r, "/login", "", "Too many failed attempts. Try again later.")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
code := strings.TrimSpace(r.FormValue("code"))
|
||||
if len(code) == 0 || len(code) > 64 {
|
||||
s.recordAttempt(ctx, clientIP, "", false)
|
||||
redirect(w, r, "/login/mfa", "", "Invalid code.")
|
||||
return
|
||||
}
|
||||
|
||||
user, err := s.deps.DB.GetUserByID(ctx, userID)
|
||||
if err != nil || user == nil || !user.Enabled || !user.MFAEnabled {
|
||||
clearPreAuth(w)
|
||||
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
// Decrypt the TOTP secret.
|
||||
secretRaw, err := s.deps.Crypt.DecryptForUser(user.ID, "totp", user.MFASecretEnc)
|
||||
if err != nil {
|
||||
log.Printf("[webmail] mfa decrypt: %v", err)
|
||||
clearPreAuth(w)
|
||||
redirect(w, r, "/login", "", "MFA error. Please try again.")
|
||||
return
|
||||
}
|
||||
|
||||
var authenticated bool
|
||||
|
||||
if len(code) == totp.Digits {
|
||||
// TOTP path.
|
||||
authenticated = totp.Verify(string(secretRaw), code)
|
||||
} else {
|
||||
// Recovery code path.
|
||||
if len(user.RecoveryCodesEnc) > 0 {
|
||||
codesRaw, cerr := s.deps.Crypt.DecryptForUser(user.ID, "recovery", user.RecoveryCodesEnc)
|
||||
if cerr == nil {
|
||||
codes, cerr := totp.DecodeRecoveryCodes(codesRaw)
|
||||
if cerr == nil {
|
||||
updated, consumed := totp.ConsumeRecoveryCode(codes, code)
|
||||
if consumed {
|
||||
authenticated = true
|
||||
newEnc, encErr := s.deps.Crypt.EncryptForUser(user.ID, "recovery", mustMarshalCodes(updated))
|
||||
if encErr == nil {
|
||||
_ = s.deps.DB.SetRecoveryCodes(ctx, user.ID, newEnc)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !authenticated {
|
||||
s.recordAttempt(ctx, clientIP, user.Email, false)
|
||||
redirect(w, r, "/login/mfa", "", "Invalid code.")
|
||||
return
|
||||
}
|
||||
|
||||
clearPreAuth(w)
|
||||
s.recordAttempt(ctx, clientIP, user.Email, true)
|
||||
if _, err := s.deps.Sessions.Create(w, r, user.ID); err != nil {
|
||||
log.Printf("[webmail] session create: %v", err)
|
||||
redirect(w, r, "/login", "", "Session error.")
|
||||
return
|
||||
}
|
||||
s.deps.DB.UpdateLastLogin(ctx, user.ID)
|
||||
http.Redirect(w, r, "/", http.StatusSeeOther)
|
||||
}
|
||||
|
||||
// ---- Logout ----
|
||||
|
||||
func (s *Server) logout(w http.ResponseWriter, r *http.Request) {
|
||||
if err := s.deps.Sessions.Destroy(w, r); err != nil {
|
||||
log.Printf("[webmail] session destroy: %v", err)
|
||||
}
|
||||
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
||||
}
|
||||
|
||||
// ---- Pre-auth cookie ----
|
||||
|
||||
func (s *Server) setPreAuthCookie(w http.ResponseWriter, userID int64) {
|
||||
ts := fmt.Sprintf("%d", time.Now().Unix())
|
||||
uid := fmt.Sprintf("%d", userID)
|
||||
payload := uid + "|" + ts
|
||||
mac := preAuthMAC(s.deps.Cfg.SessionSecret, payload)
|
||||
value := payload + "|" + mac
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: preAuthCookieName,
|
||||
Value: value,
|
||||
Path: "/login",
|
||||
MaxAge: preAuthMaxAge,
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) preAuthUserID(r *http.Request) (int64, bool) {
|
||||
c, err := r.Cookie(preAuthCookieName)
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
parts := strings.SplitN(c.Value, "|", 3)
|
||||
if len(parts) != 3 {
|
||||
return 0, false
|
||||
}
|
||||
uid, ts, gotMAC := parts[0], parts[1], parts[2]
|
||||
payload := uid + "|" + ts
|
||||
|
||||
wantMAC := preAuthMAC(s.deps.Cfg.SessionSecret, payload)
|
||||
if !hmac.Equal([]byte(gotMAC), []byte(wantMAC)) {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
tsi, err := strconv.ParseInt(ts, 10, 64)
|
||||
if err != nil || time.Now().Unix()-tsi > preAuthMaxAge {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
id, err := strconv.ParseInt(uid, 10, 64)
|
||||
if err != nil || id <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
return id, true
|
||||
}
|
||||
|
||||
func clearPreAuth(w http.ResponseWriter) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: preAuthCookieName,
|
||||
Value: "",
|
||||
Path: "/login",
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
})
|
||||
}
|
||||
|
||||
func preAuthMAC(secret []byte, payload string) string {
|
||||
mac := hmac.New(sha256.New, secret)
|
||||
mac.Write([]byte(payload))
|
||||
return hex.EncodeToString(mac.Sum(nil))
|
||||
}
|
||||
|
||||
func mustMarshalCodes(codes []string) []byte {
|
||||
data, err := totp.EncodeRecoveryCodes(codes)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("totp marshal: %v", err))
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
// ---- Shared helpers ----
|
||||
|
||||
func (s *Server) recordAttempt(ctx context.Context, ip, email string, success bool) {
|
||||
if s.deps.Brute != nil {
|
||||
s.deps.Brute.RecordAttempt(ctx, ip, email, success)
|
||||
}
|
||||
}
|
||||
|
||||
func realIP(r *http.Request) string {
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
parts := strings.Split(xff, ",")
|
||||
ip := strings.TrimSpace(parts[0])
|
||||
if net.ParseIP(ip) != nil {
|
||||
return ip
|
||||
}
|
||||
}
|
||||
host, _, _ := net.SplitHostPort(r.RemoteAddr)
|
||||
return host
|
||||
}
|
||||
@@ -0,0 +1,193 @@
|
||||
package webclient
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"mime/quotedprintable"
|
||||
"net/mail"
|
||||
"net/textproto"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
appCrypto "ghb.freebede.com/nahakubuilder/mailgosend/internal/crypto"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/models"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/storage"
|
||||
)
|
||||
|
||||
// ComposeParams holds parsed compose form fields.
|
||||
type ComposeParams struct {
|
||||
From string // "Display Name <email@host>"
|
||||
FromEmail string // bare email
|
||||
To []string // each a valid RFC 5322 address
|
||||
CC []string
|
||||
BCC []string
|
||||
Subject string
|
||||
BodyText string
|
||||
InReplyTo string
|
||||
References string
|
||||
MessageID string // auto-generated if empty
|
||||
}
|
||||
|
||||
// BuildRFC5322 creates a raw RFC 5322 message.
|
||||
func BuildRFC5322(p *ComposeParams) ([]byte, error) {
|
||||
if p.MessageID == "" {
|
||||
randHex, err := appCrypto.RandomHex(16)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("message-id random: %w", err)
|
||||
}
|
||||
fromDomain := "localhost"
|
||||
if idx := strings.LastIndex(p.FromEmail, "@"); idx >= 0 {
|
||||
fromDomain = p.FromEmail[idx+1:]
|
||||
}
|
||||
p.MessageID = "<" + randHex + "@" + fromDomain + ">"
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
writeHeader := func(k, v string) {
|
||||
if v != "" {
|
||||
buf.WriteString(k + ": " + v + "\r\n")
|
||||
}
|
||||
}
|
||||
|
||||
writeHeader("From", p.From)
|
||||
writeHeader("To", strings.Join(p.To, ", "))
|
||||
if len(p.CC) > 0 {
|
||||
writeHeader("Cc", strings.Join(p.CC, ", "))
|
||||
}
|
||||
writeHeader("Subject", mime.QEncoding.Encode("utf-8", p.Subject))
|
||||
writeHeader("Date", time.Now().UTC().Format(time.RFC1123Z))
|
||||
writeHeader("Message-Id", p.MessageID)
|
||||
writeHeader("MIME-Version", "1.0")
|
||||
if p.InReplyTo != "" {
|
||||
writeHeader("In-Reply-To", p.InReplyTo)
|
||||
}
|
||||
if p.References != "" {
|
||||
writeHeader("References", p.References)
|
||||
}
|
||||
|
||||
// Write body as quoted-printable text/plain.
|
||||
mw := multipart.NewWriter(&buf)
|
||||
buf.WriteString("Content-Type: multipart/alternative; boundary=\"" + mw.Boundary() + "\"\r\n")
|
||||
buf.WriteString("\r\n")
|
||||
|
||||
// text/plain part
|
||||
th := make(textproto.MIMEHeader)
|
||||
th.Set("Content-Type", "text/plain; charset=utf-8")
|
||||
th.Set("Content-Transfer-Encoding", "quoted-printable")
|
||||
pw, err := mw.CreatePart(th)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create text part: %w", err)
|
||||
}
|
||||
qw := quotedprintable.NewWriter(pw)
|
||||
if _, err := qw.Write([]byte(p.BodyText)); err != nil {
|
||||
return nil, fmt.Errorf("write body: %w", err)
|
||||
}
|
||||
if err := qw.Close(); err != nil {
|
||||
return nil, fmt.Errorf("close qp: %w", err)
|
||||
}
|
||||
|
||||
if err := mw.Close(); err != nil {
|
||||
return nil, fmt.Errorf("close multipart: %w", err)
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// parseAddressList parses a comma-separated address string into valid email addresses.
|
||||
// Returns only the bare email addresses (no display names in returned slice).
|
||||
func parseAddressList(raw string) ([]string, error) {
|
||||
if strings.TrimSpace(raw) == "" {
|
||||
return nil, nil
|
||||
}
|
||||
addrs, err := mail.ParseAddressList(raw)
|
||||
if err != nil {
|
||||
// Try as a single address.
|
||||
addr, err2 := mail.ParseAddress(raw)
|
||||
if err2 != nil {
|
||||
return nil, fmt.Errorf("invalid address %q: %w", raw, err)
|
||||
}
|
||||
return []string{addr.Address}, nil
|
||||
}
|
||||
out := make([]string, 0, len(addrs))
|
||||
for _, a := range addrs {
|
||||
out = append(out, a.Address)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// addressListRFC formats a list of bare emails as RFC 5322 addresses.
|
||||
func addressListRFC(emails []string) []string {
|
||||
out := make([]string, len(emails))
|
||||
for i, e := range emails {
|
||||
out[i] = e
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// deliverLocally saves a message to a local recipient's INBOX.
|
||||
func (s *Server) deliverLocally(ctx context.Context, recipientEmail string, raw []byte, msg *storage.IncomingMessage) error {
|
||||
user, err := s.deps.DB.ResolveEmail(ctx, recipientEmail)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolve %s: %w", recipientEmail, err)
|
||||
}
|
||||
if user == nil || !user.Enabled {
|
||||
return fmt.Errorf("recipient %s not found or disabled", recipientEmail)
|
||||
}
|
||||
|
||||
inbox, err := s.deps.DB.GetMailboxByType(ctx, user.ID, models.MailboxInbox)
|
||||
if err != nil || inbox == nil {
|
||||
return fmt.Errorf("inbox not found for %s", recipientEmail)
|
||||
}
|
||||
|
||||
if _, err := s.deps.Store.SaveIncoming(ctx, user.ID, inbox.ID, msg); err != nil {
|
||||
return fmt.Errorf("save incoming: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// saveSentCopy saves a copy of a sent message to the sender's Sent folder.
|
||||
func (s *Server) saveSentCopy(ctx context.Context, senderID int64, raw []byte, msg *storage.IncomingMessage) error {
|
||||
sentBox, err := s.deps.DB.GetMailboxByType(ctx, senderID, models.MailboxSent)
|
||||
if err != nil || sentBox == nil {
|
||||
return fmt.Errorf("sent mailbox not found")
|
||||
}
|
||||
_, err = s.deps.Store.SaveIncoming(ctx, senderID, sentBox.ID, msg)
|
||||
return err
|
||||
}
|
||||
|
||||
// enqueueForDelivery adds a message to the delivery queue for a remote recipient.
|
||||
func (s *Server) enqueueForDelivery(ctx context.Context, fromEmail, toEmail string, raw []byte, msgID string) error {
|
||||
// Determine domain for queue domain_id (best effort, nil if not local).
|
||||
fromDomain := ""
|
||||
if idx := strings.LastIndex(fromEmail, "@"); idx >= 0 {
|
||||
fromDomain = fromEmail[idx+1:]
|
||||
}
|
||||
var domainID *int64
|
||||
if dom, err := s.deps.DB.GetDomain(ctx, fromDomain); err == nil && dom != nil {
|
||||
domainID = &dom.ID
|
||||
}
|
||||
|
||||
maxAge := s.deps.Cfg.QueueMaxAgeHours
|
||||
if maxAge <= 0 {
|
||||
maxAge = 72
|
||||
}
|
||||
|
||||
key, err := s.deps.Crypt.DeriveKeyGlobal("queue")
|
||||
if err != nil {
|
||||
return fmt.Errorf("queue key: %w", err)
|
||||
}
|
||||
rawEnc, err := appCrypto.Encrypt(key, raw)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encrypt queue: %w", err)
|
||||
}
|
||||
|
||||
domID := int64(0)
|
||||
if domainID != nil {
|
||||
domID = *domainID
|
||||
}
|
||||
_, err = s.deps.DB.EnqueueMessage(ctx, domID, fromEmail, toEmail, msgID, rawEnc, maxAge)
|
||||
return err
|
||||
}
|
||||
@@ -0,0 +1,904 @@
|
||||
package webclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/mail"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
appCrypto "ghb.freebede.com/nahakubuilder/mailgosend/internal/crypto"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/db"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/models"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/storage"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/totp"
|
||||
)
|
||||
|
||||
const messagesPerPage = 50
|
||||
|
||||
// ---- Root redirect ----
|
||||
|
||||
func (s *Server) rootRedirect(w http.ResponseWriter, r *http.Request) {
|
||||
user := s.currentUser(r)
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
inbox, err := s.deps.DB.GetMailboxByType(ctx, user.ID, models.MailboxInbox)
|
||||
if err != nil || inbox == nil {
|
||||
redirect(w, r, "/settings", "", "No inbox found. Contact your admin.")
|
||||
return
|
||||
}
|
||||
http.Redirect(w, r, fmt.Sprintf("/mail/%d", inbox.ID), http.StatusSeeOther)
|
||||
}
|
||||
|
||||
// ---- Mailbox view (message list) ----
|
||||
|
||||
type mailboxPage struct {
|
||||
basePage
|
||||
CurrentBox *models.Mailbox
|
||||
Messages []*db.IMAPMessage
|
||||
Query string
|
||||
PrevPage uint32 // UID of last message on prev page (0 = none)
|
||||
NextPage uint32 // UID of first message on next page (0 = none)
|
||||
TotalCount int
|
||||
}
|
||||
|
||||
func (s *Server) mailboxView(w http.ResponseWriter, r *http.Request) {
|
||||
user := s.currentUser(r)
|
||||
boxID := pathID(r, "boxid")
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
box, err := s.deps.DB.GetMailboxByID(ctx, boxID)
|
||||
if err != nil || box == nil || box.UserID != user.ID {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
msgs, err := s.deps.DB.ListIMAPMessages(ctx, boxID)
|
||||
if err != nil {
|
||||
log.Printf("[webmail] list messages: %v", err)
|
||||
msgs = nil
|
||||
}
|
||||
|
||||
// Reverse for newest-first.
|
||||
for i, j := 0, len(msgs)-1; i < j; i, j = i+1, j-1 {
|
||||
msgs[i], msgs[j] = msgs[j], msgs[i]
|
||||
}
|
||||
|
||||
// Apply search filter.
|
||||
query := strings.TrimSpace(r.URL.Query().Get("q"))
|
||||
if query != "" {
|
||||
ql := strings.ToLower(query)
|
||||
filtered := msgs[:0]
|
||||
for _, m := range msgs {
|
||||
if strings.Contains(strings.ToLower(m.Subject), ql) ||
|
||||
strings.Contains(strings.ToLower(m.FromEmail), ql) ||
|
||||
strings.Contains(strings.ToLower(m.FromName), ql) {
|
||||
filtered = append(filtered, m)
|
||||
}
|
||||
}
|
||||
msgs = filtered
|
||||
}
|
||||
|
||||
total := len(msgs)
|
||||
|
||||
// Pagination: "before" UID (load messages with UID < before).
|
||||
var prevPage, nextPage uint32
|
||||
beforeUID, _ := strconv.ParseUint(r.URL.Query().Get("before"), 10, 32)
|
||||
if beforeUID > 0 {
|
||||
// Filter messages with UID < beforeUID.
|
||||
cutEnd := len(msgs)
|
||||
for i, m := range msgs {
|
||||
if uint64(m.UID) < beforeUID {
|
||||
cutEnd = i + messagesPerPage
|
||||
if cutEnd > len(msgs) {
|
||||
cutEnd = len(msgs)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
// Find start.
|
||||
cutStart := 0
|
||||
for i, m := range msgs {
|
||||
if uint64(m.UID) < beforeUID {
|
||||
cutStart = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if cutStart > 0 {
|
||||
prevPage = msgs[cutStart-1].UID
|
||||
}
|
||||
msgs = msgs[cutStart:cutEnd]
|
||||
} else {
|
||||
// First page: newest messagesPerPage.
|
||||
if len(msgs) > messagesPerPage {
|
||||
nextPage = msgs[messagesPerPage].UID
|
||||
msgs = msgs[:messagesPerPage]
|
||||
}
|
||||
}
|
||||
|
||||
flash, errMsg := flashFrom(r)
|
||||
base := s.newBase(r, flash, errMsg)
|
||||
base.CurrentBoxID = boxID
|
||||
s.render(w, "mail", mailboxPage{
|
||||
basePage: base,
|
||||
CurrentBox: box,
|
||||
Messages: msgs,
|
||||
Query: query,
|
||||
PrevPage: prevPage,
|
||||
NextPage: nextPage,
|
||||
TotalCount: total,
|
||||
})
|
||||
}
|
||||
|
||||
// ---- Message view ----
|
||||
|
||||
type messagePage struct {
|
||||
basePage
|
||||
CurrentBox *models.Mailbox
|
||||
Message *db.IMAPMessage
|
||||
Body *storage.BodyParts
|
||||
PrevUID uint32
|
||||
NextUID uint32
|
||||
}
|
||||
|
||||
func (s *Server) messageView(w http.ResponseWriter, r *http.Request) {
|
||||
user := s.currentUser(r)
|
||||
boxID := pathID(r, "boxid")
|
||||
uid64, err := strconv.ParseUint(r.PathValue("uid"), 10, 32)
|
||||
if err != nil {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
uid := uint32(uid64)
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
box, err := s.deps.DB.GetMailboxByID(ctx, boxID)
|
||||
if err != nil || box == nil || box.UserID != user.ID {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
msg, err := s.deps.DB.GetIMAPMessageByUID(ctx, boxID, uid)
|
||||
if err != nil || msg == nil {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Auto-mark as read.
|
||||
if !msg.IsRead {
|
||||
if err := s.deps.DB.SetMessageFlags(ctx, msg.ID, true, msg.IsStarred, msg.IsDraft, msg.Flags); err != nil {
|
||||
log.Printf("[webmail] mark read: %v", err)
|
||||
} else {
|
||||
msg.IsRead = true
|
||||
}
|
||||
}
|
||||
|
||||
// Decrypt body parts.
|
||||
body, err := s.deps.Store.GetBodyParts(ctx, user.ID, msg.ID)
|
||||
if err != nil {
|
||||
log.Printf("[webmail] get body: %v", err)
|
||||
body = &storage.BodyParts{Text: "[Error loading message body]"}
|
||||
}
|
||||
|
||||
// Prev/next UIDs for navigation.
|
||||
msgs, _ := s.deps.DB.ListIMAPMessages(ctx, boxID)
|
||||
var prevUID, nextUID uint32
|
||||
for i, m := range msgs {
|
||||
if m.UID == uid {
|
||||
if i > 0 {
|
||||
nextUID = msgs[i-1].UID // list is ascending, so "next newer"
|
||||
}
|
||||
if i < len(msgs)-1 {
|
||||
prevUID = msgs[i+1].UID // older
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
flash, errMsg := flashFrom(r)
|
||||
base := s.newBase(r, flash, errMsg)
|
||||
base.CurrentBoxID = boxID
|
||||
s.render(w, "message", messagePage{
|
||||
basePage: base,
|
||||
CurrentBox: box,
|
||||
Message: msg,
|
||||
Body: body,
|
||||
PrevUID: prevUID,
|
||||
NextUID: nextUID,
|
||||
})
|
||||
}
|
||||
|
||||
// ---- Message actions ----
|
||||
|
||||
func (s *Server) messageFlag(w http.ResponseWriter, r *http.Request) {
|
||||
if !s.validateCSRF(w, r) {
|
||||
return
|
||||
}
|
||||
user := s.currentUser(r)
|
||||
boxID := pathID(r, "boxid")
|
||||
uid64, _ := strconv.ParseUint(r.PathValue("uid"), 10, 32)
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
box, err := s.deps.DB.GetMailboxByID(ctx, boxID)
|
||||
if err != nil || box == nil || box.UserID != user.ID {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
msg, err := s.deps.DB.GetIMAPMessageByUID(ctx, boxID, uint32(uid64))
|
||||
if err != nil || msg == nil {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
flag := r.FormValue("flag")
|
||||
isRead := msg.IsRead
|
||||
isStar := msg.IsStarred
|
||||
switch flag {
|
||||
case "read":
|
||||
isRead = !isRead
|
||||
case "star":
|
||||
isStar = !isStar
|
||||
}
|
||||
|
||||
if err := s.deps.DB.SetMessageFlags(ctx, msg.ID, isRead, isStar, msg.IsDraft, msg.Flags); err != nil {
|
||||
log.Printf("[webmail] flag: %v", err)
|
||||
}
|
||||
|
||||
// Return to message or mailbox depending on referrer.
|
||||
returnTo := r.FormValue("return")
|
||||
if returnTo == "" {
|
||||
returnTo = fmt.Sprintf("/mail/%d/%d", boxID, uid64)
|
||||
}
|
||||
http.Redirect(w, r, returnTo, http.StatusSeeOther)
|
||||
}
|
||||
|
||||
func (s *Server) messageTrash(w http.ResponseWriter, r *http.Request) {
|
||||
if !s.validateCSRF(w, r) {
|
||||
return
|
||||
}
|
||||
user := s.currentUser(r)
|
||||
boxID := pathID(r, "boxid")
|
||||
uid64, _ := strconv.ParseUint(r.PathValue("uid"), 10, 32)
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
box, err := s.deps.DB.GetMailboxByID(ctx, boxID)
|
||||
if err != nil || box == nil || box.UserID != user.ID {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
msg, err := s.deps.DB.GetIMAPMessageByUID(ctx, boxID, uint32(uid64))
|
||||
if err != nil || msg == nil {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if box.Type == models.MailboxTrash {
|
||||
// Already in trash: hard delete.
|
||||
if err := s.deps.DB.SoftDeleteMessage(ctx, msg.ID); err != nil {
|
||||
log.Printf("[webmail] soft delete: %v", err)
|
||||
}
|
||||
if _, err := s.deps.DB.HardDeleteMessages(ctx, boxID); err != nil {
|
||||
log.Printf("[webmail] hard delete: %v", err)
|
||||
}
|
||||
redirect(w, r, fmt.Sprintf("/mail/%d", boxID), "Message permanently deleted.", "")
|
||||
return
|
||||
}
|
||||
|
||||
// Move to trash: copy then soft-delete original.
|
||||
trashBox, err := s.deps.DB.GetMailboxByType(ctx, user.ID, models.MailboxTrash)
|
||||
if err != nil || trashBox == nil {
|
||||
redirect(w, r, fmt.Sprintf("/mail/%d", boxID), "", "Trash folder not found.")
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := s.deps.DB.CopyMessageToMailbox(ctx, msg.ID, trashBox.ID, user.ID); err != nil {
|
||||
log.Printf("[webmail] copy to trash: %v", err)
|
||||
redirect(w, r, fmt.Sprintf("/mail/%d", boxID), "", "Move to trash failed.")
|
||||
return
|
||||
}
|
||||
if err := s.deps.DB.SoftDeleteMessage(ctx, msg.ID); err != nil {
|
||||
log.Printf("[webmail] soft delete orig: %v", err)
|
||||
}
|
||||
|
||||
redirect(w, r, fmt.Sprintf("/mail/%d", boxID), "Moved to trash.", "")
|
||||
}
|
||||
|
||||
func (s *Server) messageMove(w http.ResponseWriter, r *http.Request) {
|
||||
if !s.validateCSRF(w, r) {
|
||||
return
|
||||
}
|
||||
user := s.currentUser(r)
|
||||
boxID := pathID(r, "boxid")
|
||||
uid64, _ := strconv.ParseUint(r.PathValue("uid"), 10, 32)
|
||||
destBoxID, _ := strconv.ParseInt(r.FormValue("dest_box"), 10, 64)
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
box, err := s.deps.DB.GetMailboxByID(ctx, boxID)
|
||||
if err != nil || box == nil || box.UserID != user.ID {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
destBox, err := s.deps.DB.GetMailboxByID(ctx, destBoxID)
|
||||
if err != nil || destBox == nil || destBox.UserID != user.ID {
|
||||
redirect(w, r, fmt.Sprintf("/mail/%d", boxID), "", "Destination folder not found.")
|
||||
return
|
||||
}
|
||||
|
||||
msg, err := s.deps.DB.GetIMAPMessageByUID(ctx, boxID, uint32(uid64))
|
||||
if err != nil || msg == nil {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := s.deps.DB.CopyMessageToMailbox(ctx, msg.ID, destBoxID, user.ID); err != nil {
|
||||
redirect(w, r, fmt.Sprintf("/mail/%d", boxID), "", "Move failed.")
|
||||
return
|
||||
}
|
||||
if err := s.deps.DB.SoftDeleteMessage(ctx, msg.ID); err != nil {
|
||||
log.Printf("[webmail] soft delete on move: %v", err)
|
||||
}
|
||||
|
||||
redirect(w, r, fmt.Sprintf("/mail/%d", boxID), "Moved.", "")
|
||||
}
|
||||
|
||||
func (s *Server) mailboxExpunge(w http.ResponseWriter, r *http.Request) {
|
||||
if !s.validateCSRF(w, r) {
|
||||
return
|
||||
}
|
||||
user := s.currentUser(r)
|
||||
boxID := pathID(r, "boxid")
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
box, err := s.deps.DB.GetMailboxByID(ctx, boxID)
|
||||
if err != nil || box == nil || box.UserID != user.ID {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Mark all as deleted first (for boxes that only soft-delete).
|
||||
if _, err := s.deps.DB.HardDeleteMessages(ctx, boxID); err != nil {
|
||||
log.Printf("[webmail] expunge: %v", err)
|
||||
redirect(w, r, fmt.Sprintf("/mail/%d", boxID), "", "Expunge failed.")
|
||||
return
|
||||
}
|
||||
redirect(w, r, fmt.Sprintf("/mail/%d", boxID), "Expunged.", "")
|
||||
}
|
||||
|
||||
// ---- Compose ----
|
||||
|
||||
type composePage struct {
|
||||
basePage
|
||||
To string
|
||||
CC string
|
||||
Subject string
|
||||
BodyText string
|
||||
InReplyTo string
|
||||
References string
|
||||
}
|
||||
|
||||
func (s *Server) composeGet(w http.ResponseWriter, r *http.Request) {
|
||||
user := s.currentUser(r)
|
||||
flash, errMsg := flashFrom(r)
|
||||
base := s.newBase(r, flash, errMsg)
|
||||
|
||||
p := composePage{basePage: base}
|
||||
|
||||
// Handle reply/forward.
|
||||
action := r.URL.Query().Get("action")
|
||||
boxID, _ := strconv.ParseInt(r.URL.Query().Get("boxid"), 10, 64)
|
||||
uid64, _ := strconv.ParseUint(r.URL.Query().Get("uid"), 10, 32)
|
||||
|
||||
if (action == "reply" || action == "forward") && boxID > 0 && uid64 > 0 {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
box, err := s.deps.DB.GetMailboxByID(ctx, boxID)
|
||||
if err == nil && box != nil && box.UserID == user.ID {
|
||||
orig, err := s.deps.DB.GetIMAPMessageByUID(ctx, boxID, uint32(uid64))
|
||||
if err == nil && orig != nil {
|
||||
body, err := s.deps.Store.GetBodyParts(ctx, user.ID, orig.ID)
|
||||
if err != nil {
|
||||
body = &storage.BodyParts{}
|
||||
}
|
||||
|
||||
if action == "reply" {
|
||||
replyAddr := orig.FromEmail
|
||||
if orig.FromName != "" {
|
||||
replyAddr = orig.FromName + " <" + orig.FromEmail + ">"
|
||||
}
|
||||
p.To = replyAddr
|
||||
p.Subject = reSubject(orig.Subject)
|
||||
p.InReplyTo = orig.MessageID
|
||||
p.References = orig.MessageID
|
||||
p.BodyText = quoteBody(orig.FromEmail, orig.Date, body.Text)
|
||||
} else {
|
||||
p.Subject = fwdSubject(orig.Subject)
|
||||
p.BodyText = fwdBody(orig.FromEmail, orig.ToList, orig.Date, orig.Subject, body.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.render(w, "compose", p)
|
||||
}
|
||||
|
||||
func (s *Server) composeSend(w http.ResponseWriter, r *http.Request) {
|
||||
if !s.validateCSRF(w, r) {
|
||||
return
|
||||
}
|
||||
user := s.currentUser(r)
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
toRaw := strings.TrimSpace(r.FormValue("to"))
|
||||
ccRaw := strings.TrimSpace(r.FormValue("cc"))
|
||||
bccRaw := strings.TrimSpace(r.FormValue("bcc"))
|
||||
subject := strings.TrimSpace(r.FormValue("subject"))
|
||||
bodyText := r.FormValue("body")
|
||||
inReplyTo := strings.TrimSpace(r.FormValue("in_reply_to"))
|
||||
references := strings.TrimSpace(r.FormValue("references"))
|
||||
|
||||
if toRaw == "" {
|
||||
redirect(w, r, "/compose", "", "To field is required.")
|
||||
return
|
||||
}
|
||||
if len(subject) > 998 {
|
||||
subject = subject[:998]
|
||||
}
|
||||
if len(bodyText) > 10*1024*1024 {
|
||||
redirect(w, r, "/compose", "", "Message body too large (max 10 MB).")
|
||||
return
|
||||
}
|
||||
|
||||
toAddrs, err := parseAddressList(toRaw)
|
||||
if err != nil || len(toAddrs) == 0 {
|
||||
redirect(w, r, "/compose", "", "Invalid To address: "+err.Error())
|
||||
return
|
||||
}
|
||||
ccAddrs, _ := parseAddressList(ccRaw)
|
||||
bccAddrs, _ := parseAddressList(bccRaw)
|
||||
|
||||
// Build from address.
|
||||
displayName := user.DisplayName
|
||||
if displayName == "" {
|
||||
displayName = user.Username
|
||||
}
|
||||
fromAddr := mail.Address{Name: displayName, Address: user.Email}
|
||||
fromRFC := fromAddr.String()
|
||||
|
||||
allRecipients := append(append(toAddrs, ccAddrs...), bccAddrs...)
|
||||
|
||||
p := &ComposeParams{
|
||||
From: fromRFC,
|
||||
FromEmail: user.Email,
|
||||
To: addressListRFC(toAddrs),
|
||||
CC: addressListRFC(ccAddrs),
|
||||
BCC: addressListRFC(bccAddrs),
|
||||
Subject: subject,
|
||||
BodyText: bodyText,
|
||||
InReplyTo: inReplyTo,
|
||||
References: references,
|
||||
}
|
||||
|
||||
raw, err := BuildRFC5322(p)
|
||||
if err != nil {
|
||||
log.Printf("[webmail] build message: %v", err)
|
||||
redirect(w, r, "/compose", "", "Failed to build message.")
|
||||
return
|
||||
}
|
||||
|
||||
incomingMsg := &storage.IncomingMessage{
|
||||
Raw: raw,
|
||||
FromEmail: user.Email,
|
||||
FromName: displayName,
|
||||
ToList: strings.Join(toAddrs, ", "),
|
||||
CCList: strings.Join(ccAddrs, ", "),
|
||||
BCCList: strings.Join(bccAddrs, ", "),
|
||||
Subject: subject,
|
||||
Date: time.Now().UTC(),
|
||||
MessageID: p.MessageID,
|
||||
}
|
||||
|
||||
// Save to Sent.
|
||||
if err := s.saveSentCopy(ctx, user.ID, raw, incomingMsg); err != nil {
|
||||
log.Printf("[webmail] save sent: %v", err)
|
||||
}
|
||||
|
||||
// Deliver to each recipient.
|
||||
var deliveryErrors []string
|
||||
for _, rcpt := range allRecipients {
|
||||
rcptDomain := ""
|
||||
if idx := strings.LastIndex(rcpt, "@"); idx >= 0 {
|
||||
rcptDomain = rcpt[idx+1:]
|
||||
}
|
||||
|
||||
// Check if local domain.
|
||||
isLocal := false
|
||||
if rcptDomain != "" {
|
||||
isLocal, _ = s.deps.DB.IsLocalDomain(ctx, rcptDomain)
|
||||
}
|
||||
|
||||
if isLocal {
|
||||
if err := s.deliverLocally(ctx, rcpt, raw, incomingMsg); err != nil {
|
||||
log.Printf("[webmail] local deliver %s: %v", rcpt, err)
|
||||
deliveryErrors = append(deliveryErrors, rcpt+": "+err.Error())
|
||||
}
|
||||
} else {
|
||||
if err := s.enqueueForDelivery(ctx, user.Email, rcpt, raw, p.MessageID); err != nil {
|
||||
log.Printf("[webmail] enqueue %s: %v", rcpt, err)
|
||||
deliveryErrors = append(deliveryErrors, rcpt+": queued (may fail)")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(deliveryErrors) > 0 {
|
||||
redirect(w, r, "/compose", "", "Sent with errors: "+strings.Join(deliveryErrors, "; "))
|
||||
return
|
||||
}
|
||||
|
||||
// Redirect to Sent folder.
|
||||
sentBox, _ := s.deps.DB.GetMailboxByType(ctx, user.ID, models.MailboxSent)
|
||||
if sentBox != nil {
|
||||
redirect(w, r, fmt.Sprintf("/mail/%d", sentBox.ID), "Message sent.", "")
|
||||
} else {
|
||||
redirect(w, r, "/", "Message sent.", "")
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Settings ----
|
||||
|
||||
type settingsPage struct {
|
||||
basePage
|
||||
AccountUser *models.User
|
||||
}
|
||||
|
||||
func (s *Server) settingsGet(w http.ResponseWriter, r *http.Request) {
|
||||
user := s.currentUser(r)
|
||||
flash, errMsg := flashFrom(r)
|
||||
s.render(w, "settings", settingsPage{
|
||||
basePage: s.newBase(r, flash, errMsg),
|
||||
AccountUser: user,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) settingsPassword(w http.ResponseWriter, r *http.Request) {
|
||||
if !s.validateCSRF(w, r) {
|
||||
return
|
||||
}
|
||||
user := s.currentUser(r)
|
||||
|
||||
current := r.FormValue("current_password")
|
||||
newPw := r.FormValue("new_password")
|
||||
confirm := r.FormValue("confirm_password")
|
||||
|
||||
if err := appCrypto.CheckPassword(user.PasswordHash, current); err != nil {
|
||||
redirect(w, r, "/settings", "", "Current password is incorrect.")
|
||||
return
|
||||
}
|
||||
if len(newPw) < 8 || len(newPw) > 1024 {
|
||||
redirect(w, r, "/settings", "", "New password must be 8-1024 characters.")
|
||||
return
|
||||
}
|
||||
if newPw != confirm {
|
||||
redirect(w, r, "/settings", "", "Passwords do not match.")
|
||||
return
|
||||
}
|
||||
|
||||
hash, err := appCrypto.HashPassword(newPw)
|
||||
if err != nil {
|
||||
redirect(w, r, "/settings", "", "Password error.")
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := s.deps.DB.SetUserPassword(ctx, user.ID, hash); err != nil {
|
||||
log.Printf("[webmail] set password: %v", err)
|
||||
redirect(w, r, "/settings", "", "Failed to update password.")
|
||||
return
|
||||
}
|
||||
|
||||
// Log out all other sessions for security.
|
||||
if err := s.deps.Sessions.DestroyAll(ctx, user.ID); err != nil {
|
||||
log.Printf("[webmail] destroy sessions: %v", err)
|
||||
}
|
||||
// Create fresh session for current request.
|
||||
if _, err := s.deps.Sessions.Create(w, r, user.ID); err != nil {
|
||||
log.Printf("[webmail] re-create session: %v", err)
|
||||
}
|
||||
|
||||
redirect(w, r, "/settings", "Password updated. All other sessions logged out.", "")
|
||||
}
|
||||
|
||||
func (s *Server) settingsDisplay(w http.ResponseWriter, r *http.Request) {
|
||||
if !s.validateCSRF(w, r) {
|
||||
return
|
||||
}
|
||||
user := s.currentUser(r)
|
||||
displayName := strings.TrimSpace(r.FormValue("display_name"))
|
||||
if len(displayName) > 255 {
|
||||
displayName = displayName[:255]
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := s.deps.DB.SetUserDisplayName(ctx, user.ID, displayName); err != nil {
|
||||
log.Printf("[webmail] display name: %v", err)
|
||||
redirect(w, r, "/settings", "", "Update failed.")
|
||||
return
|
||||
}
|
||||
redirect(w, r, "/settings", "Display name updated.", "")
|
||||
}
|
||||
|
||||
// ---- helpers ----
|
||||
|
||||
func pathID(r *http.Request, key string) int64 {
|
||||
id, _ := strconv.ParseInt(r.PathValue(key), 10, 64)
|
||||
return id
|
||||
}
|
||||
|
||||
func reSubject(s string) string {
|
||||
low := strings.ToLower(s)
|
||||
if strings.HasPrefix(low, "re:") {
|
||||
return s
|
||||
}
|
||||
return "Re: " + s
|
||||
}
|
||||
|
||||
func fwdSubject(s string) string {
|
||||
low := strings.ToLower(s)
|
||||
if strings.HasPrefix(low, "fwd:") || strings.HasPrefix(low, "fw:") {
|
||||
return s
|
||||
}
|
||||
return "Fwd: " + s
|
||||
}
|
||||
|
||||
func quoteBody(from string, date time.Time, text string) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("\r\n\r\n")
|
||||
sb.WriteString("On " + date.Format("Mon, 2 Jan 2006 at 15:04") + ", " + from + " wrote:\r\n")
|
||||
for _, line := range strings.Split(text, "\n") {
|
||||
sb.WriteString("> " + strings.TrimRight(line, "\r") + "\r\n")
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func fwdBody(from, to string, date time.Time, subject, text string) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("\r\n\r\n-------- Forwarded Message --------\r\n")
|
||||
sb.WriteString("From: " + from + "\r\n")
|
||||
sb.WriteString("To: " + to + "\r\n")
|
||||
sb.WriteString("Date: " + date.Format(time.RFC1123Z) + "\r\n")
|
||||
sb.WriteString("Subject: " + subject + "\r\n\r\n")
|
||||
sb.WriteString(text)
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// ---- TOTP / MFA enrollment ----
|
||||
|
||||
type mfaEnrollPage struct {
|
||||
basePage
|
||||
Secret string // base32, shown for manual entry
|
||||
OTPAuthURI string // otpauth:// URI for QR code
|
||||
}
|
||||
|
||||
// mfaEnrollGet generates a new TOTP secret, stores it unconfirmed in a signed
|
||||
// session cookie, and renders the enrollment form.
|
||||
func (s *Server) mfaEnrollGet(w http.ResponseWriter, r *http.Request) {
|
||||
user := s.currentUser(r)
|
||||
flash, errMsg := flashFrom(r)
|
||||
|
||||
secret, err := totp.GenerateSecret()
|
||||
if err != nil {
|
||||
log.Printf("[webmail] mfa generate: %v", err)
|
||||
redirect(w, r, "/settings", "", "Failed to generate MFA secret.")
|
||||
return
|
||||
}
|
||||
|
||||
issuer := s.deps.Cfg.DefaultDomain
|
||||
if issuer == "" {
|
||||
issuer = s.deps.Cfg.Hostname
|
||||
}
|
||||
if issuer == "" {
|
||||
issuer = "mailgosend"
|
||||
}
|
||||
uri := totp.OTPAuthURI(secret, user.Email, issuer)
|
||||
|
||||
// Stash pending secret in a short-lived signed cookie so that the POST
|
||||
// can verify the code before persisting to the DB.
|
||||
s.setPendingTOTPCookie(w, user.ID, secret)
|
||||
|
||||
s.render(w, "mfa_enroll", mfaEnrollPage{
|
||||
basePage: s.newBase(r, flash, errMsg),
|
||||
Secret: secret,
|
||||
OTPAuthURI: uri,
|
||||
})
|
||||
}
|
||||
|
||||
// mfaEnrollPost verifies the TOTP code from the enrollment form and, on
|
||||
// success, encrypts + persists the secret and generates recovery codes.
|
||||
func (s *Server) mfaEnrollPost(w http.ResponseWriter, r *http.Request) {
|
||||
if !s.validateCSRF(w, r) {
|
||||
return
|
||||
}
|
||||
user := s.currentUser(r)
|
||||
|
||||
code := strings.TrimSpace(r.FormValue("code"))
|
||||
if len(code) != totp.Digits {
|
||||
redirect(w, r, "/settings/mfa/enroll", "", "Enter the 6-digit code from your authenticator app.")
|
||||
return
|
||||
}
|
||||
|
||||
// Read and validate pending secret cookie.
|
||||
secret, ok := s.pendingTOTPSecret(r, user.ID)
|
||||
if !ok {
|
||||
redirect(w, r, "/settings/mfa/enroll", "", "Enrollment session expired. Please start over.")
|
||||
return
|
||||
}
|
||||
|
||||
if !totp.Verify(secret, code) {
|
||||
redirect(w, r, "/settings/mfa/enroll", "", "Code did not match. Check your authenticator and try again.")
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Encrypt and store secret.
|
||||
encSecret, err := s.deps.Crypt.EncryptForUser(user.ID, "totp", []byte(secret))
|
||||
if err != nil {
|
||||
log.Printf("[webmail] mfa encrypt secret: %v", err)
|
||||
redirect(w, r, "/settings", "", "MFA setup failed.")
|
||||
return
|
||||
}
|
||||
if err := s.deps.DB.SetMFASecret(ctx, user.ID, encSecret); err != nil {
|
||||
log.Printf("[webmail] mfa save secret: %v", err)
|
||||
redirect(w, r, "/settings", "", "MFA setup failed.")
|
||||
return
|
||||
}
|
||||
|
||||
// Generate and store recovery codes.
|
||||
codes, err := totp.GenerateRecoveryCodes()
|
||||
if err != nil {
|
||||
log.Printf("[webmail] mfa recovery codes: %v", err)
|
||||
} else {
|
||||
codesJSON, _ := totp.EncodeRecoveryCodes(codes)
|
||||
encCodes, encErr := s.deps.Crypt.EncryptForUser(user.ID, "recovery", codesJSON)
|
||||
if encErr == nil {
|
||||
_ = s.deps.DB.SetRecoveryCodes(ctx, user.ID, encCodes)
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.deps.DB.SetMFAEnabled(ctx, user.ID, true); err != nil {
|
||||
log.Printf("[webmail] mfa enable: %v", err)
|
||||
redirect(w, r, "/settings", "", "MFA setup failed.")
|
||||
return
|
||||
}
|
||||
|
||||
clearPendingTOTP(w)
|
||||
redirect(w, r, "/settings", "Two-factor authentication enabled.", "")
|
||||
}
|
||||
|
||||
// mfaDisable disables MFA after verifying current password.
|
||||
func (s *Server) mfaDisable(w http.ResponseWriter, r *http.Request) {
|
||||
if !s.validateCSRF(w, r) {
|
||||
return
|
||||
}
|
||||
if err := r.ParseForm(); err != nil {
|
||||
http.Error(w, "bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
user := s.currentUser(r)
|
||||
|
||||
pw := r.FormValue("password")
|
||||
if err := appCrypto.CheckPassword(user.PasswordHash, pw); err != nil {
|
||||
redirect(w, r, "/settings", "", "Incorrect password. MFA not disabled.")
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := s.deps.DB.ClearMFA(ctx, user.ID); err != nil {
|
||||
log.Printf("[webmail] mfa disable: %v", err)
|
||||
redirect(w, r, "/settings", "", "Failed to disable MFA.")
|
||||
return
|
||||
}
|
||||
redirect(w, r, "/settings", "Two-factor authentication disabled.", "")
|
||||
}
|
||||
|
||||
// ---- Pending TOTP cookie (enrollment flow) ----
|
||||
|
||||
const pendingTOTPCookie = "mailgo_enroll"
|
||||
const pendingTOTPMaxAge = 300
|
||||
|
||||
func (s *Server) setPendingTOTPCookie(w http.ResponseWriter, userID int64, secret string) {
|
||||
uid := fmt.Sprintf("%d", userID)
|
||||
ts := fmt.Sprintf("%d", time.Now().Unix())
|
||||
payload := uid + "|" + ts + "|" + secret
|
||||
mac := preAuthMAC(s.deps.Cfg.SessionSecret, payload)
|
||||
value := payload + "|" + mac
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: pendingTOTPCookie,
|
||||
Value: value,
|
||||
Path: "/settings/mfa",
|
||||
MaxAge: pendingTOTPMaxAge,
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) pendingTOTPSecret(r *http.Request, userID int64) (string, bool) {
|
||||
c, err := r.Cookie(pendingTOTPCookie)
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
// Split from the right: last segment is MAC.
|
||||
idx := strings.LastIndex(c.Value, "|")
|
||||
if idx < 0 {
|
||||
return "", false
|
||||
}
|
||||
payload := c.Value[:idx]
|
||||
gotMAC := c.Value[idx+1:]
|
||||
|
||||
wantMAC := preAuthMAC(s.deps.Cfg.SessionSecret, payload)
|
||||
if !strings.EqualFold(gotMAC, wantMAC) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
// payload = uid|ts|secret
|
||||
parts := strings.SplitN(payload, "|", 3)
|
||||
if len(parts) != 3 {
|
||||
return "", false
|
||||
}
|
||||
uid, ts, secret := parts[0], parts[1], parts[2]
|
||||
|
||||
id, err := strconv.ParseInt(uid, 10, 64)
|
||||
if err != nil || id != userID {
|
||||
return "", false
|
||||
}
|
||||
|
||||
tsi, err := strconv.ParseInt(ts, 10, 64)
|
||||
if err != nil || time.Now().Unix()-tsi > pendingTOTPMaxAge {
|
||||
return "", false
|
||||
}
|
||||
|
||||
if secret == "" {
|
||||
return "", false
|
||||
}
|
||||
return secret, true
|
||||
}
|
||||
|
||||
func clearPendingTOTP(w http.ResponseWriter) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: pendingTOTPCookie,
|
||||
Value: "",
|
||||
Path: "/settings/mfa",
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,281 @@
|
||||
// Package webclient provides the webmail HTTP client (default: 0.0.0.0:8080).
|
||||
package webclient
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"io/fs"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/auth"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/config"
|
||||
appCrypto "ghb.freebede.com/nahakubuilder/mailgosend/internal/crypto"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/db"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/models"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/storage"
|
||||
)
|
||||
|
||||
// Deps groups all dependencies for the webclient.
|
||||
type Deps struct {
|
||||
DB *db.DB
|
||||
Crypt *appCrypto.Crypto
|
||||
Sessions *auth.SessionStore
|
||||
Brute *auth.BruteGuard
|
||||
Store *storage.Store
|
||||
Cfg *config.Config
|
||||
FS fs.FS // embed.FS sub-rooted at web/client
|
||||
}
|
||||
|
||||
// Server handles all webclient HTTP routes.
|
||||
type Server struct {
|
||||
deps *Deps
|
||||
mux *http.ServeMux
|
||||
}
|
||||
|
||||
// New creates a webclient Server and registers all routes.
|
||||
func New(deps *Deps) *Server {
|
||||
s := &Server{deps: deps, mux: http.NewServeMux()}
|
||||
s.setupRoutes()
|
||||
return s
|
||||
}
|
||||
|
||||
// Handler returns the HTTP handler.
|
||||
func (s *Server) Handler() http.Handler {
|
||||
return logMiddleware(s.mux)
|
||||
}
|
||||
|
||||
// ---- Routing ----
|
||||
|
||||
func (s *Server) setupRoutes() {
|
||||
m := s.mux
|
||||
|
||||
// Public
|
||||
m.HandleFunc("GET /login", s.loginGet)
|
||||
m.HandleFunc("POST /login", s.loginPost)
|
||||
m.HandleFunc("GET /login/mfa", s.mfaGet)
|
||||
m.HandleFunc("POST /login/mfa", s.mfaPost)
|
||||
m.HandleFunc("GET /logout", s.logout)
|
||||
|
||||
// Root redirect
|
||||
m.HandleFunc("GET /{$}", s.require(s.rootRedirect))
|
||||
|
||||
// Mailbox + message routes
|
||||
m.HandleFunc("GET /mail/{boxid}", s.require(s.mailboxView))
|
||||
m.HandleFunc("GET /mail/{boxid}/{uid}", s.require(s.messageView))
|
||||
m.HandleFunc("POST /mail/{boxid}/{uid}/flag", s.require(s.messageFlag))
|
||||
m.HandleFunc("POST /mail/{boxid}/{uid}/trash", s.require(s.messageTrash))
|
||||
m.HandleFunc("POST /mail/{boxid}/{uid}/move", s.require(s.messageMove))
|
||||
m.HandleFunc("POST /mail/{boxid}/expunge", s.require(s.mailboxExpunge))
|
||||
|
||||
// Compose
|
||||
m.HandleFunc("GET /compose", s.require(s.composeGet))
|
||||
m.HandleFunc("POST /compose", s.require(s.composeSend))
|
||||
|
||||
// Settings
|
||||
m.HandleFunc("GET /settings", s.require(s.settingsGet))
|
||||
m.HandleFunc("POST /settings/password", s.require(s.settingsPassword))
|
||||
m.HandleFunc("POST /settings/display", s.require(s.settingsDisplay))
|
||||
m.HandleFunc("GET /settings/mfa/enroll", s.require(s.mfaEnrollGet))
|
||||
m.HandleFunc("POST /settings/mfa/enroll", s.require(s.mfaEnrollPost))
|
||||
m.HandleFunc("POST /settings/mfa/disable", s.require(s.mfaDisable))
|
||||
|
||||
// Static assets
|
||||
static, _ := fs.Sub(s.deps.FS, "static")
|
||||
m.Handle("GET /static/", http.StripPrefix("/static/", http.FileServer(http.FS(static))))
|
||||
}
|
||||
|
||||
// ---- Middleware ----
|
||||
|
||||
func (s *Server) require(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if s.currentUser(r) == nil {
|
||||
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) currentUser(r *http.Request) *models.User {
|
||||
_, user, err := s.deps.Sessions.Get(r)
|
||||
if err != nil || user == nil || !user.Enabled {
|
||||
return nil
|
||||
}
|
||||
return user
|
||||
}
|
||||
|
||||
func logMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
next.ServeHTTP(w, r)
|
||||
log.Printf("[webmail] %s %s %s", r.Method, r.URL.Path, time.Since(start))
|
||||
})
|
||||
}
|
||||
|
||||
// ---- Template rendering ----
|
||||
|
||||
type basePage struct {
|
||||
Flash string
|
||||
Error string
|
||||
CSRF string
|
||||
User *models.User
|
||||
Mailboxes []*models.Mailbox
|
||||
CurrentBoxID int64 // 0 = no mailbox selected (compose, settings, etc.)
|
||||
}
|
||||
|
||||
func (s *Server) newBase(r *http.Request, flash, errMsg string) basePage {
|
||||
user := s.currentUser(r)
|
||||
sessObj, _, _ := s.deps.Sessions.Get(r)
|
||||
var csrf string
|
||||
if sessObj != nil {
|
||||
csrf = s.csrfToken(sessObj.TokenHash)
|
||||
}
|
||||
|
||||
var boxes []*models.Mailbox
|
||||
if user != nil {
|
||||
boxes, _ = s.deps.DB.ListMailboxes(r.Context(), user.ID)
|
||||
}
|
||||
return basePage{Flash: flash, Error: errMsg, CSRF: csrf, User: user, Mailboxes: boxes}
|
||||
}
|
||||
|
||||
func (s *Server) render(w http.ResponseWriter, page string, data any) {
|
||||
t, err := template.New("").Funcs(tmplFuncs).ParseFS(
|
||||
s.deps.FS,
|
||||
"templates/base.html",
|
||||
"templates/"+page+".html",
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("[webmail] template parse %s: %v", page, err)
|
||||
http.Error(w, "template error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
w.Header().Set("X-Frame-Options", "SAMEORIGIN")
|
||||
if err := t.ExecuteTemplate(w, "base", data); err != nil {
|
||||
log.Printf("[webmail] template exec %s: %v", page, err)
|
||||
}
|
||||
}
|
||||
|
||||
// redirect sends 303 with optional flash/error query param.
|
||||
func redirect(w http.ResponseWriter, r *http.Request, path, flash, errMsg string) {
|
||||
target := path
|
||||
if flash != "" {
|
||||
target += "?flash=" + template.URLQueryEscaper(flash)
|
||||
} else if errMsg != "" {
|
||||
target += "?error=" + template.URLQueryEscaper(errMsg)
|
||||
}
|
||||
http.Redirect(w, r, target, http.StatusSeeOther)
|
||||
}
|
||||
|
||||
func flashFrom(r *http.Request) (flash, errMsg string) {
|
||||
return r.URL.Query().Get("flash"), r.URL.Query().Get("error")
|
||||
}
|
||||
|
||||
// ---- CSRF ----
|
||||
|
||||
// csrfToken returns an HMAC-SHA256 token valid for the current clock-hour.
|
||||
func (s *Server) csrfToken(sessionHash string) string {
|
||||
return computeCSRF(sessionHash, s.deps.Cfg.SessionSecret, time.Now().UTC())
|
||||
}
|
||||
|
||||
func computeCSRF(sessionHash string, secret []byte, t time.Time) string {
|
||||
h := hmac.New(sha256.New, secret)
|
||||
h.Write([]byte(sessionHash))
|
||||
h.Write([]byte(t.Format("2006-01-02-15")))
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
// checkCSRF validates the CSRF token from the form field "_csrf".
|
||||
func (s *Server) checkCSRF(r *http.Request, sessionHash string) bool {
|
||||
got := r.FormValue("_csrf")
|
||||
if got == "" {
|
||||
return false
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
cur := computeCSRF(sessionHash, s.deps.Cfg.SessionSecret, now)
|
||||
if hmac.Equal([]byte(got), []byte(cur)) {
|
||||
return true
|
||||
}
|
||||
prev := computeCSRF(sessionHash, s.deps.Cfg.SessionSecret, now.Add(-time.Hour))
|
||||
return hmac.Equal([]byte(got), []byte(prev))
|
||||
}
|
||||
|
||||
// validateCSRF checks CSRF and writes 403 on failure. Returns false on failure.
|
||||
func (s *Server) validateCSRF(w http.ResponseWriter, r *http.Request) bool {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
http.Error(w, "bad request", http.StatusBadRequest)
|
||||
return false
|
||||
}
|
||||
sess, _, _ := s.deps.Sessions.Get(r)
|
||||
if sess == nil {
|
||||
http.Error(w, "unauthenticated", http.StatusForbidden)
|
||||
return false
|
||||
}
|
||||
if !s.checkCSRF(r, sess.TokenHash) {
|
||||
http.Error(w, "CSRF validation failed", http.StatusForbidden)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// ---- Template funcs ----
|
||||
|
||||
var tmplFuncs = template.FuncMap{
|
||||
"humanBytes": humanBytes,
|
||||
"shortTime": func(t time.Time) string { return t.Format("2006-01-02 15:04") },
|
||||
"shortDate": func(t time.Time) string {
|
||||
now := time.Now()
|
||||
if t.Year() == now.Year() && t.Month() == now.Month() && t.Day() == now.Day() {
|
||||
return t.Format("15:04")
|
||||
}
|
||||
return t.Format("Jan 2")
|
||||
},
|
||||
"isZero": func(t time.Time) bool { return t.IsZero() },
|
||||
"add": func(a, b int) int { return a + b },
|
||||
"truncate": func(s string, n int) string {
|
||||
r := []rune(s)
|
||||
if len(r) <= n {
|
||||
return s
|
||||
}
|
||||
return string(r[:n]) + "..."
|
||||
},
|
||||
"mailboxLabel": mailboxLabel,
|
||||
"safeHTML": func(s string) template.HTML { return template.HTML(s) }, //nolint:gosec
|
||||
}
|
||||
|
||||
func mailboxLabel(mboxType string) string {
|
||||
switch mboxType {
|
||||
case "inbox":
|
||||
return "Inbox"
|
||||
case "sent":
|
||||
return "Sent"
|
||||
case "drafts":
|
||||
return "Drafts"
|
||||
case "trash":
|
||||
return "Trash"
|
||||
case "spam":
|
||||
return "Spam"
|
||||
case "archive":
|
||||
return "Archive"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func humanBytes(b int64) string {
|
||||
const unit = 1024
|
||||
if b < unit {
|
||||
return fmt.Sprintf("%d B", b)
|
||||
}
|
||||
div, exp := int64(unit), 0
|
||||
for n := b / unit; n >= unit; n /= unit {
|
||||
div *= unit
|
||||
exp++
|
||||
}
|
||||
return fmt.Sprintf("%.1f %cB", float64(b)/float64(div), "KMGTPE"[exp])
|
||||
}
|
||||
Reference in New Issue
Block a user