This commit is contained in:
2026-05-21 20:27:58 +00:00
parent 0d2615a9fd
commit 5a127bf2a2
38 changed files with 8644 additions and 0 deletions
+201
View File
@@ -0,0 +1,201 @@
package auth
import (
"context"
"database/sql"
"fmt"
"time"
"ghb.freebede.com/nahakubuilder/mailgosend/internal/db"
)
// BruteGuard tracks login attempts and bans IPs exceeding the threshold.
type BruteGuard struct {
db *db.DB
maxTries int
windowMin int
banHours int
whitelist map[string]struct{} // exempt IPs
}
// NewBruteGuard creates a brute-force guard.
func NewBruteGuard(database *db.DB, maxTries, windowMin, banHours int, whitelist []string) *BruteGuard {
wl := make(map[string]struct{}, len(whitelist))
for _, ip := range whitelist {
wl[ip] = struct{}{}
}
return &BruteGuard{
db: database,
maxTries: maxTries,
windowMin: windowMin,
banHours: banHours,
whitelist: wl,
}
}
// IsBanned returns true if the IP is currently banned.
func (g *BruteGuard) IsBanned(ctx context.Context, ip string) (bool, error) {
if _, ok := g.whitelist[ip]; ok {
return false, nil
}
var expiresAt sql.NullTime
err := g.db.SQL().QueryRowContext(ctx,
"SELECT expires_at FROM ip_bans WHERE ip = ?", ip).Scan(&expiresAt)
if err == sql.ErrNoRows {
return false, nil
}
if err != nil {
return false, fmt.Errorf("check ban: %w", err)
}
// Permanent ban (expires_at IS NULL) or not yet expired.
if !expiresAt.Valid || expiresAt.Time.After(time.Now().UTC()) {
return true, nil
}
// Expired ban — clean up.
_, _ = g.db.SQL().ExecContext(ctx, "DELETE FROM ip_bans WHERE ip = ?", ip)
return false, nil
}
// RecordAttempt records a login attempt and bans the IP if threshold exceeded.
// Returns (banned, error).
func (g *BruteGuard) RecordAttempt(ctx context.Context, ip, email string, success bool) (bool, error) {
if _, ok := g.whitelist[ip]; ok {
return false, nil
}
// Insert attempt.
_, err := g.db.SQL().ExecContext(ctx,
"INSERT INTO login_attempts (ip, user_email, success, created_at) VALUES (?, ?, ?, ?)",
ip, email, success, time.Now().UTC())
if err != nil {
return false, fmt.Errorf("record attempt: %w", err)
}
if success {
// Successful login resets the counter (remove recent failed attempts).
window := time.Now().UTC().Add(-time.Duration(g.windowMin) * time.Minute)
_, _ = g.db.SQL().ExecContext(ctx,
"DELETE FROM login_attempts WHERE ip = ? AND success = 0 AND created_at >= ?",
ip, window)
return false, nil
}
// Count failures in window.
window := time.Now().UTC().Add(-time.Duration(g.windowMin) * time.Minute)
var count int
err = g.db.SQL().QueryRowContext(ctx,
"SELECT COUNT(*) FROM login_attempts WHERE ip = ? AND success = 0 AND created_at >= ?",
ip, window).Scan(&count)
if err != nil {
return false, fmt.Errorf("count attempts: %w", err)
}
if count < g.maxTries {
return false, nil
}
// Threshold exceeded — ban the IP.
var expiresAt *time.Time
if g.banHours > 0 {
t := time.Now().UTC().Add(time.Duration(g.banHours) * time.Hour)
expiresAt = &t
}
_, err = g.db.SQL().ExecContext(ctx, `
INSERT INTO ip_bans (ip, reason, banned_at, expires_at)
VALUES (?, ?, ?, ?)
ON CONFLICT(ip) DO UPDATE SET
reason = excluded.reason,
banned_at = excluded.banned_at,
expires_at = excluded.expires_at`,
ip,
fmt.Sprintf("brute force: %d failed attempts in %d minutes", count, g.windowMin),
time.Now().UTC(),
expiresAt,
)
if err != nil {
return false, fmt.Errorf("ban ip: %w", err)
}
// Log security event.
_, _ = g.db.SQL().ExecContext(ctx, `
INSERT INTO security_events (type, ip, detail, created_at)
VALUES ('brute_ban', ?, ?, ?)`,
ip,
fmt.Sprintf("%d failed attempts in %d min window, last target: %s", count, g.windowMin, email),
time.Now().UTC(),
)
return true, nil
}
// BanIP manually bans an IP (from admin panel).
func (g *BruteGuard) BanIP(ctx context.Context, ip, reason, bannedBy string, hours int) error {
var expiresAt *time.Time
if hours > 0 {
t := time.Now().UTC().Add(time.Duration(hours) * time.Hour)
expiresAt = &t
}
_, err := g.db.SQL().ExecContext(ctx, `
INSERT INTO ip_bans (ip, reason, banned_at, expires_at)
VALUES (?, ?, ?, ?)
ON CONFLICT(ip) DO UPDATE SET
reason = excluded.reason,
banned_at = excluded.banned_at,
expires_at = excluded.expires_at`,
ip, reason, time.Now().UTC(), expiresAt)
return err
}
// UnbanIP removes a ban.
func (g *BruteGuard) UnbanIP(ctx context.Context, ip, releasedBy string) error {
_, err := g.db.SQL().ExecContext(ctx,
"DELETE FROM ip_bans WHERE ip = ?", ip)
return err
}
// ListBans returns active bans ordered by newest first.
func (g *BruteGuard) ListBans(ctx context.Context, limit int) ([]banRow, error) {
rows, err := g.db.SQL().QueryContext(ctx, `
SELECT ip, reason, banned_at, expires_at
FROM ip_bans
ORDER BY banned_at DESC
LIMIT ?`, limit)
if err != nil {
return nil, err
}
defer rows.Close()
var bans []banRow
for rows.Next() {
var b banRow
var exp sql.NullTime
if err := rows.Scan(&b.IP, &b.Reason, &b.BannedAt, &exp); err != nil {
return nil, err
}
if exp.Valid {
b.ExpiresAt = &exp.Time
}
bans = append(bans, b)
}
return bans, rows.Err()
}
// PurgeOldAttempts removes login attempt records older than windowMin*2.
// Call periodically to keep the table small.
func (g *BruteGuard) PurgeOldAttempts(ctx context.Context) error {
cutoff := time.Now().UTC().Add(-2 * time.Duration(g.windowMin) * time.Minute)
_, err := g.db.SQL().ExecContext(ctx,
"DELETE FROM login_attempts WHERE created_at < ?", cutoff)
return err
}
type banRow struct {
IP string
Reason string
BannedAt time.Time
ExpiresAt *time.Time
}
+276
View File
@@ -0,0 +1,276 @@
// Package auth provides session management and brute-force protection.
package auth
import (
"context"
"database/sql"
"fmt"
"net/http"
"time"
"ghb.freebede.com/nahakubuilder/mailgosend/internal/crypto"
"ghb.freebede.com/nahakubuilder/mailgosend/internal/db"
"ghb.freebede.com/nahakubuilder/mailgosend/internal/models"
)
const (
sessionCookieName = "mgs_session"
sessionPurpose = "session"
)
// SessionStore manages user sessions stored in the database.
type SessionStore struct {
db *db.DB
maxAge int // seconds
secureCookie bool
}
// NewSessionStore creates a session store.
func NewSessionStore(database *db.DB, maxAge int, secureCookie bool) *SessionStore {
return &SessionStore{
db: database,
maxAge: maxAge,
secureCookie: secureCookie,
}
}
// Create creates a new session for userID, sets the cookie, and returns the session.
func (s *SessionStore) Create(w http.ResponseWriter, r *http.Request, userID int64) (*models.Session, error) {
raw, hash, err := crypto.NewToken()
if err != nil {
return nil, fmt.Errorf("session token: %w", err)
}
now := time.Now().UTC()
expires := now.Add(time.Duration(s.maxAge) * time.Second)
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
defer cancel()
_, err = s.db.SQL().ExecContext(ctx, `
INSERT INTO sessions (user_id, token_hash, ip, user_agent, created_at, expires_at)
VALUES (?, ?, ?, ?, ?, ?)`,
userID,
hash,
realIP(r),
truncate(r.UserAgent(), 512),
now,
expires,
)
if err != nil {
return nil, fmt.Errorf("insert session: %w", err)
}
http.SetCookie(w, &http.Cookie{
Name: sessionCookieName,
Value: raw,
Path: "/",
MaxAge: s.maxAge,
HttpOnly: true,
Secure: s.secureCookie,
SameSite: http.SameSiteStrictMode,
})
return &models.Session{
UserID: userID,
TokenHash: hash,
IP: realIP(r),
CreatedAt: now,
ExpiresAt: expires,
}, nil
}
// Get validates the session cookie and returns the session + user.
// Returns (nil, nil, nil) if no valid session exists.
func (s *SessionStore) Get(r *http.Request) (*models.Session, *models.User, error) {
cookie, err := r.Cookie(sessionCookieName)
if err != nil {
return nil, nil, nil // no cookie = not logged in
}
raw := cookie.Value
if len(raw) == 0 {
return nil, nil, nil
}
hash := crypto.HashToken(raw)
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
defer cancel()
row := s.db.SQL().QueryRowContext(ctx, `
SELECT s.id, s.user_id, s.token_hash, s.ip, s.user_agent, s.created_at, s.expires_at,
u.id, u.domain_id, u.username, u.email, u.display_name,
u.quota_bytes, u.used_bytes, u.enabled, u.admin, u.domain_admin,
u.mfa_enabled, u.created_at, u.last_login
FROM sessions s
JOIN users u ON u.id = s.user_id
WHERE s.token_hash = ?
AND s.expires_at > ?
AND u.enabled = 1`,
hash, time.Now().UTC())
var sess models.Session
var user models.User
var lastLogin sql.NullTime
var createdAt sql.NullTime
err = row.Scan(
&sess.ID, &sess.UserID, &sess.TokenHash, &sess.IP, &sess.UserAgent,
&sess.CreatedAt, &sess.ExpiresAt,
&user.ID, &user.DomainID, &user.Username, &user.Email, &user.DisplayName,
&user.QuotaBytes, &user.UsedBytes, &user.Enabled, &user.Admin, &user.DomainAdmin,
&user.MFAEnabled, &createdAt, &lastLogin,
)
if err == sql.ErrNoRows {
return nil, nil, nil
}
if err != nil {
return nil, nil, fmt.Errorf("session lookup: %w", err)
}
if createdAt.Valid {
user.CreatedAt = createdAt.Time
}
if lastLogin.Valid {
user.LastLogin = lastLogin.Time
}
return &sess, &user, nil
}
// Destroy deletes the session from the database and clears the cookie.
func (s *SessionStore) Destroy(w http.ResponseWriter, r *http.Request) error {
cookie, err := r.Cookie(sessionCookieName)
if err != nil {
return nil // no session to destroy
}
hash := crypto.HashToken(cookie.Value)
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
defer cancel()
_, err = s.db.SQL().ExecContext(ctx,
"DELETE FROM sessions WHERE token_hash = ?", hash)
if err != nil {
return fmt.Errorf("delete session: %w", err)
}
http.SetCookie(w, &http.Cookie{
Name: sessionCookieName,
Value: "",
Path: "/",
MaxAge: -1,
HttpOnly: true,
Secure: s.secureCookie,
SameSite: http.SameSiteStrictMode,
})
return nil
}
// DestroyAll deletes all sessions for a user (logout everywhere).
func (s *SessionStore) DestroyAll(ctx context.Context, userID int64) error {
_, err := s.db.SQL().ExecContext(ctx,
"DELETE FROM sessions WHERE user_id = ?", userID)
return err
}
// PurgeExpired deletes all expired sessions. Call periodically (e.g. hourly).
func (s *SessionStore) PurgeExpired(ctx context.Context) error {
_, err := s.db.SQL().ExecContext(ctx,
"DELETE FROM sessions WHERE expires_at <= ?", time.Now().UTC())
return err
}
// UpdateLastLogin updates the user's last_login timestamp.
func UpdateLastLogin(ctx context.Context, database *db.DB, userID int64) {
database.SQL().ExecContext(ctx, // nolint: errcheck — best effort
"UPDATE users SET last_login = ? WHERE id = ?",
time.Now().UTC(), userID)
}
// ---- User lookup ----
// GetUserByEmail returns the user matching email (case-insensitive), or nil.
func GetUserByEmail(ctx context.Context, database *db.DB, email string) (*models.User, error) {
row := database.SQL().QueryRowContext(ctx, `
SELECT id, domain_id, username, email, password_hash, display_name,
quota_bytes, used_bytes, enabled, admin, domain_admin,
mfa_secret_enc, mfa_enabled, recovery_codes_enc,
created_at, last_login
FROM users
WHERE lower(email) = lower(?)`, email)
return scanUser(row)
}
// GetUserByID returns the user by ID, or nil.
func GetUserByID(ctx context.Context, database *db.DB, id int64) (*models.User, error) {
row := database.SQL().QueryRowContext(ctx, `
SELECT id, domain_id, username, email, password_hash, display_name,
quota_bytes, used_bytes, enabled, admin, domain_admin,
mfa_secret_enc, mfa_enabled, recovery_codes_enc,
created_at, last_login
FROM users
WHERE id = ?`, id)
return scanUser(row)
}
func scanUser(row *sql.Row) (*models.User, error) {
var u models.User
var mfaSecretEnc, recoveryCodesEnc []byte
var lastLogin sql.NullTime
err := row.Scan(
&u.ID, &u.DomainID, &u.Username, &u.Email, &u.PasswordHash,
&u.DisplayName, &u.QuotaBytes, &u.UsedBytes, &u.Enabled,
&u.Admin, &u.DomainAdmin,
&mfaSecretEnc, &u.MFAEnabled, &recoveryCodesEnc,
&u.CreatedAt, &lastLogin,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("scan user: %w", err)
}
u.MFASecretEnc = mfaSecretEnc
u.RecoveryCodesEnc = recoveryCodesEnc
if lastLogin.Valid {
u.LastLogin = lastLogin.Time
}
return &u, nil
}
// ---- Helpers ----
func realIP(r *http.Request) string {
if ip := r.Header.Get("X-Real-IP"); ip != "" {
return ip
}
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
return ip
}
host, _, err := parseHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
}
return host
}
func parseHostPort(addr string) (string, string, error) {
// net.SplitHostPort returns error for bare IPs without port.
for i := len(addr) - 1; i >= 0; i-- {
if addr[i] == ':' {
return addr[:i], addr[i+1:], nil
}
}
return addr, "", fmt.Errorf("no port")
}
func truncate(s string, max int) string {
if len(s) <= max {
return s
}
return s[:max]
}
+773
View File
@@ -0,0 +1,773 @@
// Package config loads and auto-generates app_config.conf.
// INI-style: KEY = value, # comments, blank lines ignored.
// Missing keys appended on each startup — existing values preserved.
// Env var MAILGO_<KEY> overrides file value.
package config
import (
"bufio"
"crypto/rand"
"encoding/hex"
"fmt"
"net"
"os"
"strconv"
"strings"
)
const ConfigPath = "./app_config.conf"
// Config holds all runtime configuration.
type Config struct {
// Identity
Hostname string // FQDN for SMTP HELO, TLS SNI, URL building
DefaultDomain string // Primary mail domain
// Network — SMTP
SMTPIface string
SMTPPort int
SubmitIface string
SubmitPort int // 587 STARTTLS
SMTPSPort int // 465 implicit TLS
SMTPEnabled bool
SubmitEnabled bool
SMTPSEnabled bool
// Network — IMAP
IMAPIface string
IMAPPort int // 143 STARTTLS
IMAPSPort int // 993 implicit TLS
IMAPEnabled bool
IMAPSEnabled bool
// Network — Web
WebClientIface string
WebClientPort int
WebAdminIface string
WebAdminPort int
CalDAVIface string
CalDAVPort int
// TLS
TLSMode string // dns01 | http01 | manual | off
TLSCert string // manual: path to cert.pem
TLSKey string // manual: path to key.pem
// Per-service overrides (empty = use global TLS)
SMTPTLSCert string
SMTPTLSKey string
IMAPTLSCert string
IMAPTLSKey string
WebTLSCert string
WebTLSKey string
// ACME (dns01 / http01)
ACMEEmail string
ACMECacheDir string
ACMEStaging bool
ACMEDomains []string // domains to certify; empty = just Hostname
// DNS-01 provider
ACMEDNSProvider string // cloudflare | route53 | digitalocean | hetzner | ...
// Cloudflare
CFDNSAPIToken string
CFAPIKey string
CFAPIEmail string
// Route53
AWSRegion string
AWSAccessKeyID string
AWSSecretAccessKey string
AWSHostedZoneID string
// DigitalOcean
DOAuthToken string
// Hetzner
HetznerAPIKey string
// Generic: any additional lego env vars should be exported before starting.
// Secrets (auto-generated on first run — BACK UP app_config.conf)
EncryptionKey []byte // 32 bytes, AES-256 master key
SessionSecret []byte // session cookie signing
// Database
DBDriver string // sqlite | postgres | mysql | mssql
DBPath string // SQLite path
DBDSN string // PostgreSQL/MySQL/MSSQL DSN
// Storage
StorageBackend string // db | fs
StorageFSPath string // base path for fs storage
// Security
MaxMessageSize int64
SessionMaxAge int // seconds
BruteMaxTries int
BruteWindowMin int
BruteBanHours int
TrustedProxies []net.IPNet
SecureCookie bool
BruteWhitelist []net.IP
// SMTP server
SMTPHostname string // override for SMTP HELO (defaults to Hostname)
MaxRcptPer int // max recipients per message
QueueMaxAgeHours int
QueueRetryMins []int // backoff schedule
DNSPrimary string
DNSSecondary string
// DKIM
DKIMSelector string
DKIMAlgo string // rsa2048 | ed25519
// Spam
SpamThreshold int
SpamDNSBL []string
SpamCheckSPF bool
SpamCheckDKIM bool
SpamCheckDMARC bool
// OAuth2 (external accounts)
GoogleClientID string
GoogleClientSecret string
MicrosoftClientID string
MicrosoftClientSecret string
MicrosoftTenantID string
// Debug
Debug bool
LogFile string
LogLevel string // debug | info | warn | error
}
// field drives both config file generation and value parsing.
type field struct {
key string
defVal string
comments []string
secret bool // true = never shown in logs
}
var allFields = []field{
// --- Identity ---
{key: "HOSTNAME", defVal: "mail.example.com", comments: []string{
"--- Server Identity ---",
"FQDN used for SMTP HELO/EHLO, TLS SNI, and URL building.",
"Must resolve in DNS if using TLS_MODE=autocert/dns01/http01.",
}},
{key: "DEFAULT_DOMAIN", defVal: "example.com", comments: []string{
"Primary mail domain served by this instance.",
}},
// --- SMTP ---
{key: "SMTP_IFACE", defVal: "0.0.0.0", comments: []string{
"--- SMTP Server (Inbound MTA) ---",
"Network interface to bind SMTP port 25.",
}},
{key: "SMTP_PORT", defVal: "25"},
{key: "SMTP_ENABLED", defVal: "true"},
{key: "SUBMIT_IFACE", defVal: "0.0.0.0", comments: []string{
"--- SMTP Submission (Authenticated Send) ---",
}},
{key: "SUBMIT_PORT", defVal: "587", comments: []string{"STARTTLS mandatory on this port."}},
{key: "SUBMIT_ENABLED", defVal: "true"},
{key: "SMTPS_PORT", defVal: "465", comments: []string{"Implicit TLS SMTP submission."}},
{key: "SMTPS_ENABLED", defVal: "true"},
// --- IMAP ---
{key: "IMAP_IFACE", defVal: "0.0.0.0", comments: []string{"--- IMAP Server ---"}},
{key: "IMAP_PORT", defVal: "143"},
{key: "IMAP_ENABLED", defVal: "true"},
{key: "IMAPS_PORT", defVal: "993"},
{key: "IMAPS_ENABLED", defVal: "true"},
// --- Web ---
{key: "WEBCLIENT_IFACE", defVal: "0.0.0.0", comments: []string{"--- Web Client ---"}},
{key: "WEBCLIENT_PORT", defVal: "8080"},
{key: "WEBADMIN_IFACE", defVal: "127.0.0.1", comments: []string{
"--- Web Admin ---",
"Default: loopback only. Change to 0.0.0.0 only behind a reverse proxy with auth.",
}},
{key: "WEBADMIN_PORT", defVal: "8081"},
{key: "CALDAV_IFACE", defVal: "0.0.0.0", comments: []string{"--- CalDAV / CardDAV ---"}},
{key: "CALDAV_PORT", defVal: "5232"},
// --- TLS ---
{key: "TLS_MODE", defVal: "dns01", comments: []string{
"--- TLS Configuration ---",
" dns01 = Let's Encrypt via DNS-01 (no open ports, wildcard support) [RECOMMENDED]",
" http01 = Let's Encrypt via HTTP-01 (port 80 must be reachable, no wildcards)",
" manual = Provide TLS_CERT + TLS_KEY paths",
" off = No TLS (use ONLY behind a TLS-terminating reverse proxy)",
}},
{key: "TLS_CERT", defVal: "./certs/cert.pem", comments: []string{
"Path to certificate file (PEM). Used when TLS_MODE=manual.",
"Also used as per-service fallback if SMTP_TLS_CERT etc. are not set.",
}},
{key: "TLS_KEY", defVal: "./certs/key.pem"},
{key: "SMTP_TLS_CERT", defVal: "", comments: []string{"Override TLS cert/key for SMTP services (blank = use global TLS)."}},
{key: "SMTP_TLS_KEY", defVal: ""},
{key: "IMAP_TLS_CERT", defVal: "", comments: []string{"Override TLS cert/key for IMAP services."}},
{key: "IMAP_TLS_KEY", defVal: ""},
{key: "WEB_TLS_CERT", defVal: "", comments: []string{"Override TLS cert/key for web/CalDAV ports."}},
{key: "WEB_TLS_KEY", defVal: ""},
// --- ACME ---
{key: "ACME_EMAIL", defVal: "", comments: []string{
"--- ACME / Let's Encrypt ---",
"Email for Let's Encrypt account registration and renewal notices. Required.",
}},
{key: "ACME_CACHE_DIR", defVal: "./acme-cache", comments: []string{
"Directory to cache ACME account data and certificates.",
}},
{key: "ACME_STAGING", defVal: "false", comments: []string{
"Use Let's Encrypt staging server (rate-limit-free testing). Set false for production.",
}},
{key: "ACME_DOMAINS", defVal: "", comments: []string{
"Comma-separated domains to include in the certificate.",
"Example: example.com,*.example.com,mail.example.com",
"Blank = use HOSTNAME only. Include wildcard for full coverage.",
}},
{key: "ACME_DNS_PROVIDER", defVal: "cloudflare", comments: []string{
"DNS provider for DNS-01 challenge (TLS_MODE=dns01).",
"Supported: cloudflare | route53 | digitalocean | hetzner | ovh | porkbun |",
" namecheap | gandi | desec | acmedns | godaddy | ... (90+ providers)",
"Full list: https://go-acme.github.io/lego/dns/",
}},
// --- Cloudflare ---
{key: "CF_DNS_API_TOKEN", defVal: "", secret: true, comments: []string{
"--- Cloudflare DNS-01 (ACME_DNS_PROVIDER=cloudflare) ---",
"API Token with Zone.DNS:Edit permission on target zone(s).",
"Preferred over CF_API_KEY. Create at: https://dash.cloudflare.com/profile/api-tokens",
}},
{key: "CF_API_KEY", defVal: "", secret: true, comments: []string{"Global API Key (alternative to CF_DNS_API_TOKEN)."}},
{key: "CF_API_EMAIL", defVal: "", comments: []string{"Account email (required with CF_API_KEY, not needed with CF_DNS_API_TOKEN)."}},
// --- Route53 ---
{key: "AWS_REGION", defVal: "", comments: []string{"--- AWS Route53 (ACME_DNS_PROVIDER=route53) ---"}},
{key: "AWS_ACCESS_KEY_ID", defVal: "", secret: true},
{key: "AWS_SECRET_ACCESS_KEY", defVal: "", secret: true},
{key: "AWS_HOSTED_ZONE_ID", defVal: "", comments: []string{"Optional: skip auto-detection."}},
// --- DigitalOcean ---
{key: "DO_AUTH_TOKEN", defVal: "", secret: true, comments: []string{"--- DigitalOcean (ACME_DNS_PROVIDER=digitalocean) ---"}},
// --- Hetzner ---
{key: "HETZNER_API_KEY", defVal: "", secret: true, comments: []string{"--- Hetzner DNS (ACME_DNS_PROVIDER=hetzner) ---"}},
// --- Secrets ---
{key: "ENCRYPTION_KEY", defVal: "", secret: true, comments: []string{
"--- Secrets (auto-generated — BACK UP this file!) ---",
"AES-256 master key for all data at rest (emails, tokens, keys, contacts, calendar).",
"64 hex characters = 32 bytes. Losing this key = permanent data loss.",
}},
{key: "SESSION_SECRET", defVal: "", secret: true, comments: []string{
"Session cookie signing secret. Changing this logs out all users.",
}},
// --- Database ---
{key: "DB_DRIVER", defVal: "sqlite", comments: []string{
"--- Database ---",
"Database driver: sqlite | postgres | mysql | mssql",
}},
{key: "DB_PATH", defVal: "./data/mail.db", comments: []string{"SQLite database path (DB_DRIVER=sqlite)."}},
{key: "DB_DSN", defVal: "", secret: true, comments: []string{
"Connection string for PostgreSQL/MySQL/MSSQL.",
" PostgreSQL: host=localhost port=5432 user=mail password=secret dbname=mail sslmode=require",
" MySQL: mail:secret@tcp(localhost:3306)/mail?tls=true",
" MSSQL: sqlserver://mail:secret@localhost?database=mail",
}},
// --- Storage ---
{key: "STORAGE_BACKEND", defVal: "db", comments: []string{
"--- Email Storage ---",
" db = Store encrypted message blobs in database (simple, single backup file)",
" fs = Store encrypted files on filesystem, metadata in DB (better for large attachments)",
}},
{key: "STORAGE_FS_PATH", defVal: "./data/messages", comments: []string{
"Base directory for filesystem storage (STORAGE_BACKEND=fs).",
}},
// --- Security ---
{key: "MAX_MESSAGE_SIZE", defVal: "52428800", comments: []string{
"--- Security ---",
"Maximum accepted email size in bytes (default 50 MB).",
}},
{key: "SESSION_MAX_AGE", defVal: "604800", comments: []string{"Session lifetime in seconds (default 7 days)."}},
{key: "BRUTE_MAX_TRIES", defVal: "5"},
{key: "BRUTE_WINDOW_MIN", defVal: "30"},
{key: "BRUTE_BAN_HOURS", defVal: "24"},
{key: "BRUTE_WHITELIST_IPS", defVal: "", comments: []string{"Comma-separated IPs exempt from brute-force banning."}},
{key: "TRUSTED_PROXIES", defVal: "", comments: []string{
"Comma-separated CIDR ranges of trusted reverse proxies.",
"Only these may set X-Forwarded-For / X-Forwarded-Proto headers.",
}},
{key: "SECURE_COOKIE", defVal: "false", comments: []string{
"Mark session cookies Secure. Set true when serving over HTTPS.",
"Auto-enabled when BASE_URL starts with https://",
}},
// --- SMTP tuning ---
{key: "SMTP_HOSTNAME", defVal: "", comments: []string{
"--- SMTP Tuning ---",
"SMTP HELO/EHLO hostname override. Blank = use HOSTNAME.",
}},
{key: "MAX_RCPT_PER", defVal: "100", comments: []string{"Maximum recipients per message."}},
{key: "QUEUE_MAX_AGE_HOURS", defVal: "72", comments: []string{"Queue age before bounce (hours)."}},
{key: "QUEUE_RETRY_MINS", defVal: "5,15,60,240,480", comments: []string{"Retry backoff schedule (minutes between attempts)."}},
{key: "DNS_PRIMARY", defVal: "1.1.1.1"},
{key: "DNS_SECONDARY", defVal: "8.8.8.8"},
// --- DKIM ---
{key: "DKIM_SELECTOR", defVal: "mail", comments: []string{
"--- DKIM ---",
"Default DKIM selector for new domains.",
}},
{key: "DKIM_ALGO", defVal: "rsa2048", comments: []string{"Key algorithm: rsa2048 | ed25519"}},
// --- Spam ---
{key: "SPAM_THRESHOLD", defVal: "10", comments: []string{
"--- Spam Filtering ---",
"Messages with spam score >= threshold delivered to Spam folder.",
}},
{key: "SPAM_DNSBL", defVal: "zen.spamhaus.org,bl.spamcop.net"},
{key: "SPAM_CHECK_SPF", defVal: "true"},
{key: "SPAM_CHECK_DKIM", defVal: "true"},
{key: "SPAM_CHECK_DMARC", defVal: "true"},
// --- OAuth2 ---
{key: "GOOGLE_CLIENT_ID", defVal: "", comments: []string{
"--- Google OAuth2 (external Gmail accounts) ---",
"Create at: https://console.cloud.google.com/apis/credentials",
"Required scope: https://mail.google.com/",
}},
{key: "GOOGLE_CLIENT_SECRET", defVal: "", secret: true},
{key: "MICROSOFT_CLIENT_ID", defVal: "", comments: []string{
"--- Microsoft OAuth2 (external Outlook accounts) ---",
"Register at: https://portal.azure.com/#blade/Microsoft_AAD_RegisteredApps",
}},
{key: "MICROSOFT_CLIENT_SECRET", defVal: "", secret: true},
{key: "MICROSOFT_TENANT_ID", defVal: "consumers"},
// --- Debug ---
{key: "DEBUG", defVal: "false", comments: []string{"--- Debug ---"}},
{key: "LOG_FILE", defVal: "./logs/mail.log"},
{key: "LOG_LEVEL", defVal: "info", comments: []string{"Log level: debug | info | warn | error"}},
}
// Load reads app_config.conf, generates it if missing, returns populated Config.
func Load() (*Config, error) {
if err := os.MkdirAll("./data", 0700); err != nil {
return nil, fmt.Errorf("create data dir: %w", err)
}
if err := os.MkdirAll("./logs", 0700); err != nil {
return nil, fmt.Errorf("create logs dir: %w", err)
}
existing, err := readFile(ConfigPath)
if err != nil {
return nil, err
}
// Auto-generate secrets if absent.
if existing["ENCRYPTION_KEY"] == "" {
existing["ENCRYPTION_KEY"] = mustHex(32)
fmt.Fprintln(os.Stderr, "[mailgosend] WARNING: Generated new ENCRYPTION_KEY — back up app_config.conf immediately!")
}
if existing["SESSION_SECRET"] == "" {
existing["SESSION_SECRET"] = mustHex(32)
}
if err := writeFile(ConfigPath, existing); err != nil {
return nil, fmt.Errorf("write config: %w", err)
}
get := func(key string) string {
if v := os.Getenv("MAILGO_" + key); v != "" {
return v
}
return existing[key]
}
// Decode master encryption key.
encHex := get("ENCRYPTION_KEY")
encKey, err := hex.DecodeString(encHex)
if err != nil || len(encKey) != 32 {
return nil, fmt.Errorf("ENCRYPTION_KEY must be 64 hex chars (32 bytes), got %d chars", len(encHex))
}
sessSecret := get("SESSION_SECRET")
if sessSecret == "" {
return nil, fmt.Errorf("SESSION_SECRET missing")
}
trustedProxies, err := parseCIDRs(get("TRUSTED_PROXIES"))
if err != nil {
return nil, fmt.Errorf("TRUSTED_PROXIES: %w", err)
}
acmeDomains := splitTrim(get("ACME_DOMAINS"), ",")
retryMins := parseIntList(get("QUEUE_RETRY_MINS"), []int{5, 15, 60, 240, 480})
dnsbl := splitTrim(get("SPAM_DNSBL"), ",")
smtpHostname := get("SMTP_HOSTNAME")
if smtpHostname == "" {
smtpHostname = get("HOSTNAME")
}
if smtpHostname == "" {
smtpHostname = "localhost"
}
cfg := &Config{
Hostname: orDefault(get("HOSTNAME"), "mail.example.com"),
DefaultDomain: orDefault(get("DEFAULT_DOMAIN"), "example.com"),
SMTPIface: orDefault(get("SMTP_IFACE"), "0.0.0.0"),
SMTPPort: atoi(get("SMTP_PORT"), 25),
SMTPEnabled: atobool(get("SMTP_ENABLED"), true),
SubmitIface: orDefault(get("SUBMIT_IFACE"), "0.0.0.0"),
SubmitPort: atoi(get("SUBMIT_PORT"), 587),
SubmitEnabled: atobool(get("SUBMIT_ENABLED"), true),
SMTPSPort: atoi(get("SMTPS_PORT"), 465),
SMTPSEnabled: atobool(get("SMTPS_ENABLED"), true),
IMAPIface: orDefault(get("IMAP_IFACE"), "0.0.0.0"),
IMAPPort: atoi(get("IMAP_PORT"), 143),
IMAPEnabled: atobool(get("IMAP_ENABLED"), true),
IMAPSPort: atoi(get("IMAPS_PORT"), 993),
IMAPSEnabled: atobool(get("IMAPS_ENABLED"), true),
WebClientIface: orDefault(get("WEBCLIENT_IFACE"), "0.0.0.0"),
WebClientPort: atoi(get("WEBCLIENT_PORT"), 8080),
WebAdminIface: orDefault(get("WEBADMIN_IFACE"), "127.0.0.1"),
WebAdminPort: atoi(get("WEBADMIN_PORT"), 8081),
CalDAVIface: orDefault(get("CALDAV_IFACE"), "0.0.0.0"),
CalDAVPort: atoi(get("CALDAV_PORT"), 5232),
TLSMode: orDefault(get("TLS_MODE"), "dns01"),
TLSCert: get("TLS_CERT"),
TLSKey: get("TLS_KEY"),
SMTPTLSCert: get("SMTP_TLS_CERT"),
SMTPTLSKey: get("SMTP_TLS_KEY"),
IMAPTLSCert: get("IMAP_TLS_CERT"),
IMAPTLSKey: get("IMAP_TLS_KEY"),
WebTLSCert: get("WEB_TLS_CERT"),
WebTLSKey: get("WEB_TLS_KEY"),
ACMEEmail: get("ACME_EMAIL"),
ACMECacheDir: orDefault(get("ACME_CACHE_DIR"), "./acme-cache"),
ACMEStaging: atobool(get("ACME_STAGING"), false),
ACMEDomains: acmeDomains,
ACMEDNSProvider: orDefault(get("ACME_DNS_PROVIDER"), "cloudflare"),
CFDNSAPIToken: get("CF_DNS_API_TOKEN"),
CFAPIKey: get("CF_API_KEY"),
CFAPIEmail: get("CF_API_EMAIL"),
AWSRegion: get("AWS_REGION"),
AWSAccessKeyID: get("AWS_ACCESS_KEY_ID"),
AWSSecretAccessKey: get("AWS_SECRET_ACCESS_KEY"),
AWSHostedZoneID: get("AWS_HOSTED_ZONE_ID"),
DOAuthToken: get("DO_AUTH_TOKEN"),
HetznerAPIKey: get("HETZNER_API_KEY"),
EncryptionKey: encKey,
SessionSecret: []byte(sessSecret),
DBDriver: orDefault(get("DB_DRIVER"), "sqlite"),
DBPath: orDefault(get("DB_PATH"), "./data/mail.db"),
DBDSN: get("DB_DSN"),
StorageBackend: orDefault(get("STORAGE_BACKEND"), "db"),
StorageFSPath: orDefault(get("STORAGE_FS_PATH"), "./data/messages"),
MaxMessageSize: int64(atoi(get("MAX_MESSAGE_SIZE"), 52428800)),
SessionMaxAge: atoi(get("SESSION_MAX_AGE"), 604800),
BruteMaxTries: atoi(get("BRUTE_MAX_TRIES"), 5),
BruteWindowMin: atoi(get("BRUTE_WINDOW_MIN"), 30),
BruteBanHours: atoi(get("BRUTE_BAN_HOURS"), 24),
BruteWhitelist: parseIPs(get("BRUTE_WHITELIST_IPS")),
TrustedProxies: trustedProxies,
SecureCookie: atobool(get("SECURE_COOKIE"), false),
SMTPHostname: smtpHostname,
MaxRcptPer: atoi(get("MAX_RCPT_PER"), 100),
QueueMaxAgeHours: atoi(get("QUEUE_MAX_AGE_HOURS"), 72),
QueueRetryMins: retryMins,
DNSPrimary: orDefault(get("DNS_PRIMARY"), "1.1.1.1"),
DNSSecondary: orDefault(get("DNS_SECONDARY"), "8.8.8.8"),
DKIMSelector: orDefault(get("DKIM_SELECTOR"), "mail"),
DKIMAlgo: orDefault(get("DKIM_ALGO"), "rsa2048"),
SpamThreshold: atoi(get("SPAM_THRESHOLD"), 10),
SpamDNSBL: dnsbl,
SpamCheckSPF: atobool(get("SPAM_CHECK_SPF"), true),
SpamCheckDKIM: atobool(get("SPAM_CHECK_DKIM"), true),
SpamCheckDMARC: atobool(get("SPAM_CHECK_DMARC"), true),
GoogleClientID: get("GOOGLE_CLIENT_ID"),
GoogleClientSecret: get("GOOGLE_CLIENT_SECRET"),
MicrosoftClientID: get("MICROSOFT_CLIENT_ID"),
MicrosoftClientSecret: get("MICROSOFT_CLIENT_SECRET"),
MicrosoftTenantID: orDefault(get("MICROSOFT_TENANT_ID"), "consumers"),
Debug: atobool(get("DEBUG"), false),
LogFile: orDefault(get("LOG_FILE"), "./logs/mail.log"),
LogLevel: orDefault(get("LOG_LEVEL"), "info"),
}
// Export provider-specific env vars so lego can pick them up.
cfg.exportProviderEnv()
logStartup(cfg)
return cfg, nil
}
// exportProviderEnv sets standard lego env vars from config values.
// This allows using app_config.conf as the single source of truth.
func (c *Config) exportProviderEnv() {
setenv := func(key, val string) {
if val != "" && os.Getenv(key) == "" {
os.Setenv(key, val) //nolint:errcheck
}
}
setenv("CF_DNS_API_TOKEN", c.CFDNSAPIToken)
setenv("CF_API_KEY", c.CFAPIKey)
setenv("CF_API_EMAIL", c.CFAPIEmail)
setenv("AWS_REGION", c.AWSRegion)
setenv("AWS_ACCESS_KEY_ID", c.AWSAccessKeyID)
setenv("AWS_SECRET_ACCESS_KEY", c.AWSSecretAccessKey)
setenv("AWS_HOSTED_ZONE_ID", c.AWSHostedZoneID)
setenv("DO_AUTH_TOKEN", c.DOAuthToken)
setenv("HETZNER_API_KEY", c.HetznerAPIKey)
}
// ACMEDomainList returns ACME_DOMAINS, falling back to Hostname.
func (c *Config) ACMEDomainList() []string {
if len(c.ACMEDomains) > 0 {
return c.ACMEDomains
}
return []string{c.Hostname}
}
// RealIP extracts the client IP, honouring X-Forwarded-For from trusted proxies.
func (c *Config) RealIP(remoteAddr, xForwardedFor string) string {
remoteIP, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
remoteIP = remoteAddr
}
if xForwardedFor == "" || !c.isTrustedProxy(remoteIP) {
return remoteIP
}
parts := strings.Split(xForwardedFor, ",")
if len(parts) > 0 {
if ip := strings.TrimSpace(parts[0]); net.ParseIP(ip) != nil {
return ip
}
}
return remoteIP
}
func (c *Config) isTrustedProxy(ipStr string) bool {
ip := net.ParseIP(ipStr)
if ip == nil {
return false
}
for _, cidr := range c.TrustedProxies {
if cidr.Contains(ip) {
return true
}
}
return false
}
func (c *Config) IsIPWhitelisted(ip string) bool {
parsed := net.ParseIP(ip)
if parsed == nil {
return false
}
for _, w := range c.BruteWhitelist {
if w.Equal(parsed) {
return true
}
}
return false
}
// ---- Config file I/O ----
func readFile(path string) (map[string]string, error) {
vals := make(map[string]string)
f, err := os.Open(path)
if os.IsNotExist(err) {
return vals, nil
}
if err != nil {
return nil, fmt.Errorf("open config: %w", err)
}
defer f.Close()
sc := bufio.NewScanner(f)
for sc.Scan() {
line := strings.TrimSpace(sc.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
idx := strings.IndexByte(line, '=')
if idx < 0 {
continue
}
k := strings.TrimSpace(line[:idx])
v := strings.TrimSpace(line[idx+1:])
vals[k] = v
}
return vals, sc.Err()
}
func writeFile(path string, existing map[string]string) error {
var sb strings.Builder
sb.WriteString("# mailgosend Configuration\n")
sb.WriteString("# =========================\n")
sb.WriteString("# Auto-generated and updated on each startup.\n")
sb.WriteString("# Edit freely — your values are always preserved.\n")
sb.WriteString("# Override any key via env var: MAILGO_<KEY>=value\n")
sb.WriteString("#\n\n")
for _, f := range allFields {
for _, c := range f.comments {
if c == "" {
sb.WriteString("#\n")
} else {
sb.WriteString("# " + c + "\n")
}
}
v := existing[f.key]
if v == "" {
v = f.defVal
}
sb.WriteString(f.key + " = " + v + "\n\n")
}
return os.WriteFile(path, []byte(sb.String()), 0600)
}
// ---- Startup log ----
func logStartup(c *Config) {
fmt.Printf("mailgosend starting\n")
fmt.Printf(" Hostname : %s\n", c.Hostname)
fmt.Printf(" Default domain: %s\n", c.DefaultDomain)
fmt.Printf(" TLS mode : %s\n", c.TLSMode)
if c.TLSMode == "dns01" || c.TLSMode == "http01" {
fmt.Printf(" ACME provider: %s\n", c.ACMEDNSProvider)
fmt.Printf(" ACME domains : %v\n", c.ACMEDomainList())
}
fmt.Printf(" DB driver : %s\n", c.DBDriver)
if c.SMTPEnabled {
fmt.Printf(" SMTP : %s:%d\n", c.SMTPIface, c.SMTPPort)
}
if c.SubmitEnabled {
fmt.Printf(" Submission : %s:%d (STARTTLS)\n", c.SubmitIface, c.SubmitPort)
}
if c.SMTPSEnabled {
fmt.Printf(" SMTPS : %s:%d (TLS)\n", c.SMTPIface, c.SMTPSPort)
}
if c.IMAPEnabled {
fmt.Printf(" IMAP : %s:%d (STARTTLS)\n", c.IMAPIface, c.IMAPPort)
}
if c.IMAPSEnabled {
fmt.Printf(" IMAPS : %s:%d (TLS)\n", c.IMAPIface, c.IMAPSPort)
}
fmt.Printf(" Web client : %s:%d\n", c.WebClientIface, c.WebClientPort)
fmt.Printf(" Web admin : %s:%d\n", c.WebAdminIface, c.WebAdminPort)
fmt.Printf(" CalDAV : %s:%d\n", c.CalDAVIface, c.CalDAVPort)
}
// ---- Helpers ----
func mustHex(n int) string {
b := make([]byte, n)
if _, err := rand.Read(b); err != nil {
panic("crypto/rand unavailable: " + err.Error())
}
return hex.EncodeToString(b)
}
func atoi(s string, fallback int) int {
if v, err := strconv.Atoi(s); err == nil {
return v
}
return fallback
}
func atobool(s string, fallback bool) bool {
if v, err := strconv.ParseBool(s); err == nil {
return v
}
return fallback
}
func orDefault(s, def string) string {
if s == "" {
return def
}
return s
}
func splitTrim(s, sep string) []string {
var out []string
for _, p := range strings.Split(s, sep) {
p = strings.TrimSpace(p)
if p != "" {
out = append(out, p)
}
}
return out
}
func parseIntList(s string, fallback []int) []int {
parts := splitTrim(s, ",")
if len(parts) == 0 {
return fallback
}
out := make([]int, 0, len(parts))
for _, p := range parts {
if v, err := strconv.Atoi(p); err == nil {
out = append(out, v)
}
}
if len(out) == 0 {
return fallback
}
return out
}
func parseCIDRs(s string) ([]net.IPNet, error) {
var nets []net.IPNet
for _, raw := range splitTrim(s, ",") {
if !strings.Contains(raw, "/") {
ip := net.ParseIP(raw)
if ip == nil {
return nil, fmt.Errorf("invalid IP %q", raw)
}
bits := 32
if ip.To4() == nil {
bits = 128
}
raw = fmt.Sprintf("%s/%d", ip.String(), bits)
}
_, ipNet, err := net.ParseCIDR(raw)
if err != nil {
return nil, fmt.Errorf("invalid CIDR %q: %w", raw, err)
}
nets = append(nets, *ipNet)
}
return nets, nil
}
func parseIPs(s string) []net.IP {
var ips []net.IP
for _, raw := range splitTrim(s, ",") {
if ip := net.ParseIP(raw); ip != nil {
ips = append(ips, ip)
}
}
return ips
}
+216
View File
@@ -0,0 +1,216 @@
// Package crypto provides AES-256-GCM encryption, HKDF key derivation,
// bcrypt helpers, and secure random utilities.
// All encryption uses authenticated encryption — tampering is detected.
package crypto
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"crypto/subtle"
"encoding/hex"
"fmt"
"io"
"golang.org/x/crypto/bcrypt"
"golang.org/x/crypto/hkdf"
)
const (
// BcryptCost is the minimum bcrypt work factor.
BcryptCost = 12
// keyLen is AES-256 key length in bytes.
keyLen = 32
// gcmNonceLen is the standard GCM nonce size.
gcmNonceLen = 12
)
// Crypto holds the master encryption key.
// One instance per application, injected everywhere that needs encryption.
type Crypto struct {
masterKey [keyLen]byte
}
// New creates a Crypto instance from a 32-byte master key.
func New(masterKey []byte) (*Crypto, error) {
if len(masterKey) != keyLen {
return nil, fmt.Errorf("crypto: master key must be %d bytes, got %d", keyLen, len(masterKey))
}
c := &Crypto{}
copy(c.masterKey[:], masterKey)
return c, nil
}
// DeriveKey returns a unique 32-byte AES-256 key for a given purpose + userID.
// Uses HKDF-SHA256 so each (purpose, userID) pair gets a unique subkey,
// and the master key is never used directly for encryption.
func (c *Crypto) DeriveKey(purpose string, userID int64) ([keyLen]byte, error) {
var key [keyLen]byte
info := fmt.Sprintf("%s:user:%d", purpose, userID)
r := hkdf.New(sha256.New, c.masterKey[:], nil, []byte(info))
if _, err := io.ReadFull(r, key[:]); err != nil {
return key, fmt.Errorf("hkdf derive: %w", err)
}
return key, nil
}
// DeriveKeyGlobal returns a 32-byte key derived from master key for global use
// (e.g. encrypting DKIM private keys stored per-domain, not per-user).
func (c *Crypto) DeriveKeyGlobal(purpose string) ([keyLen]byte, error) {
var key [keyLen]byte
r := hkdf.New(sha256.New, c.masterKey[:], nil, []byte("global:"+purpose))
if _, err := io.ReadFull(r, key[:]); err != nil {
return key, fmt.Errorf("hkdf derive global: %w", err)
}
return key, nil
}
// Encrypt encrypts plaintext with AES-256-GCM using the provided 32-byte key.
// Returns nonce||ciphertext||tag (nonce prepended, all opaque bytes).
// Returns an error if plaintext is nil (use []byte{} for empty).
func Encrypt(key [keyLen]byte, plaintext []byte) ([]byte, error) {
block, err := aes.NewCipher(key[:])
if err != nil {
return nil, fmt.Errorf("aes new cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("aes gcm: %w", err)
}
nonce := make([]byte, gcmNonceLen)
if _, err := rand.Read(nonce); err != nil {
return nil, fmt.Errorf("rand nonce: %w", err)
}
ciphertext := gcm.Seal(nonce, nonce, plaintext, nil)
return ciphertext, nil
}
// Decrypt decrypts a nonce||ciphertext||tag blob produced by Encrypt.
func Decrypt(key [keyLen]byte, ciphertext []byte) ([]byte, error) {
if len(ciphertext) < gcmNonceLen {
return nil, fmt.Errorf("ciphertext too short")
}
block, err := aes.NewCipher(key[:])
if err != nil {
return nil, fmt.Errorf("aes new cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("aes gcm: %w", err)
}
nonce := ciphertext[:gcmNonceLen]
data := ciphertext[gcmNonceLen:]
plaintext, err := gcm.Open(nil, nonce, data, nil)
if err != nil {
return nil, fmt.Errorf("decrypt: %w", err) // do not leak GCM error details
}
return plaintext, nil
}
// EncryptForUser derives a per-user key and encrypts.
func (c *Crypto) EncryptForUser(userID int64, purpose string, plaintext []byte) ([]byte, error) {
key, err := c.DeriveKey(purpose, userID)
if err != nil {
return nil, err
}
return Encrypt(key, plaintext)
}
// DecryptForUser derives a per-user key and decrypts.
func (c *Crypto) DecryptForUser(userID int64, purpose string, ciphertext []byte) ([]byte, error) {
key, err := c.DeriveKey(purpose, userID)
if err != nil {
return nil, err
}
return Decrypt(key, ciphertext)
}
// EncryptGlobal derives a global key for given purpose and encrypts.
func (c *Crypto) EncryptGlobal(purpose string, plaintext []byte) ([]byte, error) {
key, err := c.DeriveKeyGlobal(purpose)
if err != nil {
return nil, err
}
return Encrypt(key, plaintext)
}
// DecryptGlobal derives a global key for given purpose and decrypts.
func (c *Crypto) DecryptGlobal(purpose string, ciphertext []byte) ([]byte, error) {
key, err := c.DeriveKeyGlobal(purpose)
if err != nil {
return nil, err
}
return Decrypt(key, ciphertext)
}
// ---- Bcrypt ----
// HashPassword hashes a password with bcrypt at cost BcryptCost.
func HashPassword(password string) (string, error) {
if password == "" {
return "", fmt.Errorf("password must not be empty")
}
hash, err := bcrypt.GenerateFromPassword([]byte(password), BcryptCost)
if err != nil {
return "", fmt.Errorf("bcrypt: %w", err)
}
return string(hash), nil
}
// CheckPassword returns nil if password matches the stored bcrypt hash.
// Uses constant-time comparison internally (bcrypt).
func CheckPassword(hash, password string) error {
return bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
}
// ---- Session tokens ----
// NewToken generates a cryptographically random 32-byte token and returns
// (rawToken, sha256HexHash). Store the hash; send the raw token to the client.
func NewToken() (raw string, hash string, err error) {
b := make([]byte, 32)
if _, err = rand.Read(b); err != nil {
return "", "", fmt.Errorf("rand token: %w", err)
}
raw = hex.EncodeToString(b)
hash = HashToken(raw)
return raw, hash, nil
}
// HashToken returns the SHA-256 hex hash of a raw token string.
func HashToken(raw string) string {
h := sha256.Sum256([]byte(raw))
return hex.EncodeToString(h[:])
}
// SecureCompare returns true if a == b using constant-time comparison.
// Use for any comparison where timing attacks are a concern.
func SecureCompare(a, b string) bool {
return subtle.ConstantTimeCompare([]byte(a), []byte(b)) == 1
}
// ---- Random helpers ----
// RandomHex returns n random bytes as a hex string (length 2n).
func RandomHex(n int) (string, error) {
b := make([]byte, n)
if _, err := rand.Read(b); err != nil {
return "", fmt.Errorf("rand: %w", err)
}
return hex.EncodeToString(b), nil
}
// RandomBytes returns n cryptographically random bytes.
func RandomBytes(n int) ([]byte, error) {
b := make([]byte, n)
if _, err := rand.Read(b); err != nil {
return nil, fmt.Errorf("rand: %w", err)
}
return b, nil
}
+171
View File
@@ -0,0 +1,171 @@
// Package db provides the database/sql wrapper and driver registration.
// Default driver: modernc.org/sqlite (pure Go, no CGO).
// Additional drivers registered via build tags: postgres, mysql, mssql.
package db
import (
"context"
"database/sql"
"fmt"
"time"
_ "modernc.org/sqlite" // pure-Go SQLite driver
)
// DB wraps sql.DB with convenience methods and prepared statement caching.
type DB struct {
db *sql.DB
driver string
}
// Open opens and validates the database connection, runs migrations, returns DB.
func Open(driver, dsn string) (*DB, error) {
if driver == "" {
driver = "sqlite"
}
// Map friendly driver names to database/sql driver names.
sqlDriver := sqlDriverName(driver)
if driver == "sqlite" {
dsn = sqliteDSN(dsn)
}
sqlDB, err := sql.Open(sqlDriver, dsn)
if err != nil {
return nil, fmt.Errorf("db open %s: %w", driver, err)
}
// Connection pool tuning.
if driver == "sqlite" {
// SQLite: serialise with single connection to avoid SQLITE_BUSY.
sqlDB.SetMaxOpenConns(1)
sqlDB.SetMaxIdleConns(1)
sqlDB.SetConnMaxLifetime(0)
} else {
sqlDB.SetMaxOpenConns(25)
sqlDB.SetMaxIdleConns(10)
sqlDB.SetConnMaxLifetime(5 * time.Minute)
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := sqlDB.PingContext(ctx); err != nil {
sqlDB.Close()
return nil, fmt.Errorf("db ping: %w", err)
}
d := &DB{db: sqlDB, driver: driver}
// Enable WAL mode for SQLite (dramatically improves concurrent read performance).
if driver == "sqlite" {
if _, err := sqlDB.Exec(`PRAGMA journal_mode=WAL`); err != nil {
sqlDB.Close()
return nil, fmt.Errorf("sqlite WAL: %w", err)
}
if _, err := sqlDB.Exec(`PRAGMA foreign_keys=ON`); err != nil {
sqlDB.Close()
return nil, fmt.Errorf("sqlite foreign_keys: %w", err)
}
if _, err := sqlDB.Exec(`PRAGMA busy_timeout=5000`); err != nil {
sqlDB.Close()
return nil, fmt.Errorf("sqlite busy_timeout: %w", err)
}
}
if err := d.migrate(); err != nil {
sqlDB.Close()
return nil, fmt.Errorf("migrate: %w", err)
}
return d, nil
}
// Close closes the underlying sql.DB.
func (d *DB) Close() error { return d.db.Close() }
// Driver returns the driver name (sqlite / postgres / mysql / mssql).
func (d *DB) Driver() string { return d.driver }
// SQL returns the underlying *sql.DB for direct use when needed.
func (d *DB) SQL() *sql.DB { return d.db }
// Exec runs a query with a per-call context timeout.
func (d *DB) Exec(query string, args ...any) (sql.Result, error) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
return d.db.ExecContext(ctx, query, args...)
}
// QueryRow runs a single-row query.
func (d *DB) QueryRow(query string, args ...any) *sql.Row {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
return d.db.QueryRowContext(ctx, query, args...)
}
// Query runs a multi-row query.
func (d *DB) Query(query string, args ...any) (*sql.Rows, error) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
return d.db.QueryContext(ctx, query, args...)
}
// WithTx runs fn inside a transaction, rolling back on error or panic.
func (d *DB) WithTx(ctx context.Context, fn func(*sql.Tx) error) error {
tx, err := d.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("begin tx: %w", err)
}
defer func() {
if p := recover(); p != nil {
_ = tx.Rollback()
panic(p) // re-raise
}
}()
if err := fn(tx); err != nil {
_ = tx.Rollback()
return err
}
return tx.Commit()
}
// ---- Placeholder helper ----
// Placeholder returns the SQL parameter placeholder for the current driver.
// SQLite and MySQL use ?, PostgreSQL uses $1, $2… MSSQL uses @p1, @p2…
func (d *DB) Placeholder(n int) string {
switch d.driver {
case "postgres":
return fmt.Sprintf("$%d", n)
case "mssql":
return fmt.Sprintf("@p%d", n)
default:
return "?"
}
}
// ---- Private helpers ----
func sqlDriverName(driver string) string {
switch driver {
case "sqlite":
return "sqlite" // modernc.org/sqlite registers as "sqlite"
case "postgres":
return "postgres"
case "mysql":
return "mysql"
case "mssql":
return "sqlserver"
default:
return driver
}
}
func sqliteDSN(path string) string {
if path == "" {
path = "./data/mail.db"
}
// modernc.org/sqlite DSN supports query parameters.
return path + "?_pragma=foreign_keys(1)"
}
+120
View File
@@ -0,0 +1,120 @@
package db
import (
"context"
"database/sql"
"fmt"
"ghb.freebede.com/nahakubuilder/mailgosend/internal/models"
)
// GetDomain returns the domain row by name, or nil if not found.
func (d *DB) GetDomain(ctx context.Context, name string) (*models.Domain, error) {
row := d.db.QueryRowContext(ctx, `
SELECT id, name, enabled, dkim_private_enc, dkim_public, dkim_selector,
dkim_algo, spf_policy, dmarc_policy, max_users, max_quota_bytes, created_at
FROM domains WHERE lower(name) = lower(?)`, name)
var dom models.Domain
var privEnc []byte
err := row.Scan(
&dom.ID, &dom.Name, &dom.Enabled,
&privEnc, &dom.DKIMPublic, &dom.DKIMSelector,
&dom.DKIMAlgo, &dom.SPFPolicy, &dom.DMARCPolicy,
&dom.MaxUsers, &dom.MaxQuotaBytes, &dom.CreatedAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("get domain: %w", err)
}
dom.DKIMPrivateEnc = privEnc
return &dom, nil
}
// GetDomainByID returns the domain row by ID.
func (d *DB) GetDomainByID(ctx context.Context, id int64) (*models.Domain, error) {
row := d.db.QueryRowContext(ctx, `
SELECT id, name, enabled, dkim_private_enc, dkim_public, dkim_selector,
dkim_algo, spf_policy, dmarc_policy, max_users, max_quota_bytes, created_at
FROM domains WHERE id = ?`, id)
var dom models.Domain
var privEnc []byte
err := row.Scan(
&dom.ID, &dom.Name, &dom.Enabled,
&privEnc, &dom.DKIMPublic, &dom.DKIMSelector,
&dom.DKIMAlgo, &dom.SPFPolicy, &dom.DMARCPolicy,
&dom.MaxUsers, &dom.MaxQuotaBytes, &dom.CreatedAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("get domain by id: %w", err)
}
dom.DKIMPrivateEnc = privEnc
return &dom, nil
}
// IsLocalDomain returns true if name is a known enabled domain.
func (d *DB) IsLocalDomain(ctx context.Context, name string) (bool, error) {
var count int
err := d.db.QueryRowContext(ctx,
"SELECT COUNT(*) FROM domains WHERE lower(name)=lower(?) AND enabled=1", name).
Scan(&count)
if err != nil {
return false, err
}
return count > 0, nil
}
// ListDomains returns all domains ordered by name.
func (d *DB) ListDomains(ctx context.Context) ([]*models.Domain, error) {
rows, err := d.db.QueryContext(ctx, `
SELECT id, name, enabled, dkim_private_enc, dkim_public, dkim_selector,
dkim_algo, spf_policy, dmarc_policy, max_users, max_quota_bytes, created_at
FROM domains ORDER BY name`)
if err != nil {
return nil, err
}
defer rows.Close()
var doms []*models.Domain
for rows.Next() {
var dom models.Domain
var privEnc []byte
err := rows.Scan(
&dom.ID, &dom.Name, &dom.Enabled,
&privEnc, &dom.DKIMPublic, &dom.DKIMSelector,
&dom.DKIMAlgo, &dom.SPFPolicy, &dom.DMARCPolicy,
&dom.MaxUsers, &dom.MaxQuotaBytes, &dom.CreatedAt,
)
if err != nil {
return nil, err
}
dom.DKIMPrivateEnc = privEnc
doms = append(doms, &dom)
}
return doms, rows.Err()
}
// CreateDomain inserts a new domain. Returns the new ID.
func (d *DB) CreateDomain(ctx context.Context, name, selector, algo string) (int64, error) {
res, err := d.db.ExecContext(ctx, `
INSERT INTO domains (name, enabled, dkim_selector, dkim_algo)
VALUES (?, 1, ?, ?)`, name, selector, algo)
if err != nil {
return 0, fmt.Errorf("create domain: %w", err)
}
return res.LastInsertId()
}
// SaveDKIMKeys stores encrypted DKIM private key + public key for a domain.
func (d *DB) SaveDKIMKeys(ctx context.Context, domainID int64, privEnc []byte, pubPEM string) error {
_, err := d.db.ExecContext(ctx,
"UPDATE domains SET dkim_private_enc=?, dkim_public=? WHERE id=?",
privEnc, pubPEM, domainID)
return err
}
+284
View File
@@ -0,0 +1,284 @@
package db
import (
"context"
"database/sql"
"fmt"
"time"
"ghb.freebede.com/nahakubuilder/mailgosend/internal/models"
)
// IMAPMessage is a lightweight message descriptor used by the IMAP layer.
// The raw/body blobs are NOT loaded here — fetch separately via GetMessageRaw.
type IMAPMessage struct {
ID int64
MailboxID int64
UID uint32
MessageID string // RFC 2822 Message-ID header
Subject string
FromEmail string
FromName string
ToList string
Date time.Time
SizeBytes int64
HasAttachment bool
IsRead bool
IsStarred bool
IsDraft bool
IsDeleted bool // deleted_at IS NOT NULL
Flags string
SpamScore int
ReceivedAt time.Time
}
// ListIMAPMessages returns all non-deleted messages in a mailbox ordered by UID ascending.
func (d *DB) ListIMAPMessages(ctx context.Context, mailboxID int64) ([]*IMAPMessage, error) {
rows, err := d.db.QueryContext(ctx, `
SELECT id, mailbox_id, uid, message_id, subject, from_email, from_name,
to_list, date, size_bytes, has_attachment,
is_read, is_starred, is_draft, flags, spam_score, received_at
FROM messages
WHERE mailbox_id = ? AND deleted_at IS NULL
ORDER BY uid ASC`, mailboxID)
if err != nil {
return nil, err
}
defer rows.Close()
var out []*IMAPMessage
for rows.Next() {
m, err := scanIMAPMessage(rows)
if err != nil {
return nil, err
}
out = append(out, m)
}
return out, rows.Err()
}
// GetIMAPMessageByUID returns one message by UID within a mailbox.
func (d *DB) GetIMAPMessageByUID(ctx context.Context, mailboxID int64, uid uint32) (*IMAPMessage, error) {
rows, err := d.db.QueryContext(ctx, `
SELECT id, mailbox_id, uid, message_id, subject, from_email, from_name,
to_list, date, size_bytes, has_attachment,
is_read, is_starred, is_draft, flags, spam_score, received_at
FROM messages
WHERE mailbox_id = ? AND uid = ? AND deleted_at IS NULL
LIMIT 1`, mailboxID, uid)
if err != nil {
return nil, err
}
defer rows.Close()
if !rows.Next() {
return nil, nil
}
return scanIMAPMessage(rows)
}
// SetMessageFlags updates the mutable flags for a message.
func (d *DB) SetMessageFlags(ctx context.Context, messageID int64, isRead, isStarred, isDraft bool, extraFlags string) error {
_, err := d.db.ExecContext(ctx,
"UPDATE messages SET is_read=?, is_starred=?, is_draft=?, flags=? WHERE id=?",
isRead, isStarred, isDraft, extraFlags, messageID)
return err
}
// SoftDeleteMessage marks a message as deleted (sets deleted_at).
func (d *DB) SoftDeleteMessage(ctx context.Context, messageID int64) error {
_, err := d.db.ExecContext(ctx,
"UPDATE messages SET deleted_at=? WHERE id=?", time.Now().UTC(), messageID)
return err
}
// HardDeleteMessages physically removes all soft-deleted messages from a mailbox.
// Returns the UIDs of deleted messages (for EXPUNGE responses).
func (d *DB) HardDeleteMessages(ctx context.Context, mailboxID int64) ([]uint32, error) {
rows, err := d.db.QueryContext(ctx,
"SELECT uid FROM messages WHERE mailbox_id=? AND deleted_at IS NOT NULL ORDER BY uid ASC",
mailboxID)
if err != nil {
return nil, err
}
var uids []uint32
for rows.Next() {
var uid uint32
if err := rows.Scan(&uid); err != nil {
rows.Close()
return nil, err
}
uids = append(uids, uid)
}
rows.Close()
if err := rows.Err(); err != nil {
return nil, err
}
if len(uids) == 0 {
return nil, nil
}
// Delete attachments first (FK).
_, err = d.db.ExecContext(ctx, `
DELETE FROM attachments WHERE message_id IN (
SELECT id FROM messages WHERE mailbox_id=? AND deleted_at IS NOT NULL
)`, mailboxID)
if err != nil {
return nil, fmt.Errorf("delete attachments: %w", err)
}
_, err = d.db.ExecContext(ctx,
"DELETE FROM messages WHERE mailbox_id=? AND deleted_at IS NOT NULL", mailboxID)
if err != nil {
return nil, fmt.Errorf("delete messages: %w", err)
}
return uids, nil
}
// CopyMessageToMailbox duplicates a message row to another mailbox.
// Returns the new UID.
func (d *DB) CopyMessageToMailbox(ctx context.Context, srcMsgID, destMailboxID, userID int64) (uint32, error) {
// Read source.
var src struct {
mailboxID int64
uid uint32
messageID string
subject string
fromEmail string
fromName string
toList string
ccList string
bccList string
replyTo string
date time.Time
bodyTextEnc []byte
bodyHTMLEnc []byte
rawEnc []byte
sizeBytes int64
hasAttachment bool
isRead bool
isStarred bool
isDraft bool
flags string
spamScore int
}
err := d.db.QueryRowContext(ctx, `
SELECT mailbox_id, uid, message_id, subject, from_email, from_name,
to_list, cc_list, bcc_list, reply_to, date,
body_text_enc, body_html_enc, raw_enc,
size_bytes, has_attachment, is_read, is_starred, is_draft,
flags, spam_score
FROM messages WHERE id=? AND deleted_at IS NULL`, srcMsgID).Scan(
&src.mailboxID, &src.uid, &src.messageID, &src.subject,
&src.fromEmail, &src.fromName, &src.toList, &src.ccList, &src.bccList, &src.replyTo,
&src.date, &src.bodyTextEnc, &src.bodyHTMLEnc, &src.rawEnc,
&src.sizeBytes, &src.hasAttachment, &src.isRead, &src.isStarred, &src.isDraft,
&src.flags, &src.spamScore,
)
if err == sql.ErrNoRows {
return 0, fmt.Errorf("source message %d not found", srcMsgID)
}
if err != nil {
return 0, fmt.Errorf("copy message read: %w", err)
}
// Allocate UID in destination.
uid, err := d.NextUID(ctx, destMailboxID)
if err != nil {
return 0, fmt.Errorf("copy message uid: %w", err)
}
ins := &MessageInsert{
MailboxID: destMailboxID,
UID: uid,
MessageID: src.messageID,
Subject: src.subject,
FromEmail: src.fromEmail,
FromName: src.fromName,
ToList: src.toList,
CCList: src.ccList,
BCCList: src.bccList,
ReplyTo: src.replyTo,
Date: src.date,
BodyTextEnc: src.bodyTextEnc,
BodyHTMLEnc: src.bodyHTMLEnc,
RawEnc: src.rawEnc,
SizeBytes: src.sizeBytes,
HasAttachment: src.hasAttachment,
IsRead: src.isRead,
IsStarred: src.isStarred,
IsDraft: src.isDraft,
Flags: src.flags,
SpamScore: src.spamScore,
}
if _, err := d.InsertMessage(ctx, ins); err != nil {
return 0, fmt.Errorf("copy message insert: %w", err)
}
return uid, nil
}
// RenameMailbox updates the name field of a mailbox.
func (d *DB) RenameMailbox(ctx context.Context, mailboxID int64, newName string) error {
_, err := d.db.ExecContext(ctx,
"UPDATE mailboxes SET name=? WHERE id=?", newName, mailboxID)
return err
}
// SetMailboxSubscribed updates the subscribed flag on a mailbox.
func (d *DB) SetMailboxSubscribed(ctx context.Context, mailboxID int64, subscribed bool) error {
_, err := d.db.ExecContext(ctx,
"UPDATE mailboxes SET subscribed=? WHERE id=?", subscribed, mailboxID)
return err
}
// GetMailboxMessageCounts returns (total, unseen) counts for a mailbox.
func (d *DB) GetMailboxMessageCounts(ctx context.Context, mailboxID int64) (total, unseen int64, err error) {
err = d.db.QueryRowContext(ctx, `
SELECT COUNT(*), COUNT(CASE WHEN is_read=0 THEN 1 END)
FROM messages WHERE mailbox_id=? AND deleted_at IS NULL`, mailboxID).Scan(&total, &unseen)
return
}
// GetMailboxSize returns the total size in bytes of all messages in a mailbox.
func (d *DB) GetMailboxSize(ctx context.Context, mailboxID int64) (int64, error) {
var sz sql.NullInt64
err := d.db.QueryRowContext(ctx,
"SELECT SUM(size_bytes) FROM messages WHERE mailbox_id=? AND deleted_at IS NULL",
mailboxID).Scan(&sz)
if err != nil {
return 0, err
}
return sz.Int64, nil
}
// ---- helpers ----
func scanIMAPMessage(rows *sql.Rows) (*IMAPMessage, error) {
m := &IMAPMessage{}
err := rows.Scan(
&m.ID, &m.MailboxID, &m.UID, &m.MessageID, &m.Subject,
&m.FromEmail, &m.FromName, &m.ToList,
&m.Date, &m.SizeBytes, &m.HasAttachment,
&m.IsRead, &m.IsStarred, &m.IsDraft,
&m.Flags, &m.SpamScore, &m.ReceivedAt,
)
return m, err
}
// mailboxTypeToAttr converts our type string to an IMAP special-use string.
// Callers handle the conversion to imap.MailboxAttr themselves.
func MailboxTypeToSpecialUse(mboxType string) string {
switch mboxType {
case models.MailboxSent:
return `\Sent`
case models.MailboxDrafts:
return `\Drafts`
case models.MailboxTrash:
return `\Trash`
case models.MailboxSpam:
return `\Junk`
case models.MailboxArchive:
return `\Archive`
case models.MailboxInbox:
return `\Inbox`
}
return ""
}
+401
View File
@@ -0,0 +1,401 @@
package db
import (
"context"
"database/sql"
"fmt"
"math/rand"
"time"
"ghb.freebede.com/nahakubuilder/mailgosend/internal/models"
)
// GetMailbox returns the mailbox with the given name for a user, or nil.
func (d *DB) GetMailbox(ctx context.Context, userID int64, name string) (*models.Mailbox, error) {
row := d.db.QueryRowContext(ctx, `
SELECT id, user_id, name, type, parent_id, uid_validity, uid_next, subscribed, created_at
FROM mailboxes WHERE user_id=? AND name=?`, userID, name)
return scanMailbox(row)
}
// GetMailboxByType returns the first mailbox of the given type for a user.
func (d *DB) GetMailboxByType(ctx context.Context, userID int64, mboxType string) (*models.Mailbox, error) {
row := d.db.QueryRowContext(ctx, `
SELECT id, user_id, name, type, parent_id, uid_validity, uid_next, subscribed, created_at
FROM mailboxes WHERE user_id=? AND type=? LIMIT 1`, userID, mboxType)
return scanMailbox(row)
}
// GetMailboxByID returns the mailbox by ID.
func (d *DB) GetMailboxByID(ctx context.Context, id int64) (*models.Mailbox, error) {
row := d.db.QueryRowContext(ctx, `
SELECT id, user_id, name, type, parent_id, uid_validity, uid_next, subscribed, created_at
FROM mailboxes WHERE id=?`, id)
return scanMailbox(row)
}
// ListMailboxes returns all subscribed mailboxes for a user, ordered by name.
func (d *DB) ListMailboxes(ctx context.Context, userID int64) ([]*models.Mailbox, error) {
rows, err := d.db.QueryContext(ctx, `
SELECT id, user_id, name, type, parent_id, uid_validity, uid_next, subscribed, created_at
FROM mailboxes WHERE user_id=? ORDER BY name`, userID)
if err != nil {
return nil, err
}
defer rows.Close()
var mbs []*models.Mailbox
for rows.Next() {
mb, err := scanMailboxRow(rows)
if err != nil {
return nil, err
}
mbs = append(mbs, mb)
}
return mbs, rows.Err()
}
// CreateMailbox creates a mailbox. Returns the new mailbox with uid_validity set.
func (d *DB) CreateMailbox(ctx context.Context, userID int64, name, mboxType string, parentID *int64) (*models.Mailbox, error) {
uidValidity := uint32(rand.Int31()) //nolint:gosec — not a security value
if uidValidity == 0 {
uidValidity = 1
}
res, err := d.db.ExecContext(ctx, `
INSERT INTO mailboxes (user_id, name, type, parent_id, uid_validity, uid_next, subscribed, created_at)
VALUES (?, ?, ?, ?, ?, 1, 1, ?)`,
userID, name, mboxType, parentID, uidValidity, time.Now().UTC())
if err != nil {
return nil, fmt.Errorf("create mailbox: %w", err)
}
id, _ := res.LastInsertId()
return &models.Mailbox{
ID: id,
UserID: userID,
Name: name,
Type: mboxType,
ParentID: parentID,
UIDValidity: uidValidity,
UIDNext: 1,
Subscribed: true,
CreatedAt: time.Now().UTC(),
}, nil
}
// CreateDefaultMailboxes creates the standard mailbox set for a new user.
// Idempotent — skips any that already exist.
func (d *DB) CreateDefaultMailboxes(ctx context.Context, userID int64) error {
defaults := []struct {
name string
mboxType string
}{
{"INBOX", models.MailboxInbox},
{"Sent", models.MailboxSent},
{"Drafts", models.MailboxDrafts},
{"Trash", models.MailboxTrash},
{"Spam", models.MailboxSpam},
{"Archive", models.MailboxArchive},
}
for _, mb := range defaults {
existing, err := d.GetMailbox(ctx, userID, mb.name)
if err != nil {
return err
}
if existing != nil {
continue
}
if _, err := d.CreateMailbox(ctx, userID, mb.name, mb.mboxType, nil); err != nil {
return fmt.Errorf("create default mailbox %s: %w", mb.name, err)
}
}
return nil
}
// NextUID allocates the next UID for a mailbox atomically.
// Returns the UID to use for the new message.
func (d *DB) NextUID(ctx context.Context, mailboxID int64) (uint32, error) {
tx, err := d.db.BeginTx(ctx, nil)
if err != nil {
return 0, err
}
defer tx.Rollback() //nolint:errcheck
var next uint32
err = tx.QueryRowContext(ctx,
"SELECT uid_next FROM mailboxes WHERE id=?", mailboxID).Scan(&next)
if err != nil {
return 0, fmt.Errorf("read uid_next: %w", err)
}
_, err = tx.ExecContext(ctx,
"UPDATE mailboxes SET uid_next=uid_next+1 WHERE id=?", mailboxID)
if err != nil {
return 0, fmt.Errorf("increment uid_next: %w", err)
}
return next, tx.Commit()
}
// ---- Message operations ----
// InsertMessage stores a message record (body is already encrypted; call SaveRawBody separately).
func (d *DB) InsertMessage(ctx context.Context, m *MessageInsert) (int64, error) {
res, err := d.db.ExecContext(ctx, `
INSERT INTO messages
(mailbox_id, uid, message_id, subject, from_email, from_name,
to_list, cc_list, bcc_list, reply_to, date,
body_text_enc, body_html_enc, raw_enc,
size_bytes, has_attachment, is_read, is_starred, is_draft,
flags, spam_score, received_at)
VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)`,
m.MailboxID, m.UID, m.MessageID, m.Subject, m.FromEmail, m.FromName,
m.ToList, m.CCList, m.BCCList, m.ReplyTo, m.Date,
m.BodyTextEnc, m.BodyHTMLEnc, m.RawEnc,
m.SizeBytes, m.HasAttachment, m.IsRead, m.IsStarred, m.IsDraft,
m.Flags, m.SpamScore, time.Now().UTC(),
)
if err != nil {
return 0, fmt.Errorf("insert message: %w", err)
}
return res.LastInsertId()
}
// MessageInsert is the data transfer object for inserting a new message.
type MessageInsert struct {
MailboxID int64
UID uint32
MessageID string
Subject string
FromEmail string
FromName string
ToList string
CCList string
BCCList string
ReplyTo string
Date time.Time
BodyTextEnc []byte
BodyHTMLEnc []byte
RawEnc []byte
SizeBytes int64
HasAttachment bool
IsRead bool
IsStarred bool
IsDraft bool
Flags string
SpamScore int
}
// InsertAttachment stores an attachment record for a message.
func (d *DB) InsertAttachment(ctx context.Context, a *AttachmentInsert) (int64, error) {
res, err := d.db.ExecContext(ctx, `
INSERT INTO attachments
(message_id, filename, content_type, size_bytes, data_enc, data_path,
content_id, inline, mime_path)
VALUES (?,?,?,?,?,?,?,?,?)`,
a.MessageID, a.Filename, a.ContentType, a.SizeBytes,
a.DataEnc, a.DataPath, a.ContentID, a.Inline, a.MIMEPath,
)
if err != nil {
return 0, fmt.Errorf("insert attachment: %w", err)
}
return res.LastInsertId()
}
// AttachmentInsert is the data transfer object for inserting an attachment.
type AttachmentInsert struct {
MessageID int64
Filename string
ContentType string
SizeBytes int64
DataEnc []byte
DataPath string
ContentID string
Inline bool
MIMEPath string
}
// GetMessageRaw returns the encrypted raw blob for a message.
func (d *DB) GetMessageRaw(ctx context.Context, messageID int64) ([]byte, error) {
var raw []byte
err := d.db.QueryRowContext(ctx,
"SELECT raw_enc FROM messages WHERE id=?", messageID).Scan(&raw)
if err == sql.ErrNoRows {
return nil, nil
}
return raw, err
}
// ListMessages returns messages in a mailbox ordered by UID descending.
// Only non-deleted messages are returned.
func (d *DB) ListMessages(ctx context.Context, mailboxID int64, limit, offset int) ([]*models.Message, error) {
rows, err := d.db.QueryContext(ctx, `
SELECT id, mailbox_id, uid, message_id, subject, from_email, from_name,
to_list, cc_list, bcc_list, reply_to, date,
size_bytes, has_attachment, is_read, is_starred, is_draft,
flags, spam_score, received_at
FROM messages
WHERE mailbox_id=? AND deleted_at IS NULL
ORDER BY uid DESC
LIMIT ? OFFSET ?`, mailboxID, limit, offset)
if err != nil {
return nil, err
}
defer rows.Close()
var msgs []*models.Message
for rows.Next() {
var m models.Message
err := rows.Scan(
&m.ID, &m.MailboxID, &m.UID, &m.MessageID, &m.Subject,
&m.FromEmail, &m.FromName, &m.ToList, &m.CCList, &m.BCCList,
&m.ReplyTo, &m.Date, &m.SizeBytes, &m.HasAttachment,
&m.IsRead, &m.IsStarred, &m.IsDraft, &m.Flags,
&m.SpamScore, &m.ReceivedAt,
)
if err != nil {
return nil, err
}
msgs = append(msgs, &m)
}
return msgs, rows.Err()
}
// CountUnread returns the number of unread messages in a mailbox.
func (d *DB) CountUnread(ctx context.Context, mailboxID int64) (int, error) {
var n int
err := d.db.QueryRowContext(ctx,
"SELECT COUNT(*) FROM messages WHERE mailbox_id=? AND is_read=0 AND deleted_at IS NULL",
mailboxID).Scan(&n)
return n, err
}
// ---- Queue operations ----
// EnqueueMessage inserts a delivery queue entry. Returns the new queue ID.
func (d *DB) EnqueueMessage(ctx context.Context, domainID int64, from, to, msgID string, rawEnc []byte, maxAgeHours int) (int64, error) {
expires := time.Now().UTC().Add(time.Duration(maxAgeHours) * time.Hour)
res, err := d.db.ExecContext(ctx, `
INSERT INTO queue
(domain_id, from_addr, to_addr, raw_enc, message_id, status,
attempts, next_attempt, created_at, expires_at)
VALUES (?, ?, ?, ?, ?, 'pending', 0, ?, ?, ?)`,
domainID, from, to, rawEnc, msgID,
time.Now().UTC(), time.Now().UTC(), expires)
if err != nil {
return 0, fmt.Errorf("enqueue: %w", err)
}
return res.LastInsertId()
}
// PeekQueue returns up to limit pending/retry-eligible queue entries.
func (d *DB) PeekQueue(ctx context.Context, limit int) ([]QueueRow, error) {
rows, err := d.db.QueryContext(ctx, `
SELECT id, domain_id, from_addr, to_addr, raw_enc, message_id,
status, attempts, expires_at
FROM queue
WHERE status IN ('pending','failed')
AND next_attempt <= ?
AND expires_at > ?
ORDER BY next_attempt ASC
LIMIT ?`,
time.Now().UTC(), time.Now().UTC(), limit)
if err != nil {
return nil, err
}
defer rows.Close()
var out []QueueRow
for rows.Next() {
var q QueueRow
var domainID sql.NullInt64
err := rows.Scan(
&q.ID, &domainID, &q.FromAddr, &q.ToAddr,
&q.RawEnc, &q.MessageID, &q.Status, &q.Attempts, &q.ExpiresAt,
)
if err != nil {
return nil, err
}
if domainID.Valid {
q.DomainID = domainID.Int64
}
out = append(out, q)
}
return out, rows.Err()
}
// QueueRow is a minimal queue entry for the delivery worker.
type QueueRow struct {
ID int64
DomainID int64
FromAddr string
ToAddr string
RawEnc []byte
MessageID string
Status string
Attempts int
ExpiresAt time.Time
}
// SetQueueStatus updates the status of a queue entry.
func (d *DB) SetQueueStatus(ctx context.Context, id int64, status, errMsg string, nextAttempt *time.Time) error {
_, err := d.db.ExecContext(ctx, `
UPDATE queue
SET status=?, attempts=attempts+1, last_attempt=?,
error_log=error_log || ?, next_attempt=COALESCE(?, next_attempt)
WHERE id=?`,
status, time.Now().UTC(),
fmt.Sprintf("[%s] %s\n", time.Now().UTC().Format(time.RFC3339), errMsg),
nextAttempt, id)
return err
}
// LogDelivery inserts a delivery log entry.
func (d *DB) LogDelivery(ctx context.Context, queueID int64, from, to, status string, smtpCode int, smtpMsg, mxHost string) error {
_, err := d.db.ExecContext(ctx, `
INSERT INTO delivery_log (queue_id, from_addr, to_addr, status, smtp_code, smtp_message, mx_host, created_at)
VALUES (?,?,?,?,?,?,?,?)`,
queueID, from, to, status, smtpCode, smtpMsg, mxHost, time.Now().UTC())
return err
}
// ---- private ----
func scanMailbox(row *sql.Row) (*models.Mailbox, error) {
var mb models.Mailbox
var parentID sql.NullInt64
err := row.Scan(
&mb.ID, &mb.UserID, &mb.Name, &mb.Type,
&parentID, &mb.UIDValidity, &mb.UIDNext, &mb.Subscribed, &mb.CreatedAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("scan mailbox: %w", err)
}
if parentID.Valid {
id := parentID.Int64
mb.ParentID = &id
}
return &mb, nil
}
func scanMailboxRow(rows *sql.Rows) (*models.Mailbox, error) {
var mb models.Mailbox
var parentID sql.NullInt64
err := rows.Scan(
&mb.ID, &mb.UserID, &mb.Name, &mb.Type,
&parentID, &mb.UIDValidity, &mb.UIDNext, &mb.Subscribed, &mb.CreatedAt,
)
if err != nil {
return nil, err
}
if parentID.Valid {
id := parentID.Int64
mb.ParentID = &id
}
return &mb, nil
}
+337
View File
@@ -0,0 +1,337 @@
package db
import (
"context"
"database/sql"
"fmt"
"time"
)
// migration is a versioned schema change.
type migration struct {
version int
up string // SQL to apply
}
// migrations must be append-only. Never edit an applied migration.
var migrations = []migration{
{1, schemav1},
}
// migrate applies any unapplied migrations in order.
func (d *DB) migrate() error {
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
// Ensure migrations table exists.
_, err := d.db.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS schema_migrations (
version INTEGER PRIMARY KEY,
applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
)`)
if err != nil {
return fmt.Errorf("create migrations table: %w", err)
}
for _, m := range migrations {
var count int
err := d.db.QueryRowContext(ctx,
"SELECT COUNT(*) FROM schema_migrations WHERE version = ?", m.version).Scan(&count)
if err != nil {
return fmt.Errorf("check migration %d: %w", m.version, err)
}
if count > 0 {
continue // already applied
}
if err := d.WithTx(ctx, func(tx *sql.Tx) error {
if _, err := tx.ExecContext(ctx, m.up); err != nil {
return fmt.Errorf("apply migration %d: %w", m.version, err)
}
_, err := tx.ExecContext(ctx,
"INSERT INTO schema_migrations (version) VALUES (?)", m.version)
return err
}); err != nil {
return err
}
fmt.Printf("[db] applied migration %d\n", m.version)
}
return nil
}
// ---- Schema v1 (initial) ----
const schemav1 = `
-- Domains
CREATE TABLE IF NOT EXISTS domains (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL UNIQUE,
enabled BOOLEAN NOT NULL DEFAULT 1,
dkim_private_enc BLOB,
dkim_public TEXT,
dkim_selector TEXT NOT NULL DEFAULT 'mail',
dkim_algo TEXT NOT NULL DEFAULT 'rsa2048',
spf_policy TEXT,
dmarc_policy TEXT,
max_users INTEGER NOT NULL DEFAULT 0,
max_quota_bytes INTEGER NOT NULL DEFAULT 0,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
-- Users
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
domain_id INTEGER NOT NULL REFERENCES domains(id) ON DELETE CASCADE,
username TEXT NOT NULL,
email TEXT NOT NULL UNIQUE,
password_hash TEXT NOT NULL,
display_name TEXT NOT NULL DEFAULT '',
quota_bytes INTEGER NOT NULL DEFAULT 1073741824,
used_bytes INTEGER NOT NULL DEFAULT 0,
enabled BOOLEAN NOT NULL DEFAULT 1,
admin BOOLEAN NOT NULL DEFAULT 0,
domain_admin BOOLEAN NOT NULL DEFAULT 0,
mfa_secret_enc BLOB,
mfa_enabled BOOLEAN NOT NULL DEFAULT 0,
recovery_codes_enc BLOB,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
last_login TIMESTAMP
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_users_email ON users(email);
CREATE INDEX IF NOT EXISTS idx_users_domain ON users(domain_id);
-- User aliases
CREATE TABLE IF NOT EXISTS user_aliases (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
alias_email TEXT NOT NULL UNIQUE
);
-- Mailboxes (IMAP folders)
CREATE TABLE IF NOT EXISTS mailboxes (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
name TEXT NOT NULL,
type TEXT NOT NULL DEFAULT 'custom',
parent_id INTEGER REFERENCES mailboxes(id) ON DELETE CASCADE,
uid_validity INTEGER NOT NULL DEFAULT 1,
uid_next INTEGER NOT NULL DEFAULT 1,
subscribed BOOLEAN NOT NULL DEFAULT 1,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
UNIQUE(user_id, name)
);
CREATE INDEX IF NOT EXISTS idx_mailboxes_user ON mailboxes(user_id);
-- Messages
CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
mailbox_id INTEGER NOT NULL REFERENCES mailboxes(id) ON DELETE CASCADE,
uid INTEGER NOT NULL,
message_id TEXT,
subject TEXT NOT NULL DEFAULT '',
from_email TEXT NOT NULL DEFAULT '',
from_name TEXT NOT NULL DEFAULT '',
to_list TEXT NOT NULL DEFAULT '',
cc_list TEXT NOT NULL DEFAULT '',
bcc_list TEXT NOT NULL DEFAULT '',
reply_to TEXT NOT NULL DEFAULT '',
date TIMESTAMP,
body_text_enc BLOB,
body_html_enc BLOB,
raw_enc BLOB,
size_bytes INTEGER NOT NULL DEFAULT 0,
has_attachment BOOLEAN NOT NULL DEFAULT 0,
is_read BOOLEAN NOT NULL DEFAULT 0,
is_starred BOOLEAN NOT NULL DEFAULT 0,
is_draft BOOLEAN NOT NULL DEFAULT 0,
flags TEXT NOT NULL DEFAULT '',
spam_score INTEGER NOT NULL DEFAULT 0,
received_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
deleted_at TIMESTAMP,
UNIQUE(mailbox_id, uid)
);
CREATE INDEX IF NOT EXISTS idx_messages_mailbox ON messages(mailbox_id);
CREATE INDEX IF NOT EXISTS idx_messages_uid ON messages(mailbox_id, uid);
CREATE INDEX IF NOT EXISTS idx_messages_date ON messages(mailbox_id, date);
CREATE INDEX IF NOT EXISTS idx_messages_deleted ON messages(mailbox_id, deleted_at);
-- Attachments
CREATE TABLE IF NOT EXISTS attachments (
id INTEGER PRIMARY KEY AUTOINCREMENT,
message_id INTEGER NOT NULL REFERENCES messages(id) ON DELETE CASCADE,
filename TEXT NOT NULL,
content_type TEXT NOT NULL,
size_bytes INTEGER NOT NULL DEFAULT 0,
data_enc BLOB,
data_path TEXT,
content_id TEXT,
inline BOOLEAN NOT NULL DEFAULT 0,
mime_path TEXT
);
CREATE INDEX IF NOT EXISTS idx_attachments_message ON attachments(message_id);
-- Delivery queue
CREATE TABLE IF NOT EXISTS queue (
id INTEGER PRIMARY KEY AUTOINCREMENT,
domain_id INTEGER REFERENCES domains(id),
from_addr TEXT NOT NULL,
to_addr TEXT NOT NULL,
raw_enc BLOB NOT NULL,
message_id TEXT,
status TEXT NOT NULL DEFAULT 'pending',
attempts INTEGER NOT NULL DEFAULT 0,
last_attempt TIMESTAMP,
next_attempt TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
error_log TEXT NOT NULL DEFAULT '',
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
expires_at TIMESTAMP NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_queue_status ON queue(status, next_attempt);
-- Delivery log
CREATE TABLE IF NOT EXISTS delivery_log (
id INTEGER PRIMARY KEY AUTOINCREMENT,
queue_id INTEGER REFERENCES queue(id) ON DELETE SET NULL,
from_addr TEXT NOT NULL,
to_addr TEXT NOT NULL,
status TEXT NOT NULL,
smtp_code INTEGER NOT NULL DEFAULT 0,
smtp_message TEXT NOT NULL DEFAULT '',
mx_host TEXT NOT NULL DEFAULT '',
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_delivery_log_created ON delivery_log(created_at);
-- Sessions
CREATE TABLE IF NOT EXISTS sessions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
token_hash TEXT NOT NULL UNIQUE,
ip TEXT NOT NULL DEFAULT '',
user_agent TEXT NOT NULL DEFAULT '',
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
expires_at TIMESTAMP NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_sessions_token ON sessions(token_hash);
CREATE INDEX IF NOT EXISTS idx_sessions_user ON sessions(user_id);
CREATE INDEX IF NOT EXISTS idx_sessions_expires ON sessions(expires_at);
-- IP bans
CREATE TABLE IF NOT EXISTS ip_bans (
id INTEGER PRIMARY KEY AUTOINCREMENT,
ip TEXT NOT NULL UNIQUE,
reason TEXT NOT NULL DEFAULT '',
banned_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
expires_at TIMESTAMP,
released_by TEXT NOT NULL DEFAULT ''
);
CREATE INDEX IF NOT EXISTS idx_ip_bans_ip ON ip_bans(ip);
CREATE INDEX IF NOT EXISTS idx_ip_bans_expires ON ip_bans(expires_at);
-- Login attempts
CREATE TABLE IF NOT EXISTS login_attempts (
id INTEGER PRIMARY KEY AUTOINCREMENT,
ip TEXT NOT NULL,
user_email TEXT NOT NULL DEFAULT '',
success BOOLEAN NOT NULL DEFAULT 0,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_login_attempts_ip ON login_attempts(ip, created_at);
-- Security events
CREATE TABLE IF NOT EXISTS security_events (
id INTEGER PRIMARY KEY AUTOINCREMENT,
type TEXT NOT NULL,
ip TEXT NOT NULL DEFAULT '',
user_id INTEGER REFERENCES users(id) ON DELETE SET NULL,
detail TEXT NOT NULL DEFAULT '',
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_security_events_created ON security_events(created_at);
-- External accounts (Gmail / Outlook / custom IMAP)
CREATE TABLE IF NOT EXISTS external_accounts (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
provider TEXT NOT NULL,
email_address TEXT NOT NULL,
display_name TEXT NOT NULL DEFAULT '',
access_token_enc BLOB,
refresh_token_enc BLOB,
token_expiry TIMESTAMP,
imap_host TEXT NOT NULL DEFAULT '',
imap_port INTEGER NOT NULL DEFAULT 993,
smtp_host TEXT NOT NULL DEFAULT '',
smtp_port INTEGER NOT NULL DEFAULT 587,
enabled BOOLEAN NOT NULL DEFAULT 1,
sync_enabled BOOLEAN NOT NULL DEFAULT 1,
last_sync TIMESTAMP,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_ext_accounts_user ON external_accounts(user_id);
-- Address books (CardDAV)
CREATE TABLE IF NOT EXISTS address_books (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
name TEXT NOT NULL,
description TEXT NOT NULL DEFAULT '',
color TEXT NOT NULL DEFAULT '#4A90E2',
sync_token INTEGER NOT NULL DEFAULT 1,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
-- Contacts (CardDAV)
CREATE TABLE IF NOT EXISTS contacts (
id INTEGER PRIMARY KEY AUTOINCREMENT,
address_book_id INTEGER NOT NULL REFERENCES address_books(id) ON DELETE CASCADE,
uid TEXT NOT NULL,
vcard_enc BLOB NOT NULL,
etag TEXT NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
UNIQUE(address_book_id, uid)
);
CREATE INDEX IF NOT EXISTS idx_contacts_book ON contacts(address_book_id);
-- Calendars (CalDAV)
CREATE TABLE IF NOT EXISTS calendars (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
name TEXT NOT NULL,
description TEXT NOT NULL DEFAULT '',
color TEXT NOT NULL DEFAULT '#4CAF50',
timezone TEXT NOT NULL DEFAULT 'UTC',
sync_token INTEGER NOT NULL DEFAULT 1,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
-- Calendar events (CalDAV)
CREATE TABLE IF NOT EXISTS calendar_events (
id INTEGER PRIMARY KEY AUTOINCREMENT,
calendar_id INTEGER NOT NULL REFERENCES calendars(id) ON DELETE CASCADE,
uid TEXT NOT NULL,
ical_enc BLOB NOT NULL,
etag TEXT NOT NULL,
dt_start TIMESTAMP,
dt_end TIMESTAMP,
summary TEXT NOT NULL DEFAULT '',
recurring BOOLEAN NOT NULL DEFAULT 0,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
UNIQUE(calendar_id, uid)
);
CREATE INDEX IF NOT EXISTS idx_events_calendar ON calendar_events(calendar_id);
CREATE INDEX IF NOT EXISTS idx_events_dtstart ON calendar_events(calendar_id, dt_start);
-- Spam Bayesian tokens (per-user)
CREATE TABLE IF NOT EXISTS spam_tokens (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
token TEXT NOT NULL,
spam_count INTEGER NOT NULL DEFAULT 0,
ham_count INTEGER NOT NULL DEFAULT 0,
UNIQUE(user_id, token)
);
CREATE INDEX IF NOT EXISTS idx_spam_tokens_user ON spam_tokens(user_id, token);
`
+158
View File
@@ -0,0 +1,158 @@
package db
import (
"context"
"database/sql"
"fmt"
"time"
"ghb.freebede.com/nahakubuilder/mailgosend/internal/models"
)
// GetUserByEmail returns the user with the given email (case-insensitive), or nil.
func (d *DB) GetUserByEmail(ctx context.Context, email string) (*models.User, error) {
row := d.db.QueryRowContext(ctx, `
SELECT id, domain_id, username, email, password_hash, display_name,
quota_bytes, used_bytes, enabled, admin, domain_admin,
mfa_secret_enc, mfa_enabled, recovery_codes_enc, created_at, last_login
FROM users WHERE lower(email)=lower(?)`, email)
return scanUser(row)
}
// GetUserByID returns the user with the given ID, or nil.
func (d *DB) GetUserByID(ctx context.Context, id int64) (*models.User, error) {
row := d.db.QueryRowContext(ctx, `
SELECT id, domain_id, username, email, password_hash, display_name,
quota_bytes, used_bytes, enabled, admin, domain_admin,
mfa_secret_enc, mfa_enabled, recovery_codes_enc, created_at, last_login
FROM users WHERE id=?`, id)
return scanUser(row)
}
// UserExistsByEmail returns true if any user (enabled or not) has this email or alias.
func (d *DB) UserExistsByEmail(ctx context.Context, email string) (bool, error) {
var count int
err := d.db.QueryRowContext(ctx, `
SELECT COUNT(*) FROM (
SELECT 1 FROM users WHERE lower(email)=lower(?) AND enabled=1
UNION ALL
SELECT 1 FROM user_aliases WHERE lower(alias_email)=lower(?)
)`, email, email).Scan(&count)
return count > 0, err
}
// ResolveEmail returns the canonical user for an email or alias, or nil.
func (d *DB) ResolveEmail(ctx context.Context, email string) (*models.User, error) {
// Direct match first.
u, err := d.GetUserByEmail(ctx, email)
if err != nil || u != nil {
return u, err
}
// Alias match.
var userID int64
err = d.db.QueryRowContext(ctx,
"SELECT user_id FROM user_aliases WHERE lower(alias_email)=lower(?)", email).
Scan(&userID)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return d.GetUserByID(ctx, userID)
}
// CreateUser inserts a new user. Returns the new ID.
func (d *DB) CreateUser(ctx context.Context, domainID int64, username, email, passwordHash, displayName string, quotaBytes int64, domainAdmin bool) (int64, error) {
res, err := d.db.ExecContext(ctx, `
INSERT INTO users
(domain_id, username, email, password_hash, display_name, quota_bytes,
enabled, admin, domain_admin, created_at)
VALUES (?, ?, ?, ?, ?, ?, 1, 0, ?, ?)`,
domainID, username, email, passwordHash, displayName, quotaBytes, domainAdmin,
time.Now().UTC())
if err != nil {
return 0, fmt.Errorf("create user: %w", err)
}
return res.LastInsertId()
}
// UpdateUsedBytes sets the cached used_bytes for a user (approximate, updated on store).
func (d *DB) UpdateUsedBytes(ctx context.Context, userID int64, delta int64) error {
_, err := d.db.ExecContext(ctx,
"UPDATE users SET used_bytes = MAX(0, used_bytes + ?) WHERE id=?",
delta, userID)
return err
}
// UpdateLastLogin sets last_login to now.
func (d *DB) UpdateLastLogin(ctx context.Context, userID int64) {
d.db.ExecContext(ctx, //nolint:errcheck — best-effort
"UPDATE users SET last_login=? WHERE id=?", time.Now().UTC(), userID)
}
// ListUsers returns all users for a domain.
func (d *DB) ListUsers(ctx context.Context, domainID int64) ([]*models.User, error) {
rows, err := d.db.QueryContext(ctx, `
SELECT id, domain_id, username, email, password_hash, display_name,
quota_bytes, used_bytes, enabled, admin, domain_admin,
mfa_secret_enc, mfa_enabled, recovery_codes_enc, created_at, last_login
FROM users WHERE domain_id=? ORDER BY email`, domainID)
if err != nil {
return nil, err
}
defer rows.Close()
var users []*models.User
for rows.Next() {
var u models.User
var mfaEnc, rcEnc []byte
var lastLogin sql.NullTime
err := rows.Scan(
&u.ID, &u.DomainID, &u.Username, &u.Email, &u.PasswordHash,
&u.DisplayName, &u.QuotaBytes, &u.UsedBytes, &u.Enabled,
&u.Admin, &u.DomainAdmin,
&mfaEnc, &u.MFAEnabled, &rcEnc,
&u.CreatedAt, &lastLogin,
)
if err != nil {
return nil, err
}
u.MFASecretEnc = mfaEnc
u.RecoveryCodesEnc = rcEnc
if lastLogin.Valid {
u.LastLogin = lastLogin.Time
}
users = append(users, &u)
}
return users, rows.Err()
}
// ---- private ----
func scanUser(row *sql.Row) (*models.User, error) {
var u models.User
var mfaEnc, rcEnc []byte
var lastLogin sql.NullTime
err := row.Scan(
&u.ID, &u.DomainID, &u.Username, &u.Email, &u.PasswordHash,
&u.DisplayName, &u.QuotaBytes, &u.UsedBytes, &u.Enabled,
&u.Admin, &u.DomainAdmin,
&mfaEnc, &u.MFAEnabled, &rcEnc,
&u.CreatedAt, &lastLogin,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("scan user: %w", err)
}
u.MFASecretEnc = mfaEnc
u.RecoveryCodesEnc = rcEnc
if lastLogin.Valid {
u.LastLogin = lastLogin.Time
}
return &u, nil
}
+211
View File
@@ -0,0 +1,211 @@
// Package delivery implements outbound SMTP delivery: MX lookup, connection,
// TLS upgrade, message submission. Used by the queue worker.
package delivery
import (
"bytes"
"context"
"crypto/tls"
"fmt"
"log"
"net"
"net/smtp"
"sort"
"strings"
"time"
)
const (
deliveryTimeout = 60 * time.Second
connectTimeout = 30 * time.Second
)
// Result holds the outcome of a single delivery attempt.
type Result struct {
MXHost string
SMTPCode int
Message string
Perm bool // true = permanent failure (5xx), don't retry
}
// Deliver attempts to deliver raw to the address to using a fresh SMTP
// connection to the recipient domain's MX. Signs with the given EHLO hostname.
// Returns a Result describing success or failure.
func Deliver(ctx context.Context, ehloHostname, from, to string, raw []byte) *Result {
at := strings.LastIndex(to, "@")
if at < 0 {
return &Result{Perm: true, Message: "invalid recipient address: " + to}
}
toDomain := strings.ToLower(to[at+1:])
mxHosts, err := lookupMX(ctx, toDomain)
if err != nil {
return &Result{Message: fmt.Sprintf("MX lookup %s: %v", toDomain, err)}
}
if len(mxHosts) == 0 {
return &Result{Perm: true, Message: "no MX records for " + toDomain}
}
var lastResult *Result
for _, host := range mxHosts {
r := deliver(ctx, ehloHostname, host, from, to, raw)
lastResult = r
if r.SMTPCode == 0 || r.SMTPCode/100 == 4 {
// Temp error or connection failure — try next MX.
continue
}
// 2xx = success, 5xx = permanent failure — stop trying.
return r
}
return lastResult
}
// deliver connects to one MX host and submits the message.
func deliver(ctx context.Context, ehloHostname, mxHost, from, to string, raw []byte) *Result {
addr := net.JoinHostPort(mxHost, "25")
dialCtx, cancel := context.WithTimeout(ctx, connectTimeout)
defer cancel()
conn, err := (&net.Dialer{}).DialContext(dialCtx, "tcp", addr)
if err != nil {
return &Result{MXHost: mxHost, Message: fmt.Sprintf("connect %s: %v", addr, err)}
}
// Wrap in a deadline for the full SMTP exchange.
deadline := time.Now().Add(deliveryTimeout)
_ = conn.SetDeadline(deadline)
c, err := smtp.NewClient(conn, mxHost)
if err != nil {
conn.Close()
return &Result{MXHost: mxHost, Message: fmt.Sprintf("smtp client %s: %v", mxHost, err)}
}
defer c.Close()
// EHLO.
if err := c.Hello(ehloHostname); err != nil {
return &Result{MXHost: mxHost, Message: fmt.Sprintf("EHLO: %v", err)}
}
// Try STARTTLS (best effort — not all remote servers require it).
if ok, _ := c.Extension("STARTTLS"); ok {
tlsCfg := &tls.Config{
ServerName: mxHost,
MinVersion: tls.VersionTLS12,
}
if err := c.StartTLS(tlsCfg); err != nil {
log.Printf("[delivery] STARTTLS %s failed (continuing plain): %v", mxHost, err)
}
}
// MAIL FROM.
if err := c.Mail(from); err != nil {
return smtpResult(mxHost, err)
}
// RCPT TO.
if err := c.Rcpt(to); err != nil {
return smtpResult(mxHost, err)
}
// DATA.
wc, err := c.Data()
if err != nil {
return smtpResult(mxHost, err)
}
if _, err := wc.Write(raw); err != nil {
wc.Close()
return &Result{MXHost: mxHost, Message: fmt.Sprintf("write data: %v", err)}
}
if err := wc.Close(); err != nil {
return smtpResult(mxHost, err)
}
_ = c.Quit()
log.Printf("[delivery] delivered %s → %s via %s", from, to, mxHost)
return &Result{MXHost: mxHost, SMTPCode: 250, Message: "2.0.0 OK"}
}
// lookupMX resolves MX records and returns hosts sorted by priority.
func lookupMX(ctx context.Context, domain string) ([]string, error) {
r := net.DefaultResolver
mxs, err := r.LookupMX(ctx, domain)
if err != nil {
// Treat NXDOMAIN as no-MX (not a transient error).
if dnsErr, ok := err.(*net.DNSError); ok && dnsErr.IsNotFound {
// Fall back: try A record (some small domains don't publish MX).
addrs, aerr := r.LookupHost(ctx, domain)
if aerr != nil || len(addrs) == 0 {
return nil, nil
}
return []string{domain}, nil
}
return nil, err
}
// Sort by priority ascending.
sort.Slice(mxs, func(i, j int) bool {
return mxs[i].Pref < mxs[j].Pref
})
hosts := make([]string, 0, len(mxs))
for _, mx := range mxs {
h := strings.TrimSuffix(mx.Host, ".")
if h != "" {
hosts = append(hosts, h)
}
}
return hosts, nil
}
// smtpResult maps an smtp error to a Result, marking 5xx as permanent.
func smtpResult(mxHost string, err error) *Result {
if err == nil {
return &Result{MXHost: mxHost, SMTPCode: 250, Message: "2.0.0 OK"}
}
msg := err.Error()
code := parseCode(msg)
return &Result{
MXHost: mxHost,
SMTPCode: code,
Message: msg,
Perm: code/100 == 5,
}
}
// parseCode extracts the leading 3-digit SMTP code from an error string.
func parseCode(s string) int {
if len(s) < 3 {
return 0
}
var code int
_, _ = fmt.Sscanf(s[:3], "%d", &code)
return code
}
// IsLocal reports whether the given domain is served locally. Used by the
// queue worker to skip outbound delivery for internal mail.
// The caller supplies its own domain list to avoid a DB call here.
func IsLocal(domain string, localDomains []string) bool {
domain = strings.ToLower(domain)
for _, d := range localDomains {
if strings.EqualFold(d, domain) {
return true
}
}
return false
}
// RecipientDomain returns the domain part of an email address.
func RecipientDomain(addr string) string {
at := strings.LastIndex(addr, "@")
if at < 0 {
return ""
}
return strings.ToLower(addr[at+1:])
}
// _ suppresses unused import if bytes is only used in tests later.
var _ = bytes.NewReader
+465
View File
@@ -0,0 +1,465 @@
// Package dkim implements DKIM signing (outbound) and verification (inbound)
// for RSA-2048 and Ed25519 keys per RFC 6376 and RFC 8463.
package dkim
import (
gocrypto "crypto"
"crypto/ed25519"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"fmt"
"net"
"strings"
"time"
)
// Signer holds a loaded private key and signing metadata.
type Signer struct {
privateKey gocrypto.PrivateKey // *rsa.PrivateKey or ed25519.PrivateKey
selector string
domain string
algo string // "rsa2048" | "ed25519"
}
// GenerateKeyPair generates a DKIM key pair for the given algorithm.
// algo must be "rsa2048" or "ed25519".
// Returns PEM-encoded private key and PEM-encoded public key.
func GenerateKeyPair(algo string) (privateKeyPEM, publicKeyPEM string, err error) {
switch algo {
case "rsa2048":
priv, genErr := rsa.GenerateKey(rand.Reader, 2048)
if genErr != nil {
return "", "", fmt.Errorf("dkim: generate RSA key: %w", genErr)
}
privDER := x509.MarshalPKCS1PrivateKey(priv)
privBlock := &pem.Block{Type: "RSA PRIVATE KEY", Bytes: privDER}
privateKeyPEM = string(pem.EncodeToMemory(privBlock))
pubDER, marshalErr := x509.MarshalPKIXPublicKey(&priv.PublicKey)
if marshalErr != nil {
return "", "", fmt.Errorf("dkim: marshal RSA public key: %w", marshalErr)
}
pubBlock := &pem.Block{Type: "PUBLIC KEY", Bytes: pubDER}
publicKeyPEM = string(pem.EncodeToMemory(pubBlock))
return privateKeyPEM, publicKeyPEM, nil
case "ed25519":
pub, priv, genErr := ed25519.GenerateKey(rand.Reader)
if genErr != nil {
return "", "", fmt.Errorf("dkim: generate Ed25519 key: %w", genErr)
}
privDER, marshalErr := x509.MarshalPKCS8PrivateKey(priv)
if marshalErr != nil {
return "", "", fmt.Errorf("dkim: marshal Ed25519 private key: %w", marshalErr)
}
privBlock := &pem.Block{Type: "PRIVATE KEY", Bytes: privDER}
privateKeyPEM = string(pem.EncodeToMemory(privBlock))
pubDER, marshalErr := x509.MarshalPKIXPublicKey(pub)
if marshalErr != nil {
return "", "", fmt.Errorf("dkim: marshal Ed25519 public key: %w", marshalErr)
}
pubBlock := &pem.Block{Type: "PUBLIC KEY", Bytes: pubDER}
publicKeyPEM = string(pem.EncodeToMemory(pubBlock))
return privateKeyPEM, publicKeyPEM, nil
default:
return "", "", fmt.Errorf("dkim: unsupported algorithm %q (want rsa2048 or ed25519)", algo)
}
}
// NewSigner parses a PEM private key and returns a Signer.
// Tries PKCS1 RSA first, then PKCS8 (Ed25519 or RSA).
func NewSigner(privateKeyPEM, domain, selector string) (*Signer, error) {
if domain == "" {
return nil, fmt.Errorf("dkim: domain must not be empty")
}
if selector == "" {
return nil, fmt.Errorf("dkim: selector must not be empty")
}
block, _ := pem.Decode([]byte(privateKeyPEM))
if block == nil {
return nil, fmt.Errorf("dkim: failed to decode PEM block")
}
// Try PKCS1 RSA.
if rsaKey, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil {
return &Signer{
privateKey: rsaKey,
selector: selector,
domain: domain,
algo: "rsa2048",
}, nil
}
// Try PKCS8 (covers Ed25519 and RSA PKCS8).
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("dkim: parse private key: %w", err)
}
switch k := key.(type) {
case ed25519.PrivateKey:
return &Signer{
privateKey: k,
selector: selector,
domain: domain,
algo: "ed25519",
}, nil
case *rsa.PrivateKey:
return &Signer{
privateKey: k,
selector: selector,
domain: domain,
algo: "rsa2048",
}, nil
default:
return nil, fmt.Errorf("dkim: unsupported private key type %T", key)
}
}
// Sign produces a DKIM-Signature header for the given RFC822 message.
// Returns the complete "DKIM-Signature: ..." header line (no trailing CRLF).
func (s *Signer) Sign(message []byte) (string, error) {
// Split at first blank line (\r\n\r\n or \n\n).
headerBytes, bodyBytes := splitMessage(message)
// Canonicalize body (relaxed).
canonBody := canonicalizeBodyRelaxed(bodyBytes)
// Body hash.
bodyHash := sha256.Sum256(canonBody)
bh := base64.StdEncoding.EncodeToString(bodyHash[:])
// Determine algorithm tag.
var aTag string
switch s.algo {
case "ed25519":
aTag = "ed25519-sha256"
default:
aTag = "rsa-sha256"
}
// Signed header fields (lower-case, in sign order).
signedFields := "from:to:subject:date:message-id"
// Build DKIM-Signature with b= empty.
ts := fmt.Sprintf("%d", time.Now().Unix())
sigHeader := fmt.Sprintf(
"DKIM-Signature: v=1; a=%s; c=relaxed/relaxed; d=%s; s=%s; t=%s; bh=%s; h=%s; b=",
aTag, s.domain, s.selector, ts, bh, signedFields,
)
// Canonicalize headers to sign + the sig header (b= empty).
hdrMap := parseHeaders(headerBytes)
var sb strings.Builder
for _, field := range strings.Split(signedFields, ":") {
field = strings.TrimSpace(field)
val, ok := hdrMap[strings.ToLower(field)]
if !ok {
continue
}
sb.WriteString(canonicalizeHeaderRelaxed(field, val))
sb.WriteString("\r\n")
}
// Append the DKIM-Signature line itself (with b= empty), canonicalized.
sb.WriteString(canonicalizeHeaderRelaxed("dkim-signature", strings.TrimPrefix(sigHeader, "DKIM-Signature: ")))
dataToSign := []byte(sb.String())
hash := sha256.Sum256(dataToSign)
var sigBytes []byte
var signErr error
switch s.algo {
case "ed25519":
privKey, ok := s.privateKey.(ed25519.PrivateKey)
if !ok {
return "", fmt.Errorf("dkim: private key type mismatch for ed25519")
}
// RFC 8463: ed25519-sha256 — sign the SHA-256 hash of the data.
sigBytes = ed25519.Sign(privKey, hash[:])
default:
privKey, ok := s.privateKey.(*rsa.PrivateKey)
if !ok {
return "", fmt.Errorf("dkim: private key type mismatch for RSA")
}
sigBytes, signErr = rsa.SignPKCS1v15(rand.Reader, privKey, gocrypto.SHA256, hash[:])
if signErr != nil {
return "", fmt.Errorf("dkim: RSA sign: %w", signErr)
}
}
b := base64.StdEncoding.EncodeToString(sigBytes)
return sigHeader + b, nil
}
// Verify finds the DKIM-Signature in message, fetches the DNS public key,
// and verifies the signature. Returns the signing domain on success.
func Verify(message []byte) (domain string, err error) {
headerBytes, bodyBytes := splitMessage(message)
hdrMap := parseHeaders(headerBytes)
sigVal, ok := hdrMap["dkim-signature"]
if !ok {
return "", fmt.Errorf("dkim: no DKIM-Signature header found")
}
params := parseDKIMParams(sigVal)
sel, ok := params["s"]
if !ok || sel == "" {
return "", fmt.Errorf("dkim: missing selector (s=) in DKIM-Signature")
}
dom, ok := params["d"]
if !ok || dom == "" {
return "", fmt.Errorf("dkim: missing domain (d=) in DKIM-Signature")
}
bh64, _ := params["bh"]
b64, ok := params["b"]
if !ok {
return "", fmt.Errorf("dkim: missing signature (b=) in DKIM-Signature")
}
signedFields, _ := params["h"]
// Verify body hash.
canonBody := canonicalizeBodyRelaxed(bodyBytes)
bodyHash := sha256.Sum256(canonBody)
expectedBH := base64.StdEncoding.EncodeToString(bodyHash[:])
// Strip whitespace from DNS-retrieved bh for comparison.
cleanBH := strings.Map(func(r rune) rune {
if r == ' ' || r == '\t' || r == '\r' || r == '\n' {
return -1
}
return r
}, bh64)
if cleanBH != expectedBH {
return "", fmt.Errorf("dkim: body hash mismatch")
}
// DNS lookup.
lookupName := sel + "._domainkey." + dom
txts, lookupErr := net.LookupTXT(lookupName)
if lookupErr != nil {
return "", fmt.Errorf("dkim: DNS lookup %s: %w", lookupName, lookupErr)
}
if len(txts) == 0 {
return "", fmt.Errorf("dkim: no TXT record at %s", lookupName)
}
// Join TXT record parts (DNS may split at 255 bytes).
dnsRecord := strings.Join(txts, "")
dnsParams := parseDKIMParams(dnsRecord)
pVal, ok := dnsParams["p"]
if !ok || pVal == "" {
return "", fmt.Errorf("dkim: no p= (public key) in DNS TXT record")
}
pubDER, decErr := base64.StdEncoding.DecodeString(pVal)
if decErr != nil {
return "", fmt.Errorf("dkim: decode public key base64: %w", decErr)
}
// Rebuild the data-to-sign exactly as Signer.Sign did.
// Canonicalize the signed headers.
var sb strings.Builder
for _, field := range strings.Split(signedFields, ":") {
field = strings.TrimSpace(field)
val, ok2 := hdrMap[strings.ToLower(field)]
if !ok2 {
continue
}
sb.WriteString(canonicalizeHeaderRelaxed(field, val))
sb.WriteString("\r\n")
}
// Append DKIM-Signature with b= stripped (zeroed), canonicalized.
cleanedSig := stripBValue(sigVal)
sb.WriteString(canonicalizeHeaderRelaxed("dkim-signature", cleanedSig))
dataToVerify := []byte(sb.String())
hash := sha256.Sum256(dataToVerify)
sigBytes, decErr := base64.StdEncoding.DecodeString(
strings.Map(func(r rune) rune {
if r == ' ' || r == '\t' || r == '\r' || r == '\n' {
return -1
}
return r
}, b64),
)
if decErr != nil {
return "", fmt.Errorf("dkim: decode signature base64: %w", decErr)
}
// Try RSA PKIX public key first.
if pubKey, rsaErr := x509.ParsePKIXPublicKey(pubDER); rsaErr == nil {
switch k := pubKey.(type) {
case *rsa.PublicKey:
if err := rsa.VerifyPKCS1v15(k, gocrypto.SHA256, hash[:], sigBytes); err != nil {
return "", fmt.Errorf("dkim: RSA signature invalid: %w", err)
}
return dom, nil
case ed25519.PublicKey:
if !ed25519.Verify(k, hash[:], sigBytes) {
return "", fmt.Errorf("dkim: Ed25519 signature invalid")
}
return dom, nil
default:
return "", fmt.Errorf("dkim: unsupported public key type %T", pubKey)
}
}
// Try raw Ed25519 public key (32 bytes).
if len(pubDER) == ed25519.PublicKeySize {
edPub := ed25519.PublicKey(pubDER)
if !ed25519.Verify(edPub, hash[:], sigBytes) {
return "", fmt.Errorf("dkim: Ed25519 signature invalid")
}
return dom, nil
}
return "", fmt.Errorf("dkim: unable to parse public key from DNS record")
}
// DNSRecord returns the DKIM TXT record string for publishing.
func DNSRecord(selector, domain, publicKeyPEM string) string {
block, _ := pem.Decode([]byte(publicKeyPEM))
var p string
if block != nil {
p = base64.StdEncoding.EncodeToString(block.Bytes)
}
return fmt.Sprintf("%s._domainkey.%s. IN TXT \"v=DKIM1; k=rsa; p=%s\"", selector, domain, p)
}
// SPFRecord returns the recommended SPF TXT record for a domain.
func SPFRecord(domain string) string {
return fmt.Sprintf("%s IN TXT \"v=spf1 mx a ~all\"", domain)
}
// ---- Internals ----
// splitMessage splits at the first blank line (CRLF or LF variants).
func splitMessage(msg []byte) (headers, body []byte) {
s := string(msg)
for _, sep := range []string{"\r\n\r\n", "\n\n"} {
if idx := strings.Index(s, sep); idx >= 0 {
return []byte(s[:idx]), []byte(s[idx+len(sep):])
}
}
return msg, nil
}
// canonicalizeBodyRelaxed applies RFC 6376 relaxed body canonicalization.
func canonicalizeBodyRelaxed(body []byte) []byte {
if len(body) == 0 {
return []byte("\r\n")
}
// Normalize line endings.
s := strings.ReplaceAll(string(body), "\r\n", "\n")
s = strings.ReplaceAll(s, "\r", "\n")
lines := strings.Split(s, "\n")
var out []string
for _, line := range lines {
// Collapse whitespace runs within each line, trim trailing whitespace.
fields := strings.Fields(line)
out = append(out, strings.Join(fields, " "))
}
// Strip trailing empty lines.
for len(out) > 0 && out[len(out)-1] == "" {
out = out[:len(out)-1]
}
result := strings.Join(out, "\r\n") + "\r\n"
return []byte(result)
}
// canonicalizeHeaderRelaxed applies RFC 6376 relaxed header canonicalization
// to a single header name + value. Returns "lowername:value" (no trailing CRLF).
func canonicalizeHeaderRelaxed(name, value string) string {
name = strings.ToLower(strings.TrimSpace(name))
// Collapse all whitespace (including CRLF folding) to a single space.
value = strings.ReplaceAll(value, "\r\n", " ")
value = strings.ReplaceAll(value, "\r", " ")
value = strings.ReplaceAll(value, "\n", " ")
// Collapse multiple spaces.
parts := strings.Fields(value)
value = strings.Join(parts, " ")
value = strings.TrimSpace(value)
return name + ":" + value
}
// parseHeaders builds a map of lower-cased header name → last value.
// Handles multi-line (folded) headers.
func parseHeaders(headerBytes []byte) map[string]string {
m := make(map[string]string)
lines := strings.Split(strings.ReplaceAll(string(headerBytes), "\r\n", "\n"), "\n")
var curName, curVal string
flush := func() {
if curName != "" {
m[strings.ToLower(curName)] = curVal
}
}
for _, line := range lines {
if line == "" {
continue
}
if len(line) > 0 && (line[0] == ' ' || line[0] == '\t') {
// Folded continuation.
curVal += " " + strings.TrimSpace(line)
continue
}
flush()
idx := strings.IndexByte(line, ':')
if idx < 0 {
curName = ""
curVal = ""
continue
}
curName = line[:idx]
curVal = strings.TrimSpace(line[idx+1:])
}
flush()
return m
}
// parseDKIMParams parses semicolon-separated tag=value pairs (DKIM and DNS TXT).
func parseDKIMParams(s string) map[string]string {
m := make(map[string]string)
for _, part := range strings.Split(s, ";") {
part = strings.TrimSpace(part)
if part == "" {
continue
}
idx := strings.IndexByte(part, '=')
if idx < 0 {
continue
}
key := strings.TrimSpace(part[:idx])
val := strings.TrimSpace(part[idx+1:])
m[key] = val
}
return m
}
// stripBValue removes the value of the b= tag (sets it to empty) so the
// header can be re-canonicalized for verification.
func stripBValue(sigVal string) string {
// Find b= and zero everything after it up to the next ;.
parts := strings.Split(sigVal, ";")
for i, p := range parts {
trimmed := strings.TrimSpace(p)
if strings.HasPrefix(trimmed, "b=") {
parts[i] = " b="
}
}
return strings.Join(parts, ";")
}
+125
View File
@@ -0,0 +1,125 @@
// Package dmarc implements basic DMARC (RFC 7489) policy evaluation.
// Only stdlib net is used for DNS. Supports p=none, p=quarantine, p=reject.
package dmarc
import (
"fmt"
"net"
"strings"
"ghb.freebede.com/nahakubuilder/mailgosend/internal/spf"
)
// Policy represents the DMARC disposition policy.
type Policy int
const (
PolicyNone Policy = iota // p=none
PolicyQuarantine // p=quarantine
PolicyReject // p=reject
)
func (p Policy) String() string {
return [...]string{"none", "quarantine", "reject"}[p]
}
// Result holds the DMARC evaluation outcome.
type Result struct {
Pass bool
Policy Policy
Disposition string // none | quarantine | reject
Reason string
}
// Check evaluates DMARC policy for the given envelope-from domain.
// spfPass: whether SPF passed for this envelope-from domain.
// dkimDomains: set of signing domains that produced a valid DKIM signature.
// Returns a Result and logs no external calls beyond DNS.
func Check(fromDomain string, spfResult spf.Result, dkimDomains []string) *Result {
record, err := fetchDMARC(fromDomain)
if err != nil || record == "" {
// No DMARC record — not a failure, but can't enforce.
return &Result{
Pass: true,
Policy: PolicyNone,
Disposition: "none",
Reason: "no DMARC record for " + fromDomain,
}
}
policy := parsePolicy(record)
// SPF alignment (relaxed): envelope-from domain must equal or be a subdomain
// of the DMARC organisational domain.
orgDomain := orgDomainOf(fromDomain)
spfAligned := spfResult == spf.ResultPass
// DKIM alignment (relaxed): at least one valid DKIM signature domain must
// equal or be a subdomain of the org domain.
dkimAligned := false
for _, d := range dkimDomains {
if strings.EqualFold(d, fromDomain) || strings.HasSuffix(strings.ToLower(d), "."+strings.ToLower(orgDomain)) {
dkimAligned = true
break
}
}
pass := spfAligned || dkimAligned
reason := fmt.Sprintf("SPF=%v DKIM-aligned=%v policy=%s", spfAligned, dkimAligned, policy)
if pass {
return &Result{
Pass: true,
Policy: policy,
Disposition: "none",
Reason: reason,
}
}
return &Result{
Pass: false,
Policy: policy,
Disposition: policy.String(),
Reason: reason,
}
}
// fetchDMARC queries TXT at _dmarc.<domain> and returns the first DMARC record.
func fetchDMARC(domain string) (string, error) {
txts, err := net.LookupTXT("_dmarc." + domain)
if err != nil {
if dnsErr, ok := err.(*net.DNSError); ok && dnsErr.IsNotFound {
return "", nil
}
return "", err
}
for _, txt := range txts {
txt = strings.TrimSpace(txt)
if strings.HasPrefix(txt, "v=DMARC1") {
return txt, nil
}
}
return "", nil
}
// parsePolicy extracts the p= tag from a DMARC record.
func parsePolicy(record string) Policy {
for _, part := range strings.Split(record, ";") {
part = strings.TrimSpace(part)
if strings.HasPrefix(strings.ToLower(part), "p=") {
switch strings.ToLower(part[2:]) {
case "quarantine":
return PolicyQuarantine
case "reject":
return PolicyReject
}
}
}
return PolicyNone
}
// orgDomainOf returns a simplified "organisational domain" — for this
// implementation we use the domain as-is (a full PSL lookup is out of scope).
func orgDomainOf(domain string) string {
return strings.ToLower(domain)
}
+80
View File
@@ -0,0 +1,80 @@
// Package imap provides an IMAP4rev2 server backed by the encrypted SQLite store.
package imap
import (
"context"
"crypto/tls"
"fmt"
"log"
"net"
"time"
"github.com/emersion/go-imap/v2"
"github.com/emersion/go-imap/v2/imapserver"
"ghb.freebede.com/nahakubuilder/mailgosend/internal/auth"
"ghb.freebede.com/nahakubuilder/mailgosend/internal/config"
appCrypto "ghb.freebede.com/nahakubuilder/mailgosend/internal/crypto"
"ghb.freebede.com/nahakubuilder/mailgosend/internal/db"
)
// Deps groups dependencies needed by the IMAP server.
type Deps struct {
DB *db.DB
Crypt *appCrypto.Crypto
Brute *auth.BruteGuard
Cfg *config.Config
}
// NewServer creates an IMAP4rev2 server for port 143 (STARTTLS).
func NewServer(d *Deps, tlsCfg *tls.Config) *imapserver.Server {
return imapserver.New(&imapserver.Options{
NewSession: func(c *imapserver.Conn) (imapserver.Session, *imapserver.GreetingData, error) {
clientIP := connRemoteIP(c)
if d.Brute != nil && d.Cfg.BruteMaxTries > 0 {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
if banned, err := d.Brute.IsBanned(ctx, clientIP); err == nil && banned {
cancel()
return nil, nil, fmt.Errorf("connection refused")
}
cancel()
}
return &IMAPSession{deps: d, clientIP: clientIP}, &imapserver.GreetingData{}, nil
},
Caps: imap.CapSet{
imap.CapIMAP4rev2: {},
imap.CapIMAP4rev1: {},
},
TLSConfig: tlsCfg,
InsecureAuth: tlsCfg == nil,
})
}
// ListenAndServe listens on addr (port 143 / STARTTLS).
func ListenAndServe(s *imapserver.Server, addr, name string) error {
ln, err := net.Listen("tcp", addr)
if err != nil {
return fmt.Errorf("%s listen %s: %w", name, addr, err)
}
log.Printf("[%s] listening on %s", name, addr)
return s.Serve(ln)
}
// ListenAndServeTLS listens with implicit TLS (port 993 / IMAPS).
func ListenAndServeTLS(s *imapserver.Server, addr, name string) error {
log.Printf("[%s] listening on %s (TLS)", name, addr)
return s.ListenAndServeTLS(addr)
}
// connRemoteIP returns the client IP from an imapserver.Conn.
func connRemoteIP(c *imapserver.Conn) string {
if c == nil {
return ""
}
nc := c.NetConn()
if nc == nil {
return ""
}
host, _, _ := net.SplitHostPort(nc.RemoteAddr().String())
return host
}
+884
View File
@@ -0,0 +1,884 @@
package imap
import (
"bufio"
"bytes"
"context"
"fmt"
"log"
"sort"
"strings"
"time"
"github.com/emersion/go-imap/v2"
"github.com/emersion/go-imap/v2/imapserver"
"github.com/emersion/go-message/textproto"
"ghb.freebede.com/nahakubuilder/mailgosend/internal/crypto"
"ghb.freebede.com/nahakubuilder/mailgosend/internal/db"
"ghb.freebede.com/nahakubuilder/mailgosend/internal/models"
)
const mailboxDelim rune = '/'
// msgEntry holds the in-memory descriptor for one message in the selected mailbox.
type msgEntry struct {
dbID int64
uid imap.UID
isRead bool
isStarred bool
isDraft bool
isDeleted bool
extraFlags string
size int64
internalDate time.Time
}
func (e *msgEntry) flagList() []imap.Flag {
var flags []imap.Flag
if e.isRead {
flags = append(flags, imap.FlagSeen)
}
if e.isStarred {
flags = append(flags, imap.FlagFlagged)
}
if e.isDraft {
flags = append(flags, imap.FlagDraft)
}
if e.isDeleted {
flags = append(flags, imap.FlagDeleted)
}
for _, f := range strings.Fields(e.extraFlags) {
flags = append(flags, imap.Flag(f))
}
return flags
}
// IMAPSession implements imapserver.SessionIMAP4rev2.
type IMAPSession struct {
deps *Deps
clientIP string
user *models.User // set after Login
selectedMailbox *models.Mailbox
msgs []msgEntry // index+1 = IMAP sequence number
mboxTracker *imapserver.MailboxTracker
sessionTracker *imapserver.SessionTracker
}
var _ imapserver.SessionIMAP4rev2 = (*IMAPSession)(nil)
// ---- Authentication ----
func (s *IMAPSession) Close() error {
if s.sessionTracker != nil {
s.sessionTracker.Close()
}
return nil
}
func (s *IMAPSession) Login(username, password string) error {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
user, err := s.deps.DB.GetUserByEmail(ctx, username)
if err != nil {
log.Printf("[imap] login error %s: %v", username, err)
return imapserver.ErrAuthFailed
}
if user == nil || !user.Enabled {
s.recordAttempt(ctx, username, false)
return imapserver.ErrAuthFailed
}
if err := crypto.CheckPassword(user.PasswordHash, password); err != nil {
log.Printf("[imap] auth failed %s from %s", username, s.clientIP)
s.recordAttempt(ctx, username, false)
return imapserver.ErrAuthFailed
}
s.user = user
s.deps.DB.UpdateLastLogin(ctx, user.ID) //nolint:errcheck
s.recordAttempt(ctx, username, true)
log.Printf("[imap] auth OK %s from %s", username, s.clientIP)
return nil
}
func (s *IMAPSession) recordAttempt(ctx context.Context, email string, success bool) {
if s.deps.Brute != nil {
s.deps.Brute.RecordAttempt(ctx, s.clientIP, email, success) //nolint:errcheck
}
}
// ---- Mailbox management ----
func (s *IMAPSession) Select(mailboxName string, options *imap.SelectOptions) (*imap.SelectData, error) {
s.doUnselect()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
mbox, err := s.deps.DB.GetMailbox(ctx, s.user.ID, mailboxName)
if err != nil {
return nil, imapErr(err)
}
if mbox == nil {
return nil, noSuchMailbox()
}
msgs, err := s.deps.DB.ListIMAPMessages(ctx, mbox.ID)
if err != nil {
return nil, imapErr(err)
}
s.selectedMailbox = mbox
s.msgs = make([]msgEntry, 0, len(msgs))
for _, m := range msgs {
s.msgs = append(s.msgs, msgEntry{
dbID: m.ID,
uid: imap.UID(m.UID),
isRead: m.IsRead,
isStarred: m.IsStarred,
isDraft: m.IsDraft,
extraFlags: m.Flags,
size: m.SizeBytes,
internalDate: m.ReceivedAt,
})
}
s.mboxTracker = imapserver.NewMailboxTracker(uint32(len(s.msgs)))
s.sessionTracker = s.mboxTracker.NewSession()
var firstUnseen uint32
for i, m := range s.msgs {
if !m.isRead {
firstUnseen = uint32(i + 1)
break
}
}
return &imap.SelectData{
Flags: allFlags(),
PermanentFlags: allFlags(),
NumMessages: uint32(len(s.msgs)),
FirstUnseenSeqNum: firstUnseen,
UIDNext: imap.UID(mbox.UIDNext),
UIDValidity: mbox.UIDValidity,
List: &imap.ListData{
Mailbox: mailboxName,
Delim: mailboxDelim,
Attrs: mboxAttrs(mbox),
},
}, nil
}
func (s *IMAPSession) Unselect() error {
s.doUnselect()
return nil
}
func (s *IMAPSession) doUnselect() {
if s.sessionTracker != nil {
s.sessionTracker.Close()
s.sessionTracker = nil
}
s.selectedMailbox = nil
s.msgs = nil
s.mboxTracker = nil
}
func (s *IMAPSession) Create(mailboxName string, options *imap.CreateOptions) error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
existing, _ := s.deps.DB.GetMailbox(ctx, s.user.ID, mailboxName)
if existing != nil {
return &imap.Error{Type: imap.StatusResponseTypeNo, Code: imap.ResponseCodeAlreadyExists, Text: "Mailbox already exists"}
}
_, err := s.deps.DB.CreateMailbox(ctx, s.user.ID, mailboxName, "", nil)
return imapErr(err)
}
func (s *IMAPSession) Delete(mailboxName string) error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
mbox, _ := s.deps.DB.GetMailbox(ctx, s.user.ID, mailboxName)
if mbox == nil {
return noSuchMailbox()
}
if mbox.Type != "" {
return &imap.Error{Type: imap.StatusResponseTypeNo, Text: "Cannot delete system mailbox"}
}
_, err := s.deps.DB.SQL().ExecContext(ctx, "DELETE FROM mailboxes WHERE id=?", mbox.ID)
return imapErr(err)
}
func (s *IMAPSession) Rename(oldName, newName string, options *imap.RenameOptions) error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
mbox, _ := s.deps.DB.GetMailbox(ctx, s.user.ID, oldName)
if mbox == nil {
return noSuchMailbox()
}
return imapErr(s.deps.DB.RenameMailbox(ctx, mbox.ID, newName))
}
func (s *IMAPSession) Subscribe(mailboxName string) error {
return s.setSubscribed(mailboxName, true)
}
func (s *IMAPSession) Unsubscribe(mailboxName string) error {
return s.setSubscribed(mailboxName, false)
}
func (s *IMAPSession) setSubscribed(name string, sub bool) error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
mbox, _ := s.deps.DB.GetMailbox(ctx, s.user.ID, name)
if mbox == nil {
return noSuchMailbox()
}
return imapErr(s.deps.DB.SetMailboxSubscribed(ctx, mbox.ID, sub))
}
func (s *IMAPSession) List(w *imapserver.ListWriter, ref string, patterns []string, options *imap.ListOptions) error {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if len(patterns) == 0 {
return w.WriteList(&imap.ListData{
Attrs: []imap.MailboxAttr{imap.MailboxAttrNoSelect},
Delim: mailboxDelim,
})
}
mailboxes, err := s.deps.DB.ListMailboxes(ctx, s.user.ID)
if err != nil {
return imapErr(err)
}
type entry struct {
name string
data imap.ListData
}
var entries []entry
for _, mbox := range mailboxes {
if options.SelectSubscribed && !mbox.Subscribed {
continue
}
matched := false
for _, pat := range patterns {
if imapserver.MatchList(mbox.Name, mailboxDelim, ref, pat) {
matched = true
break
}
}
if !matched {
continue
}
data := imap.ListData{
Mailbox: mbox.Name,
Delim: mailboxDelim,
Attrs: mboxAttrs(mbox),
}
if mbox.Subscribed {
data.Attrs = append(data.Attrs, imap.MailboxAttrSubscribed)
}
if options.ReturnStatus != nil {
sd, _ := s.statusFor(ctx, mbox, options.ReturnStatus)
data.Status = sd
}
entries = append(entries, entry{name: mbox.Name, data: data})
}
sort.Slice(entries, func(i, j int) bool { return entries[i].name < entries[j].name })
for _, e := range entries {
if err := w.WriteList(&e.data); err != nil {
return err
}
}
return nil
}
func (s *IMAPSession) Status(mailboxName string, options *imap.StatusOptions) (*imap.StatusData, error) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
mbox, _ := s.deps.DB.GetMailbox(ctx, s.user.ID, mailboxName)
if mbox == nil {
return nil, noSuchMailbox()
}
return s.statusFor(ctx, mbox, options)
}
func (s *IMAPSession) statusFor(ctx context.Context, mbox *models.Mailbox, opts *imap.StatusOptions) (*imap.StatusData, error) {
data := &imap.StatusData{Mailbox: mbox.Name}
total, unseen, err := s.deps.DB.GetMailboxMessageCounts(ctx, mbox.ID)
if err != nil {
return nil, imapErr(err)
}
if opts.NumMessages {
n := uint32(total)
data.NumMessages = &n
}
if opts.NumUnseen {
n := uint32(unseen)
data.NumUnseen = &n
}
if opts.UIDNext {
data.UIDNext = imap.UID(mbox.UIDNext)
}
if opts.UIDValidity {
data.UIDValidity = mbox.UIDValidity
}
if opts.NumRecent {
n := uint32(0)
data.NumRecent = &n
}
if opts.NumDeleted {
n := uint32(0)
data.NumDeleted = &n
}
if opts.Size {
sz, _ := s.deps.DB.GetMailboxSize(ctx, mbox.ID)
data.Size = &sz
}
return data, nil
}
func (s *IMAPSession) Append(mailboxName string, r imap.LiteralReader, options *imap.AppendOptions) (*imap.AppendData, error) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
mbox, _ := s.deps.DB.GetMailbox(ctx, s.user.ID, mailboxName)
if mbox == nil {
return nil, &imap.Error{Type: imap.StatusResponseTypeNo, Code: imap.ResponseCodeTryCreate, Text: "No such mailbox"}
}
raw, err := readLiteral(r, s.deps.Cfg.MaxMessageSize)
if err != nil {
return nil, imapErr(err)
}
key, err := s.deps.Crypt.DeriveKey("messages", s.user.ID)
if err != nil {
return nil, imapErr(fmt.Errorf("derive key: %w", err))
}
rawEnc, err := crypto.Encrypt(key, raw)
if err != nil {
return nil, imapErr(err)
}
uid, err := s.deps.DB.NextUID(ctx, mbox.ID)
if err != nil {
return nil, imapErr(err)
}
internalDate := time.Now().UTC()
isRead, isDraft := false, false
var extraParts []string
if options != nil {
if !options.Time.IsZero() {
internalDate = options.Time
}
for _, f := range options.Flags {
switch f {
case imap.FlagSeen:
isRead = true
case imap.FlagDraft:
isDraft = true
default:
extraParts = append(extraParts, string(f))
}
}
}
ins := &db.MessageInsert{
MailboxID: mbox.ID,
UID: uid,
SizeBytes: int64(len(raw)),
RawEnc: rawEnc,
IsRead: isRead,
IsDraft: isDraft,
Flags: strings.Join(extraParts, " "),
Date: internalDate,
}
if _, err := s.deps.DB.InsertMessage(ctx, ins); err != nil {
return nil, imapErr(err)
}
return &imap.AppendData{
UIDValidity: mbox.UIDValidity,
UID: imap.UID(uid),
}, nil
}
func (s *IMAPSession) Poll(w *imapserver.UpdateWriter, allowExpunge bool) error {
if s.sessionTracker == nil {
return nil
}
return s.sessionTracker.Poll(w, allowExpunge)
}
func (s *IMAPSession) Idle(w *imapserver.UpdateWriter, stop <-chan struct{}) error {
if s.sessionTracker == nil {
return nil
}
return s.sessionTracker.Idle(w, stop)
}
// ---- Selected-state operations ----
func (s *IMAPSession) Expunge(w *imapserver.ExpungeWriter, uids *imap.UIDSet) error {
if s.selectedMailbox == nil {
return noSelectedMailbox()
}
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
// Soft-delete entries marked \Deleted (filtered by UIDs if provided).
for i := range s.msgs {
m := &s.msgs[i]
if uids != nil && !uids.Contains(m.uid) {
continue
}
if m.isDeleted {
s.deps.DB.SoftDeleteMessage(ctx, m.dbID) //nolint:errcheck
}
}
deletedUIDs, err := s.deps.DB.HardDeleteMessages(ctx, s.selectedMailbox.ID)
if err != nil {
return imapErr(err)
}
deletedSet := make(map[imap.UID]struct{}, len(deletedUIDs))
for _, uid := range deletedUIDs {
deletedSet[imap.UID(uid)] = struct{}{}
}
// Collect seq nums to expunge (in ascending order, write in reverse).
var seqNums []uint32
for i, m := range s.msgs {
if _, ok := deletedSet[m.uid]; ok {
seqNums = append(seqNums, uint32(i+1))
}
}
for i := len(seqNums) - 1; i >= 0; i-- {
if err := w.WriteExpunge(seqNums[i]); err != nil {
return err
}
}
// Remove from in-memory list.
filtered := s.msgs[:0]
for _, m := range s.msgs {
if _, ok := deletedSet[m.uid]; !ok {
filtered = append(filtered, m)
}
}
s.msgs = filtered
return nil
}
func (s *IMAPSession) Search(kind imapserver.NumKind, criteria *imap.SearchCriteria, options *imap.SearchOptions) (*imap.SearchData, error) {
if s.selectedMailbox == nil {
return nil, noSelectedMailbox()
}
if kind == imapserver.NumKindUID {
var result imap.UIDSet
for i, m := range s.msgs {
seqNum := uint32(i + 1)
if s.matchCriteria(seqNum, &m, criteria) {
result.AddNum(m.uid)
}
}
return &imap.SearchData{All: result}, nil
}
var result imap.SeqSet
for i, m := range s.msgs {
seqNum := uint32(i + 1)
if s.matchCriteria(seqNum, &m, criteria) {
result.AddNum(seqNum)
}
}
return &imap.SearchData{All: result}, nil
}
func (s *IMAPSession) matchCriteria(seqNum uint32, m *msgEntry, c *imap.SearchCriteria) bool {
if c == nil {
return true
}
for _, seqSet := range c.SeqNum {
if !seqSet.Contains(seqNum) {
return false
}
}
for _, uidSet := range c.UID {
if !uidSet.Contains(m.uid) {
return false
}
}
flagSet := make(map[imap.Flag]struct{})
for _, f := range m.flagList() {
flagSet[f] = struct{}{}
}
for _, f := range c.Flag {
if _, ok := flagSet[f]; !ok {
return false
}
}
for _, f := range c.NotFlag {
if _, ok := flagSet[f]; ok {
return false
}
}
if c.Larger != 0 && m.size <= c.Larger {
return false
}
if c.Smaller != 0 && m.size >= c.Smaller {
return false
}
if !matchDate(m.internalDate, c.Since, c.Before) {
return false
}
for _, sub := range c.Not {
if s.matchCriteria(seqNum, m, &sub) {
return false
}
}
for _, or := range c.Or {
if !s.matchCriteria(seqNum, m, &or[0]) && !s.matchCriteria(seqNum, m, &or[1]) {
return false
}
}
return true
}
func (s *IMAPSession) Fetch(w *imapserver.FetchWriter, numSet imap.NumSet, options *imap.FetchOptions) error {
if s.selectedMailbox == nil {
return noSelectedMailbox()
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
key, err := s.deps.Crypt.DeriveKey("messages", s.user.ID)
if err != nil {
return imapErr(err)
}
for i := range s.msgs {
m := &s.msgs[i]
seqNum := uint32(i + 1)
if !numSetContains(numSet, seqNum, m.uid) {
continue
}
rw := w.CreateMessage(seqNum)
rw.WriteUID(m.uid)
if options.Flags {
rw.WriteFlags(m.flagList())
}
if options.InternalDate {
rw.WriteInternalDate(m.internalDate)
}
if options.RFC822Size {
rw.WriteRFC822Size(m.size)
}
needRaw := options.Envelope || options.BodyStructure != nil ||
len(options.BodySection) > 0 || len(options.BinarySection) > 0 ||
len(options.BinarySectionSize) > 0
if needRaw {
rawEnc, err := s.deps.DB.GetMessageRaw(ctx, m.dbID)
if err != nil || rawEnc == nil {
rw.Close() //nolint:errcheck
continue
}
raw, err := crypto.Decrypt(key, rawEnc)
if err != nil {
log.Printf("[imap] decrypt %d: %v", m.dbID, err)
rw.Close() //nolint:errcheck
continue
}
if options.Envelope {
if env := extractEnvelope(raw); env != nil {
rw.WriteEnvelope(env)
}
}
if options.BodyStructure != nil {
rw.WriteBodyStructure(imapserver.ExtractBodyStructure(bytes.NewReader(raw)))
}
for _, bs := range options.BodySection {
buf := imapserver.ExtractBodySection(bytes.NewReader(raw), bs)
wc := rw.WriteBodySection(bs, int64(len(buf)))
wc.Write(buf) //nolint:errcheck
wc.Close() //nolint:errcheck
}
for _, bs := range options.BinarySection {
buf := imapserver.ExtractBinarySection(bytes.NewReader(raw), bs)
wc := rw.WriteBinarySection(bs, int64(len(buf)))
wc.Write(buf) //nolint:errcheck
wc.Close() //nolint:errcheck
}
for _, bss := range options.BinarySectionSize {
n := imapserver.ExtractBinarySectionSize(bytes.NewReader(raw), bss)
rw.WriteBinarySectionSize(bss, n)
}
}
if err := rw.Close(); err != nil {
return err
}
}
return nil
}
func (s *IMAPSession) Store(w *imapserver.FetchWriter, numSet imap.NumSet, flags *imap.StoreFlags, options *imap.StoreOptions) error {
if s.selectedMailbox == nil {
return noSelectedMailbox()
}
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
for i := range s.msgs {
m := &s.msgs[i]
seqNum := uint32(i + 1)
if !numSetContains(numSet, seqNum, m.uid) {
continue
}
applyStoreFlags(m, flags)
s.deps.DB.SetMessageFlags(ctx, m.dbID, m.isRead, m.isStarred, m.isDraft, m.extraFlags) //nolint:errcheck
if m.isDeleted {
s.deps.DB.SoftDeleteMessage(ctx, m.dbID) //nolint:errcheck
}
if !flags.Silent {
rw := w.CreateMessage(seqNum)
rw.WriteUID(m.uid)
rw.WriteFlags(m.flagList())
if err := rw.Close(); err != nil {
return err
}
}
}
return nil
}
func (s *IMAPSession) Copy(numSet imap.NumSet, destName string) (*imap.CopyData, error) {
if s.selectedMailbox == nil {
return nil, noSelectedMailbox()
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
dest, _ := s.deps.DB.GetMailbox(ctx, s.user.ID, destName)
if dest == nil {
return nil, &imap.Error{Type: imap.StatusResponseTypeNo, Code: imap.ResponseCodeTryCreate, Text: "No such mailbox"}
}
var sourceUIDs, destUIDs imap.UIDSet
for i, m := range s.msgs {
seqNum := uint32(i + 1)
if !numSetContains(numSet, seqNum, m.uid) {
continue
}
newUID, err := s.deps.DB.CopyMessageToMailbox(ctx, m.dbID, dest.ID, s.user.ID)
if err != nil {
log.Printf("[imap] copy msg %d: %v", m.dbID, err)
continue
}
sourceUIDs.AddNum(m.uid)
destUIDs.AddNum(imap.UID(newUID))
}
return &imap.CopyData{
UIDValidity: dest.UIDValidity,
SourceUIDs: sourceUIDs,
DestUIDs: destUIDs,
}, nil
}
func (s *IMAPSession) Move(w *imapserver.MoveWriter, numSet imap.NumSet, destName string) error {
copyData, err := s.Copy(numSet, destName)
if err != nil {
return err
}
if err := w.WriteCopyData(copyData); err != nil {
return err
}
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
var seqNums []uint32
for i := range s.msgs {
m := &s.msgs[i]
seqNum := uint32(i + 1)
if !numSetContains(numSet, seqNum, m.uid) {
continue
}
s.deps.DB.SoftDeleteMessage(ctx, m.dbID) //nolint:errcheck
seqNums = append(seqNums, seqNum)
}
s.deps.DB.HardDeleteMessages(ctx, s.selectedMailbox.ID) //nolint:errcheck
for i := len(seqNums) - 1; i >= 0; i-- {
if err := w.WriteExpunge(seqNums[i]); err != nil {
return err
}
}
// Remove moved messages.
kept := s.msgs[:0]
for i, m := range s.msgs {
seqNum := uint32(i + 1)
if !numSetContains(numSet, seqNum, m.uid) {
kept = append(kept, m)
}
}
s.msgs = kept
return nil
}
func (s *IMAPSession) Namespace() (*imap.NamespaceData, error) {
return &imap.NamespaceData{
Personal: []imap.NamespaceDescriptor{{Delim: mailboxDelim}},
}, nil
}
// ---- Helpers ----
// numSetContains checks seqNum (for SeqSet) or uid (for UIDSet).
func numSetContains(numSet imap.NumSet, seqNum uint32, uid imap.UID) bool {
switch ns := numSet.(type) {
case imap.SeqSet:
return ns.Contains(seqNum)
case imap.UIDSet:
return ns.Contains(uid)
}
return false
}
func applyStoreFlags(m *msgEntry, store *imap.StoreFlags) {
flagMap := map[imap.Flag]*bool{
imap.FlagSeen: &m.isRead,
imap.FlagFlagged: &m.isStarred,
imap.FlagDraft: &m.isDraft,
imap.FlagDeleted: &m.isDeleted,
}
switch store.Op {
case imap.StoreFlagsSet:
m.isRead, m.isStarred, m.isDraft, m.isDeleted = false, false, false, false
m.extraFlags = ""
for _, f := range store.Flags {
if ptr, ok := flagMap[f]; ok {
*ptr = true
} else {
m.extraFlags = strings.TrimSpace(m.extraFlags + " " + string(f))
}
}
case imap.StoreFlagsAdd:
for _, f := range store.Flags {
if ptr, ok := flagMap[f]; ok {
*ptr = true
} else if !strings.Contains(m.extraFlags, string(f)) {
m.extraFlags = strings.TrimSpace(m.extraFlags + " " + string(f))
}
}
case imap.StoreFlagsDel:
for _, f := range store.Flags {
if ptr, ok := flagMap[f]; ok {
*ptr = false
} else {
m.extraFlags = strings.TrimSpace(strings.ReplaceAll(m.extraFlags, string(f), ""))
}
}
}
}
func allFlags() []imap.Flag {
return []imap.Flag{
imap.FlagSeen,
imap.FlagAnswered,
imap.FlagFlagged,
imap.FlagDeleted,
imap.FlagDraft,
}
}
func mboxAttrs(mbox *models.Mailbox) []imap.MailboxAttr {
su := db.MailboxTypeToSpecialUse(mbox.Type)
if su != "" {
return []imap.MailboxAttr{imap.MailboxAttr(su)}
}
return nil
}
func extractEnvelope(raw []byte) *imap.Envelope {
br := bufio.NewReader(bytes.NewReader(raw))
header, err := textproto.ReadHeader(br)
if err != nil {
return nil
}
return imapserver.ExtractEnvelope(header)
}
func matchDate(t, since, before time.Time) bool {
t = time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC)
if !since.IsZero() && t.Before(since) {
return false
}
if !before.IsZero() && !t.Before(before) {
return false
}
return true
}
func readLiteral(r imap.LiteralReader, maxSize int64) ([]byte, error) {
buf := make([]byte, 0, 4096)
tmp := make([]byte, 32768)
var total int64
for {
n, err := r.Read(tmp)
if n > 0 {
total += int64(n)
if total > maxSize {
return nil, fmt.Errorf("message too large")
}
buf = append(buf, tmp[:n]...)
}
if err != nil {
break
}
}
return buf, nil
}
func imapErr(err error) error {
if err == nil {
return nil
}
if _, ok := err.(*imap.Error); ok {
return err
}
return &imap.Error{Type: imap.StatusResponseTypeNo, Text: "server error"}
}
func noSuchMailbox() error {
return &imap.Error{Type: imap.StatusResponseTypeNo, Code: imap.ResponseCodeNonExistent, Text: "No such mailbox"}
}
func noSelectedMailbox() error {
return &imap.Error{Type: imap.StatusResponseTypeBad, Text: "no mailbox selected"}
}
+295
View File
@@ -0,0 +1,295 @@
// Package models defines all shared data structures.
package models
import "time"
// ---- Domain & User ----
type Domain struct {
ID int64
Name string
Enabled bool
DKIMPrivateEnc []byte // AES-256-GCM encrypted PEM private key
DKIMPublic string // PEM public key (not secret)
DKIMSelector string
DKIMAlgo string // rsa2048 | ed25519
SPFPolicy string // stored for reference
DMARCPolicy string // stored for reference
MaxUsers int
MaxQuotaBytes int64
CreatedAt time.Time
}
type User struct {
ID int64
DomainID int64
Username string // local part (before @)
Email string // full address
PasswordHash string // bcrypt
DisplayName string
QuotaBytes int64
UsedBytes int64
Enabled bool
Admin bool // global admin
DomainAdmin bool // admin of own domain only
MFASecretEnc []byte // encrypted TOTP secret; nil = MFA disabled
MFAEnabled bool
RecoveryCodesEnc []byte // encrypted JSON array of one-time codes
CreatedAt time.Time
LastLogin time.Time
}
type UserAlias struct {
ID int64
UserID int64
AliasEmail string
}
// ---- Mail storage ----
// MailboxType canonical values.
const (
MailboxInbox = "inbox"
MailboxSent = "sent"
MailboxDrafts = "drafts"
MailboxTrash = "trash"
MailboxSpam = "spam"
MailboxArchive = "archive"
MailboxCustom = "custom"
)
type Mailbox struct {
ID int64
UserID int64
Name string // IMAP mailbox name (e.g. "INBOX", "Sent", "Folder/Sub")
Type string // MailboxInbox … MailboxCustom
ParentID *int64
UIDValidity uint32
UIDNext uint32
Subscribed bool
CreatedAt time.Time
}
type Message struct {
ID int64
MailboxID int64
UID uint32
MessageID string // RFC 2822 Message-ID header (not encrypted — needed for threading)
Subject string // plaintext for list/search (consider sensitivity vs usability)
FromEmail string
FromName string
ToList string // comma-separated
CCList string
BCCList string
ReplyTo string
Date time.Time
BodyTextEnc []byte // AES-256-GCM encrypted text/plain
BodyHTMLEnc []byte // AES-256-GCM encrypted text/html
RawEnc []byte // AES-256-GCM encrypted full RFC822 message
SizeBytes int64
HasAttachment bool
IsRead bool
IsStarred bool
IsDraft bool
Flags string // raw IMAP flags string
SpamScore int
ReceivedAt time.Time
DeletedAt *time.Time // soft delete; nil = not deleted
}
type Attachment struct {
ID int64
MessageID int64
Filename string
ContentType string
SizeBytes int64
DataEnc []byte // AES-256-GCM encrypted bytes (or nil if fs-backed)
DataPath string // filesystem path (if STORAGE_BACKEND=fs)
ContentID string // MIME Content-ID for inline images
Inline bool
MIMEPath string // dot-separated MIME section path e.g. "1.2"
}
// ---- Delivery queue ----
const (
QueuePending = "pending"
QueueSending = "sending"
QueueDelivered = "delivered"
QueueFailed = "failed"
QueueBounced = "bounced"
)
type QueueEntry struct {
ID int64
DomainID int64
FromAddr string
ToAddr string
RawEnc []byte // encrypted RFC822
MessageID string
Status string // QueuePending … QueueBounced
Attempts int
LastAttempt *time.Time
NextAttempt time.Time
ErrorLog string
CreatedAt time.Time
ExpiresAt time.Time
}
type DeliveryLog struct {
ID int64
QueueID int64
FromAddr string
ToAddr string
Status string
SMTPCode int
SMTPMessage string
MXHost string
CreatedAt time.Time
}
// ---- Sessions & Security ----
type Session struct {
ID int64
UserID int64
TokenHash string // SHA-256 hex of the raw bearer token
IP string
UserAgent string
CreatedAt time.Time
ExpiresAt time.Time
}
type IPBan struct {
ID int64
IP string
Reason string
BannedAt time.Time
ExpiresAt *time.Time // nil = permanent
ReleasedBy string
}
type LoginAttempt struct {
ID int64
IP string
UserEmail string
Success bool
CreatedAt time.Time
}
type SecurityEvent struct {
ID int64
Type string // brute_ban | auth_fail | relay_attempt | etc.
IP string
UserID *int64
Detail string
CreatedAt time.Time
}
// ---- External accounts (Gmail / Outlook / custom IMAP) ----
const (
ProviderGmail = "gmail"
ProviderOutlook = "outlook"
ProviderCustom = "custom"
)
type ExternalAccount struct {
ID int64
UserID int64
Provider string // ProviderGmail | ProviderOutlook | ProviderCustom
EmailAddress string
DisplayName string
AccessTokenEnc []byte // encrypted OAuth2 or password
RefreshTokenEnc []byte // encrypted OAuth2 refresh token
TokenExpiry time.Time
IMAPHost string
IMAPPort int
SMTPHost string
SMTPPort int
Enabled bool
SyncEnabled bool
LastSync time.Time
CreatedAt time.Time
}
// ---- Contacts (CardDAV) ----
type AddressBook struct {
ID int64
UserID int64
Name string
Description string
Color string
SyncToken int64
CreatedAt time.Time
}
type Contact struct {
ID int64
AddressBookID int64
UID string
VCardEnc []byte // AES-256-GCM encrypted vCard 3.0/4.0
ETag string // hex(sha256(raw vcard)) — for sync
CreatedAt time.Time
UpdatedAt time.Time
}
// ---- Calendar (CalDAV) ----
type Calendar struct {
ID int64
UserID int64
Name string
Description string
Color string
Timezone string
SyncToken int64
CreatedAt time.Time
}
type CalendarEvent struct {
ID int64
CalendarID int64
UID string
ICalEnc []byte // AES-256-GCM encrypted iCalendar data
ETag string // hex(sha256(raw ical))
DTStart time.Time
DTEnd time.Time
Summary string // plaintext for calendar grid display
Recurring bool
CreatedAt time.Time
UpdatedAt time.Time
}
// ---- Spam / Bayesian ----
type SpamToken struct {
ID int64
UserID int64
Token string
SpamCount int64
HamCount int64
}
// ---- Compose helpers (not persisted directly) ----
type Attachment_Upload struct {
Filename string
ContentType string
Data []byte
}
type ComposeRequest struct {
AccountID int64 // 0 = local account, >0 = external account ID
FromEmail string
To []string
CC []string
BCC []string
Subject string
BodyText string
BodyHTML string
Attachments []Attachment_Upload
InReplyTo string
References string
}
+265
View File
@@ -0,0 +1,265 @@
package smtp
import (
"context"
"fmt"
"io"
"log"
"net"
"net/mail"
"strings"
"time"
gosmtp "github.com/emersion/go-smtp"
"ghb.freebede.com/nahakubuilder/mailgosend/internal/dkim"
"ghb.freebede.com/nahakubuilder/mailgosend/internal/models"
"ghb.freebede.com/nahakubuilder/mailgosend/internal/spam"
"ghb.freebede.com/nahakubuilder/mailgosend/internal/spf"
"ghb.freebede.com/nahakubuilder/mailgosend/internal/storage"
)
// InboundBackend implements gosmtp.Backend for port 25 (receive from internet).
type InboundBackend struct {
deps *Deps
}
func (b *InboundBackend) NewSession(c *gosmtp.Conn) (gosmtp.Session, error) {
clientIP, _, _ := net.SplitHostPort(c.Conn().RemoteAddr().String())
// Check IP ban.
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
var banned bool
var err error
if b.deps.Cfg.BruteMaxTries > 0 && b.deps.Brute != nil {
banned, err = b.deps.Brute.IsBanned(ctx, clientIP)
if err != nil {
log.Printf("[smtp/inbound] ban check error %s: %v", clientIP, err)
}
}
if banned {
return nil, fmt.Errorf("connection refused")
}
return &InboundSession{
deps: b.deps,
clientIP: net.ParseIP(clientIP),
log: func(f string, a ...any) { log.Printf("[smtp/inbound] "+f, a...) },
}, nil
}
// InboundSession handles one inbound SMTP connection.
type InboundSession struct {
deps *Deps
clientIP net.IP
from string
fromDomain string
rcpts []string // validated local recipients (email addresses)
rcptUsers []int64 // corresponding user IDs
spfResult spf.Result
log func(string, ...any)
}
// AuthPlain is not used on port 25; always return error.
func (s *InboundSession) AuthPlain(username, password string) error {
return fmt.Errorf("AUTH not supported on port 25")
}
func (s *InboundSession) Mail(from string, opts *gosmtp.MailOptions) error {
if from == "" {
// Bounce messages use empty envelope sender — allow.
s.from = ""
s.fromDomain = ""
return nil
}
addr, err := mail.ParseAddress(from)
if err != nil {
return &gosmtp.SMTPError{Code: 501, EnhancedCode: gosmtp.EnhancedCode{5, 1, 7}, Message: "invalid sender address"}
}
s.from = addr.Address
at := strings.LastIndex(s.from, "@")
if at < 0 {
return &gosmtp.SMTPError{Code: 501, EnhancedCode: gosmtp.EnhancedCode{5, 1, 7}, Message: "invalid sender domain"}
}
s.fromDomain = strings.ToLower(s.from[at+1:])
// SPF check (async, best-effort).
if s.deps.Cfg.SpamCheckSPF && s.clientIP != nil && s.fromDomain != "" {
s.spfResult, _ = spf.Check(s.clientIP, s.fromDomain)
}
return nil
}
func (s *InboundSession) Rcpt(to string, opts *gosmtp.RcptOptions) error {
addr, err := mail.ParseAddress(to)
if err != nil {
return &gosmtp.SMTPError{Code: 501, EnhancedCode: gosmtp.EnhancedCode{5, 1, 3}, Message: "invalid recipient"}
}
email := strings.ToLower(addr.Address)
at := strings.LastIndex(email, "@")
if at < 0 {
return &gosmtp.SMTPError{Code: 501, EnhancedCode: gosmtp.EnhancedCode{5, 1, 3}, Message: "invalid recipient"}
}
domain := email[at+1:]
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// Must be a local domain.
local, err := s.deps.DB.IsLocalDomain(ctx, domain)
if err != nil || !local {
return &gosmtp.SMTPError{Code: 550, EnhancedCode: gosmtp.EnhancedCode{5, 1, 2}, Message: "relay access denied"}
}
// Resolve to a user (handles aliases).
user, err := s.deps.DB.ResolveEmail(ctx, email)
if err != nil {
return &gosmtp.SMTPError{Code: 451, EnhancedCode: gosmtp.EnhancedCode{4, 3, 0}, Message: "temporary lookup error"}
}
if user == nil || !user.Enabled {
return &gosmtp.SMTPError{Code: 550, EnhancedCode: gosmtp.EnhancedCode{5, 1, 1}, Message: "user unknown"}
}
s.rcpts = append(s.rcpts, email)
s.rcptUsers = append(s.rcptUsers, user.ID)
return nil
}
func (s *InboundSession) Data(r io.Reader) error {
if len(s.rcpts) == 0 {
return &gosmtp.SMTPError{Code: 503, Message: "no recipients"}
}
// Read message with size cap (already enforced by go-smtp, but be safe).
raw, err := io.ReadAll(io.LimitReader(r, s.deps.Cfg.MaxMessageSize+1))
if err != nil {
return fmt.Errorf("read data: %w", err)
}
if int64(len(raw)) > s.deps.Cfg.MaxMessageSize {
return &gosmtp.SMTPError{Code: 552, EnhancedCode: gosmtp.EnhancedCode{5, 3, 4}, Message: "message too large"}
}
// Parse headers for metadata.
msg, err := mail.ReadMessage(strings.NewReader(string(raw)))
if err != nil {
return &gosmtp.SMTPError{Code: 550, EnhancedCode: gosmtp.EnhancedCode{5, 6, 0}, Message: "malformed message"}
}
subject := decodeHeader(msg.Header.Get("Subject"))
fromHeader := msg.Header.Get("From")
fromAddr, fromName := parseFromHeader(fromHeader)
if fromAddr == "" && s.from != "" {
fromAddr = s.from
}
msgID := msg.Header.Get("Message-ID")
dateStr := msg.Header.Get("Date")
msgDate, _ := mail.ParseDate(dateStr)
if msgDate.IsZero() {
msgDate = time.Now().UTC()
}
// DKIM verification.
dkimValid := false
dkimPresent := strings.Contains(string(raw), "DKIM-Signature:")
if dkimPresent {
_, dkimErr := dkim.Verify(raw)
dkimValid = dkimErr == nil
}
// Spam scoring (per recipient — each has their own Bayesian model).
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// Deliver to each local recipient.
for i, userID := range s.rcptUsers {
rcptEmail := s.rcpts[i]
// Build spam score params.
params := &spam.Params{
ClientIP: s.clientIP,
SenderDomain: s.fromDomain,
SPFResult: s.spfResult,
DKIMValid: dkimValid,
DKIMPresent: dkimPresent,
Subject: subject,
FromHeader: fromHeader,
RecipCount: len(s.rcpts),
HasDateHeader: dateStr != "",
HasMsgIDHeader: msgID != "",
}
spamResult := s.deps.Scorer.Score(ctx, userID, params)
// Choose target mailbox.
mboxType := models.MailboxInbox
if spamResult.IsSpam {
mboxType = models.MailboxSpam
s.log("message for %s scored %d (spam), delivering to Spam", rcptEmail, spamResult.Total)
}
mbox, err := s.deps.DB.GetMailboxByType(ctx, userID, mboxType)
if err != nil || mbox == nil {
// Fallback to INBOX.
mbox, err = s.deps.DB.GetMailboxByType(ctx, userID, models.MailboxInbox)
if err != nil || mbox == nil {
s.log("no inbox for user %d (%s): %v", userID, rcptEmail, err)
continue
}
}
incoming := &storage.IncomingMessage{
Raw: raw,
FromEmail: fromAddr,
FromName: fromName,
ToList: strings.Join(s.rcpts, ", "),
Subject: subject,
Date: msgDate,
MessageID: msgID,
SpamScore: spamResult.Total,
}
_, err = s.deps.Store.SaveIncoming(ctx, userID, mbox.ID, incoming)
if err != nil {
s.log("store message for %s: %v", rcptEmail, err)
return &gosmtp.SMTPError{Code: 451, EnhancedCode: gosmtp.EnhancedCode{4, 3, 0}, Message: "storage error"}
}
s.log("delivered to %s (spam=%v score=%d)", rcptEmail, spamResult.IsSpam, spamResult.Total)
}
return nil
}
func (s *InboundSession) Reset() {
s.from = ""
s.fromDomain = ""
s.rcpts = s.rcpts[:0]
s.rcptUsers = s.rcptUsers[:0]
s.spfResult = spf.ResultNone
}
func (s *InboundSession) Logout() error { return nil }
// ---- Helpers ----
func parseFromHeader(h string) (addr, name string) {
if h == "" {
return "", ""
}
a, err := mail.ParseAddress(h)
if err != nil {
return h, ""
}
return a.Address, a.Name
}
func decodeHeader(h string) string {
// mime.WordDecoder would be imported from mime package. Keep it simple.
return h
}
+148
View File
@@ -0,0 +1,148 @@
package smtp
import (
"context"
"log"
"time"
"ghb.freebede.com/nahakubuilder/mailgosend/internal/crypto"
"ghb.freebede.com/nahakubuilder/mailgosend/internal/delivery"
)
// QueueWorker polls the delivery queue and dispatches messages.
type QueueWorker struct {
deps *Deps
interval time.Duration // how often to poll (default 30s)
}
// NewQueueWorker creates a QueueWorker backed by the given deps.
func NewQueueWorker(deps *Deps) *QueueWorker {
return &QueueWorker{
deps: deps,
interval: 30 * time.Second,
}
}
// Run polls the queue until stopCh is closed.
func (w *QueueWorker) Run(stopCh <-chan struct{}) {
log.Println("[queue] worker started")
ticker := time.NewTicker(w.interval)
defer ticker.Stop()
// Drain immediately on start.
w.drainQueue()
for {
select {
case <-stopCh:
log.Println("[queue] worker stopped")
return
case <-ticker.C:
w.drainQueue()
}
}
}
// drainQueue fetches all due entries and delivers them.
func (w *QueueWorker) drainQueue() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
entries, err := w.deps.DB.PeekQueue(ctx, 100)
if err != nil {
log.Printf("[queue] peek error: %v", err)
return
}
if len(entries) == 0 {
return
}
log.Printf("[queue] %d entries ready for delivery", len(entries))
// Derive the global queue decryption key once.
queueKey, err := w.deps.Crypt.DeriveKeyGlobal("queue")
if err != nil {
log.Printf("[queue] derive key: %v", err)
return
}
for _, entry := range entries {
// Mark in-progress to prevent parallel workers picking the same entry.
nextAttempt := time.Now().Add(5 * time.Minute) // safety: reset on success
if err := w.deps.DB.SetQueueStatus(ctx, entry.ID, "sending", "", &nextAttempt); err != nil {
log.Printf("[queue] mark sending %d: %v", entry.ID, err)
continue
}
raw, err := crypto.Decrypt(queueKey, entry.RawEnc)
if err != nil {
log.Printf("[queue] decrypt %d: %v", entry.ID, err)
w.markFailed(ctx, entry.ID, entry.FromAddr, entry.ToAddr, "decrypt error: "+err.Error(), true)
continue
}
ehlo := w.deps.Cfg.SMTPHostname
if ehlo == "" {
ehlo = w.deps.Cfg.Hostname
}
result := delivery.Deliver(ctx, ehlo, entry.FromAddr, entry.ToAddr, raw)
if result.SMTPCode == 250 {
// Success.
if err := w.deps.DB.SetQueueStatus(ctx, entry.ID, "delivered", "delivered", nil); err != nil {
log.Printf("[queue] mark delivered %d: %v", entry.ID, err)
}
w.deps.DB.LogDelivery(ctx, entry.ID, //nolint:errcheck
entry.FromAddr, entry.ToAddr,
"delivered", result.SMTPCode, result.Message, result.MXHost)
log.Printf("[queue] delivered %d: %s → %s via %s", entry.ID, entry.FromAddr, entry.ToAddr, result.MXHost)
continue
}
// Failure.
w.markFailed(ctx, entry.ID, entry.FromAddr, entry.ToAddr, result.Message, result.Perm)
}
}
// markFailed updates queue status with exponential back-off or marks permanent failure.
func (w *QueueWorker) markFailed(ctx context.Context, queueID int64, from, to, errMsg string, perm bool) {
status := "failed"
if perm {
status = "bounced"
}
var nextAttempt *time.Time
if !perm {
// Exponential back-off: 5m, 15m, 1h, 4h, 8h (then give up per expires_at).
backoff := w.nextBackoff(ctx, queueID)
t := time.Now().Add(backoff)
nextAttempt = &t
}
if err := w.deps.DB.SetQueueStatus(ctx, queueID, status, errMsg, nextAttempt); err != nil {
log.Printf("[queue] update status %d: %v", queueID, err)
}
w.deps.DB.LogDelivery(ctx, queueID, from, to, status, 0, errMsg, "") //nolint:errcheck
log.Printf("[queue] %s %d → %s: %s", status, queueID, to, errMsg)
}
// nextBackoff returns the back-off duration based on attempt count using
// the configured schedule (or a default).
func (w *QueueWorker) nextBackoff(ctx context.Context, queueID int64) time.Duration {
var attempts int
_ = w.deps.DB.SQL().QueryRowContext(ctx,
"SELECT attempts FROM queue WHERE id=?", queueID).Scan(&attempts)
schedule := w.deps.Cfg.QueueRetryMins
if len(schedule) == 0 {
// Default: 5m, 15m, 60m, 240m, 480m
schedule = []int{5, 15, 60, 240, 480}
}
idx := attempts
if idx >= len(schedule) {
idx = len(schedule) - 1
}
return time.Duration(schedule[idx]) * time.Minute
}
+104
View File
@@ -0,0 +1,104 @@
// Package smtp provides SMTP inbound (port 25) and submission (587/465) servers
// using github.com/emersion/go-smtp as the protocol layer.
package smtp
import (
"crypto/tls"
"fmt"
"log"
"time"
gosmtp "github.com/emersion/go-smtp"
"ghb.freebede.com/nahakubuilder/mailgosend/internal/auth"
"ghb.freebede.com/nahakubuilder/mailgosend/internal/config"
"ghb.freebede.com/nahakubuilder/mailgosend/internal/crypto"
"ghb.freebede.com/nahakubuilder/mailgosend/internal/db"
"ghb.freebede.com/nahakubuilder/mailgosend/internal/spam"
"ghb.freebede.com/nahakubuilder/mailgosend/internal/storage"
)
// Deps groups all dependencies shared by SMTP servers.
type Deps struct {
DB *db.DB
Crypt *crypto.Crypto
Store *storage.Store
Scorer *spam.Scorer
Brute *auth.BruteGuard
Cfg *config.Config
}
// NewInboundServer creates the SMTP server for port 25 (inbound MTA).
// Accepts mail for local domains from any sender. STARTTLS offered but not required.
func NewInboundServer(d *Deps, tlsCfg *tls.Config) *gosmtp.Server {
be := &InboundBackend{deps: d}
s := gosmtp.NewServer(be)
s.Addr = fmt.Sprintf("%s:%d", d.Cfg.SMTPIface, d.Cfg.SMTPPort)
s.Domain = d.Cfg.SMTPHostname
s.MaxMessageBytes = d.Cfg.MaxMessageSize
s.MaxRecipients = d.Cfg.MaxRcptPer
s.WriteTimeout = 5 * time.Minute
s.ReadTimeout = 5 * time.Minute
s.AllowInsecureAuth = false // AUTH not offered on port 25
if tlsCfg != nil {
// Setting TLSConfig enables STARTTLS automatically in go-smtp v0.24+.
s.TLSConfig = tlsCfg
}
return s
}
// NewSubmissionServer creates the SMTP server for port 587 (STARTTLS submission).
// Auth required. STARTTLS mandatory before AUTH.
func NewSubmissionServer(d *Deps, tlsCfg *tls.Config) *gosmtp.Server {
be := &SubmissionBackend{deps: d}
s := gosmtp.NewServer(be)
s.Addr = fmt.Sprintf("%s:%d", d.Cfg.SubmitIface, d.Cfg.SubmitPort)
s.Domain = d.Cfg.SMTPHostname
s.MaxMessageBytes = d.Cfg.MaxMessageSize
s.MaxRecipients = d.Cfg.MaxRcptPer
s.WriteTimeout = 5 * time.Minute
s.ReadTimeout = 5 * time.Minute
s.AllowInsecureAuth = false // auth only after STARTTLS
if tlsCfg != nil {
s.TLSConfig = tlsCfg
}
return s
}
// NewSMTPSServer creates the SMTP server for port 465 (implicit TLS submission).
func NewSMTPSServer(d *Deps, tlsCfg *tls.Config) *gosmtp.Server {
be := &SubmissionBackend{deps: d}
s := gosmtp.NewServer(be)
s.Addr = fmt.Sprintf("%s:%d", d.Cfg.SMTPIface, d.Cfg.SMTPSPort)
s.Domain = d.Cfg.SMTPHostname
s.MaxMessageBytes = d.Cfg.MaxMessageSize
s.MaxRecipients = d.Cfg.MaxRcptPer
s.WriteTimeout = 5 * time.Minute
s.ReadTimeout = 5 * time.Minute
s.AllowInsecureAuth = false
if tlsCfg != nil {
s.TLSConfig = tlsCfg
}
return s
}
// ListenAndServe starts a server and logs errors.
func ListenAndServe(s *gosmtp.Server, name string) error {
log.Printf("[%s] listening on %s", name, s.Addr)
return s.ListenAndServe()
}
// ListenAndServeTLS starts a server with implicit TLS (port 465).
func ListenAndServeTLS(s *gosmtp.Server, name string) error {
log.Printf("[%s] listening on %s (TLS)", name, s.Addr)
return s.ListenAndServeTLS()
}
+281
View File
@@ -0,0 +1,281 @@
package smtp
import (
"bytes"
"context"
"fmt"
"io"
"log"
"net"
"net/mail"
"strings"
"time"
gosmtp "github.com/emersion/go-smtp"
"ghb.freebede.com/nahakubuilder/mailgosend/internal/crypto"
"ghb.freebede.com/nahakubuilder/mailgosend/internal/dkim"
"ghb.freebede.com/nahakubuilder/mailgosend/internal/models"
"ghb.freebede.com/nahakubuilder/mailgosend/internal/storage"
)
// SubmissionBackend implements gosmtp.Backend for ports 587/465.
// Requires authenticated users. Signs outbound mail with DKIM. Queues for delivery.
type SubmissionBackend struct {
deps *Deps
}
func (b *SubmissionBackend) NewSession(c *gosmtp.Conn) (gosmtp.Session, error) {
clientIP, _, _ := net.SplitHostPort(c.Conn().RemoteAddr().String())
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
var banned bool
if b.deps.Brute != nil {
banned, _ = b.deps.Brute.IsBanned(ctx, clientIP)
}
if banned {
return nil, fmt.Errorf("connection refused")
}
return &SubmissionSession{
deps: b.deps,
clientIP: clientIP,
}, nil
}
// SubmissionSession handles one authenticated submission connection.
type SubmissionSession struct {
deps *Deps
clientIP string
user *models.User // set after AUTH
from string
rcpts []string
}
func (s *SubmissionSession) AuthPlain(username, password string) error {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
user, err := s.deps.DB.GetUserByEmail(ctx, username)
if err != nil {
log.Printf("[smtp/submission] auth lookup error %s: %v", username, err)
return &gosmtp.SMTPError{Code: 535, Message: "authentication failed"}
}
if user == nil || !user.Enabled {
s.logAttempt(ctx, username, false)
return &gosmtp.SMTPError{Code: 535, Message: "authentication failed"}
}
if err := crypto.CheckPassword(user.PasswordHash, password); err != nil {
s.logAttempt(ctx, username, false)
log.Printf("[smtp/submission] auth failed for %s from %s", username, s.clientIP)
return &gosmtp.SMTPError{Code: 535, Message: "authentication failed"}
}
s.user = user
s.deps.DB.UpdateLastLogin(ctx, user.ID)
log.Printf("[smtp/submission] auth OK for %s from %s", username, s.clientIP)
return nil
}
func (s *SubmissionSession) Mail(from string, opts *gosmtp.MailOptions) error {
if s.user == nil {
return &gosmtp.SMTPError{Code: 530, Message: "authentication required"}
}
addr, err := mail.ParseAddress(from)
if err != nil {
return &gosmtp.SMTPError{Code: 501, EnhancedCode: gosmtp.EnhancedCode{5, 1, 7}, Message: "invalid sender"}
}
// Sender must be user's own address or an alias they own.
fromEmail := strings.ToLower(addr.Address)
if !strings.EqualFold(fromEmail, s.user.Email) {
// Check aliases.
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
resolved, err := s.deps.DB.ResolveEmail(ctx, fromEmail)
if err != nil || resolved == nil || resolved.ID != s.user.ID {
return &gosmtp.SMTPError{Code: 553, EnhancedCode: gosmtp.EnhancedCode{5, 1, 8}, Message: "sender not owned by authenticated user"}
}
}
s.from = addr.Address
return nil
}
func (s *SubmissionSession) Rcpt(to string, opts *gosmtp.RcptOptions) error {
if s.user == nil {
return &gosmtp.SMTPError{Code: 530, Message: "authentication required"}
}
addr, err := mail.ParseAddress(to)
if err != nil {
return &gosmtp.SMTPError{Code: 501, EnhancedCode: gosmtp.EnhancedCode{5, 1, 3}, Message: "invalid recipient"}
}
s.rcpts = append(s.rcpts, addr.Address)
return nil
}
func (s *SubmissionSession) Data(r io.Reader) error {
if s.user == nil {
return &gosmtp.SMTPError{Code: 530, Message: "authentication required"}
}
if len(s.rcpts) == 0 {
return &gosmtp.SMTPError{Code: 503, Message: "no recipients"}
}
raw, err := io.ReadAll(io.LimitReader(r, s.deps.Cfg.MaxMessageSize+1))
if err != nil {
return fmt.Errorf("read submission data: %w", err)
}
if int64(len(raw)) > s.deps.Cfg.MaxMessageSize {
return &gosmtp.SMTPError{Code: 552, EnhancedCode: gosmtp.EnhancedCode{5, 3, 4}, Message: "message too large"}
}
// Parse for basic header validation.
_, err = mail.ReadMessage(bytes.NewReader(raw))
if err != nil {
return &gosmtp.SMTPError{Code: 550, EnhancedCode: gosmtp.EnhancedCode{5, 6, 0}, Message: "malformed message"}
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// DKIM-sign the message if the sender's domain has keys configured.
senderDomain := domainOf(s.from)
raw = s.signDKIM(ctx, raw, senderDomain)
msgID := extractMsgID(raw)
// Queue each recipient for delivery.
// For local recipients we could deliver directly, but queuing is simpler and
// provides a consistent audit trail.
dom, err := s.deps.DB.GetDomain(ctx, senderDomain)
var domainID int64
if err == nil && dom != nil {
domainID = dom.ID
}
// Encrypt raw for queue storage using a global (non-user) key.
queueKey, err := s.deps.Crypt.DeriveKeyGlobal("queue")
if err != nil {
return fmt.Errorf("queue key: %w", err)
}
rawEnc, err := crypto.Encrypt(queueKey, raw)
if err != nil {
return fmt.Errorf("encrypt for queue: %w", err)
}
for _, rcpt := range s.rcpts {
_, err := s.deps.DB.EnqueueMessage(ctx, domainID, s.from, rcpt, msgID, rawEnc, s.deps.Cfg.QueueMaxAgeHours)
if err != nil {
log.Printf("[smtp/submission] enqueue to %s: %v", rcpt, err)
return &gosmtp.SMTPError{Code: 451, EnhancedCode: gosmtp.EnhancedCode{4, 3, 0}, Message: "queue error"}
}
log.Printf("[smtp/submission] queued %s → %s", s.from, rcpt)
}
// Also save a copy in sender's Sent folder.
s.saveSentCopy(ctx, raw)
return nil
}
func (s *SubmissionSession) Reset() {
s.from = ""
s.rcpts = s.rcpts[:0]
}
func (s *SubmissionSession) Logout() error { return nil }
// signDKIM signs the message with the sender domain's DKIM key if available.
// Returns the original raw on any error (DKIM is best-effort).
func (s *SubmissionSession) signDKIM(ctx context.Context, raw []byte, senderDomain string) []byte {
dom, err := s.deps.DB.GetDomain(ctx, senderDomain)
if err != nil || dom == nil || dom.DKIMPrivateEnc == nil {
return raw // no key configured
}
privPEM, err := s.deps.Crypt.DecryptGlobal("dkim", dom.DKIMPrivateEnc)
if err != nil {
log.Printf("[smtp/submission] dkim key decrypt for %s: %v", senderDomain, err)
return raw
}
signer, err := dkim.NewSigner(string(privPEM), senderDomain, dom.DKIMSelector)
if err != nil {
log.Printf("[smtp/submission] dkim signer for %s: %v", senderDomain, err)
return raw
}
header, err := signer.Sign(raw)
if err != nil {
log.Printf("[smtp/submission] dkim sign for %s: %v", senderDomain, err)
return raw
}
// Prepend DKIM-Signature header.
return append([]byte(header+"\r\n"), raw...)
}
// saveSentCopy stores a copy in the user's Sent mailbox (best-effort).
func (s *SubmissionSession) saveSentCopy(ctx context.Context, raw []byte) {
mbox, err := s.deps.DB.GetMailboxByType(ctx, s.user.ID, models.MailboxSent)
if err != nil || mbox == nil {
return
}
msg, err := mail.ReadMessage(strings.NewReader(string(raw)))
if err != nil {
return
}
subject := msg.Header.Get("Subject")
msgDate, _ := mail.ParseDate(msg.Header.Get("Date"))
if msgDate.IsZero() {
msgDate = time.Now().UTC()
}
incoming := &storage.IncomingMessage{
Raw: raw,
FromEmail: s.from,
ToList: strings.Join(s.rcpts, ", "),
Subject: subject,
Date: msgDate,
MessageID: extractMsgID(raw),
SpamScore: 0,
}
if _, err := s.deps.Store.SaveIncoming(ctx, s.user.ID, mbox.ID, incoming); err != nil {
log.Printf("[smtp/submission] save sent copy: %v", err)
}
}
// ---- helpers ----
func (s *SubmissionSession) logAttempt(ctx context.Context, email string, success bool) {
if s.deps.Brute != nil {
s.deps.Brute.RecordAttempt(ctx, s.clientIP, email, success) //nolint:errcheck
}
}
func domainOf(email string) string {
at := strings.LastIndex(email, "@")
if at < 0 {
return ""
}
return strings.ToLower(email[at+1:])
}
func extractMsgID(raw []byte) string {
msg, err := mail.ReadMessage(bytes.NewReader(raw))
if err != nil {
return ""
}
return msg.Header.Get("Message-ID")
}
+314
View File
@@ -0,0 +1,314 @@
// Package spam scores inbound messages using static heuristics plus
// optional Bayesian token analysis (per RFC 5965 conventions).
// Score >= threshold → deliver to Spam folder.
package spam
import (
"context"
"database/sql"
"fmt"
"math"
"net"
"strings"
"time"
"unicode"
"ghb.freebede.com/nahakubuilder/mailgosend/internal/db"
"ghb.freebede.com/nahakubuilder/mailgosend/internal/spf"
)
// Scorer evaluates spam likelihood.
type Scorer struct {
db *db.DB
threshold int
dnsbl []string
checkSPF bool
checkDKIM bool
}
// Result holds the total spam score and the component breakdown.
type Result struct {
Total int
Reasons []string
IsSpam bool
}
// Params groups message features for scoring.
type Params struct {
ClientIP net.IP
SenderDomain string
SPFResult spf.Result
DKIMValid bool
DKIMPresent bool
DMARCFail bool
Subject string
FromHeader string
HasHTMLOnly bool // true if no text/plain part
RecipCount int
HasDateHeader bool
HasMsgIDHeader bool
BodyText string // first 1000 bytes of plain text for token analysis
}
// NewScorer creates a scorer from config values.
func NewScorer(database *db.DB, threshold int, dnsbl []string, checkSPF, checkDKIM bool) *Scorer {
return &Scorer{
db: database,
threshold: threshold,
dnsbl: dnsbl,
checkSPF: checkSPF,
checkDKIM: checkDKIM,
}
}
// Score evaluates the message and returns a Result.
func (s *Scorer) Score(ctx context.Context, userID int64, p *Params) *Result {
r := &Result{}
add := func(pts int, reason string) {
r.Total += pts
r.Reasons = append(r.Reasons, fmt.Sprintf("+%d: %s", pts, reason))
}
// DNSBL check (async-ish: each lookup gets its own goroutine with timeout)
if p.ClientIP != nil {
hits := s.dnsblCheck(ctx, p.ClientIP)
for _, bl := range hits {
add(5, "DNSBL hit: "+bl)
}
}
// SPF
if s.checkSPF {
switch p.SPFResult {
case spf.ResultFail:
add(4, "SPF fail")
case spf.ResultSoftFail:
add(2, "SPF softfail")
case spf.ResultNone:
add(1, "SPF none (no record)")
}
}
// DKIM
if s.checkDKIM {
if !p.DKIMPresent {
add(2, "DKIM absent")
} else if !p.DKIMValid {
add(3, "DKIM invalid signature")
}
}
// DMARC
if p.DMARCFail {
add(5, "DMARC fail")
}
// Missing required headers.
if !p.HasDateHeader {
add(1, "missing Date header")
}
if !p.HasMsgIDHeader {
add(1, "missing Message-ID header")
}
// HTML-only (no text/plain).
if p.HasHTMLOnly {
add(1, "HTML-only body")
}
// All-caps subject.
if p.Subject != "" && isAllCaps(p.Subject) {
add(1, "all-caps subject")
}
// Excessive recipients.
if p.RecipCount > 20 {
add(2, fmt.Sprintf("excessive recipients (%d)", p.RecipCount))
}
// Bayesian (per-user trained model).
if userID > 0 && p.BodyText != "" {
bayesScore, err := s.bayesScore(ctx, userID, p.BodyText)
if err == nil && bayesScore >= 0.8 {
pts := int((bayesScore - 0.7) * 20) // 0.8→2 pts, 0.9→4 pts, 1.0→6 pts
add(pts, fmt.Sprintf("Bayesian score %.2f", bayesScore))
}
}
r.IsSpam = r.Total >= s.threshold
return r
}
// TrainSpam adds body tokens to the user's spam corpus.
func (s *Scorer) TrainSpam(ctx context.Context, userID int64, body string) error {
return s.trainTokens(ctx, userID, body, true)
}
// TrainHam adds body tokens to the user's ham corpus.
func (s *Scorer) TrainHam(ctx context.Context, userID int64, body string) error {
return s.trainTokens(ctx, userID, body, false)
}
// ---- DNSBL ----
func (s *Scorer) dnsblCheck(ctx context.Context, ip net.IP) []string {
ipv4 := ip.To4()
if ipv4 == nil {
return nil // DNSBL queries are IPv4-only for now
}
// Reverse the IP octets: 1.2.3.4 → 4.3.2.1
reversed := fmt.Sprintf("%d.%d.%d.%d", ipv4[3], ipv4[2], ipv4[1], ipv4[0])
type result struct{ bl string }
hits := make(chan result, len(s.dnsbl))
timeout, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
for _, bl := range s.dnsbl {
bl := bl
go func() {
query := reversed + "." + bl
addrs, err := net.DefaultResolver.LookupHost(timeout, query)
if err == nil && len(addrs) > 0 {
hits <- result{bl}
} else {
hits <- result{}
}
}()
}
var matched []string
for range s.dnsbl {
if h := <-hits; h.bl != "" {
matched = append(matched, h.bl)
}
}
return matched
}
// ---- Bayesian ----
func tokenize(body string) []string {
body = strings.ToLower(body)
var tokens []string
seen := make(map[string]struct{})
words := strings.FieldsFunc(body, func(r rune) bool {
return !unicode.IsLetter(r) && !unicode.IsDigit(r)
})
for _, w := range words {
if len(w) < 3 || len(w) > 30 {
continue
}
if _, ok := seen[w]; ok {
continue
}
seen[w] = struct{}{}
tokens = append(tokens, w)
if len(tokens) >= 200 {
break
}
}
return tokens
}
// bayesScore returns the probability [0,1] that the message is spam.
func (s *Scorer) bayesScore(ctx context.Context, userID int64, body string) (float64, error) {
tokens := tokenize(body)
if len(tokens) == 0 {
return 0, nil
}
// Fetch total spam/ham message counts for this user (proxy: count rows with nonzero counts).
var totalSpam, totalHam int64
err := s.db.SQL().QueryRowContext(ctx, `
SELECT COALESCE(SUM(spam_count),0), COALESCE(SUM(ham_count),0)
FROM spam_tokens WHERE user_id=?`, userID).Scan(&totalSpam, &totalHam)
if err != nil || (totalSpam+totalHam) < 50 {
// Not enough training data.
return 0, nil
}
// Naive Bayes: P(spam|words) ∝ Π P(word|spam) / Π P(word|ham)
// Use log-probabilities to avoid underflow.
var logP float64
placeholders := make([]string, len(tokens))
args := make([]interface{}, len(tokens)+1)
args[0] = userID
for i, tok := range tokens {
placeholders[i] = "?"
args[i+1] = tok
}
query := fmt.Sprintf(`
SELECT token, spam_count, ham_count FROM spam_tokens
WHERE user_id=? AND token IN (%s)`, strings.Join(placeholders, ","))
rows, err := s.db.SQL().QueryContext(ctx, query, args...)
if err != nil {
return 0, err
}
defer rows.Close()
tokenData := make(map[string][2]int64)
for rows.Next() {
var tok string
var sc, hc int64
if err := rows.Scan(&tok, &sc, &hc); err != nil {
return 0, err
}
tokenData[tok] = [2]int64{sc, hc}
}
for _, tok := range tokens {
counts := tokenData[tok]
sc, hc := counts[0], counts[1]
// Laplace smoothing.
pSpam := float64(sc+1) / float64(totalSpam+2)
pHam := float64(hc+1) / float64(totalHam+2)
logP += math.Log(pSpam) - math.Log(pHam)
}
// Convert log-odds back to probability.
prob := 1.0 / (1.0 + math.Exp(-logP))
return prob, nil
}
func (s *Scorer) trainTokens(ctx context.Context, userID int64, body string, isSpam bool) error {
tokens := tokenize(body)
for _, tok := range tokens {
var col string
if isSpam {
col = "spam_count"
} else {
col = "ham_count"
}
_, err := s.db.SQL().ExecContext(ctx, fmt.Sprintf(`
INSERT INTO spam_tokens (user_id, token, %s)
VALUES (?, ?, 1)
ON CONFLICT(user_id, token) DO UPDATE SET %s=%s+1`, col, col, col),
userID, tok)
if err != nil && err != sql.ErrNoRows {
return fmt.Errorf("train token %q: %w", tok, err)
}
}
return nil
}
func isAllCaps(s string) bool {
hasLetter := false
for _, r := range s {
if unicode.IsLetter(r) {
hasLetter = true
if unicode.IsLower(r) {
return false
}
}
}
return hasLetter
}
+200
View File
@@ -0,0 +1,200 @@
// Package spf implements basic SPF (RFC 7208) DNS lookup and policy evaluation.
// Only stdlib net is used for DNS. Lookup limit: 10 DNS mechanisms per spec.
package spf
import (
"fmt"
"net"
"strings"
)
// Result values per RFC 7208 §2.6.
type Result int
const (
ResultNone Result = iota // no SPF record
ResultNeutral // ?all
ResultPass // sender is permitted
ResultFail // sender is not permitted (hard fail)
ResultSoftFail // ~all (likely spam)
ResultTempError // transient DNS error
ResultPermError // permanent SPF parse error
)
func (r Result) String() string {
return [...]string{"none", "neutral", "pass", "fail", "softfail", "temperror", "permerror"}[r]
}
// Check performs SPF evaluation for the given sender IP and envelope-from domain.
// Returns the SPF result and an explanation string.
func Check(clientIP net.IP, senderDomain string) (Result, string) {
if senderDomain == "" {
return ResultNone, "empty sender domain"
}
record, err := fetchSPF(senderDomain)
if err != nil {
return ResultTempError, fmt.Sprintf("SPF DNS lookup %s: %v", senderDomain, err)
}
if record == "" {
return ResultNone, "no SPF record for " + senderDomain
}
result, msg := evaluate(clientIP, senderDomain, record, 0)
return result, msg
}
// fetchSPF queries TXT records for the domain and returns the first SPF record.
func fetchSPF(domain string) (string, error) {
txts, err := net.LookupTXT(domain)
if err != nil {
// Treat NXDOMAIN as "no record" not an error.
if dnsErr, ok := err.(*net.DNSError); ok && dnsErr.IsNotFound {
return "", nil
}
return "", err
}
for _, txt := range txts {
txt = strings.TrimSpace(txt)
if strings.HasPrefix(txt, "v=spf1") {
return txt, nil
}
}
return "", nil
}
// evaluate parses and applies SPF mechanisms. dnsLookups tracks the lookup count.
func evaluate(ip net.IP, domain, record string, dnsLookups int) (Result, string) {
parts := strings.Fields(record)
if len(parts) == 0 || !strings.EqualFold(parts[0], "v=spf1") {
return ResultPermError, "invalid SPF record: " + record
}
for _, term := range parts[1:] {
if dnsLookups > 10 {
return ResultPermError, "exceeded 10 DNS lookups"
}
// Qualifier prefix: +pass -fail ~softfail ?neutral
qualifier := ResultPass
switch term[0] {
case '+':
qualifier = ResultPass
term = term[1:]
case '-':
qualifier = ResultFail
term = term[1:]
case '~':
qualifier = ResultSoftFail
term = term[1:]
case '?':
qualifier = ResultNeutral
term = term[1:]
}
lower := strings.ToLower(term)
switch {
case lower == "all":
return qualifier, "matched 'all'"
case strings.HasPrefix(lower, "ip4:"):
cidr := term[4:]
if !strings.Contains(cidr, "/") {
cidr += "/32"
}
_, network, err := net.ParseCIDR(cidr)
if err != nil {
continue
}
if ip.To4() != nil && network.Contains(ip) {
return qualifier, fmt.Sprintf("matched ip4:%s", term[4:])
}
case strings.HasPrefix(lower, "ip6:"):
cidr := term[4:]
if !strings.Contains(cidr, "/") {
cidr += "/128"
}
_, network, err := net.ParseCIDR(cidr)
if err != nil {
continue
}
if ip.To4() == nil && network.Contains(ip) {
return qualifier, fmt.Sprintf("matched ip6:%s", term[4:])
}
case lower == "a" || strings.HasPrefix(lower, "a:") || strings.HasPrefix(lower, "a/"):
dnsLookups++
checkDomain := domain
if strings.HasPrefix(lower, "a:") {
checkDomain = term[2:]
}
addrs, err := net.LookupHost(checkDomain)
if err == nil {
for _, addr := range addrs {
if net.ParseIP(addr).Equal(ip) {
return qualifier, fmt.Sprintf("matched a:%s", checkDomain)
}
}
}
case lower == "mx" || strings.HasPrefix(lower, "mx:") || strings.HasPrefix(lower, "mx/"):
dnsLookups++
checkDomain := domain
if strings.HasPrefix(lower, "mx:") {
checkDomain = term[3:]
}
mxs, err := net.LookupMX(checkDomain)
if err == nil {
for _, mx := range mxs {
dnsLookups++
addrs, err := net.LookupHost(mx.Host)
if err == nil {
for _, addr := range addrs {
if net.ParseIP(addr).Equal(ip) {
return qualifier, fmt.Sprintf("matched mx:%s", checkDomain)
}
}
}
}
}
case strings.HasPrefix(lower, "include:"):
includeDomain := term[8:]
dnsLookups++
includeRecord, err := fetchSPF(includeDomain)
if err != nil {
return ResultTempError, fmt.Sprintf("include %s DNS error: %v", includeDomain, err)
}
if includeRecord == "" {
return ResultPermError, "include domain has no SPF: " + includeDomain
}
subResult, subMsg := evaluate(ip, includeDomain, includeRecord, dnsLookups)
if subResult == ResultPass {
return qualifier, fmt.Sprintf("include:%s → %s", includeDomain, subMsg)
}
case strings.HasPrefix(lower, "redirect="):
redirectDomain := term[9:]
dnsLookups++
redirectRecord, err := fetchSPF(redirectDomain)
if err != nil {
return ResultTempError, fmt.Sprintf("redirect %s DNS error: %v", redirectDomain, err)
}
if redirectRecord == "" {
return ResultNone, "redirect domain has no SPF: " + redirectDomain
}
return evaluate(ip, redirectDomain, redirectRecord, dnsLookups)
case strings.HasPrefix(lower, "exp="):
// Explanation modifier — ignore.
default:
// Unknown mechanism — ignore per spec.
}
}
// No mechanism matched.
return ResultNeutral, "no mechanism matched"
}
+358
View File
@@ -0,0 +1,358 @@
// Package storage provides encrypted message persistence for both SQLite (db)
// and filesystem (fs) backends. All body and raw data is encrypted with
// AES-256-GCM before being written.
package storage
import (
"bytes"
"context"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"fmt"
"io"
"mime"
"mime/multipart"
"mime/quotedprintable"
"net/mail"
"os"
"path/filepath"
"strings"
"time"
"ghb.freebede.com/nahakubuilder/mailgosend/internal/crypto"
"ghb.freebede.com/nahakubuilder/mailgosend/internal/db"
"ghb.freebede.com/nahakubuilder/mailgosend/internal/models"
)
// Store is the encrypted message store.
type Store struct {
db *db.DB
crypt *crypto.Crypto
backend string // "db" | "fs"
fsPath string
}
// IncomingMessage is the raw inbound message plus parsed envelope fields.
type IncomingMessage struct {
Raw []byte // full RFC822
FromEmail string
FromName string
ToList string
CCList string
BCCList string
ReplyTo string
Subject string
Date time.Time
MessageID string
SpamScore int
}
type parsedAttachment struct {
Filename string
ContentType string
Data []byte
ContentID string
Inline bool
MIMEPath string
}
// New validates the backend and returns a ready Store.
func New(database *db.DB, crypt *crypto.Crypto, backend, fsPath string) (*Store, error) {
if backend != "db" && backend != "fs" {
return nil, fmt.Errorf("storage: unknown backend %q (want \"db\" or \"fs\")", backend)
}
if backend == "fs" {
if fsPath == "" {
return nil, fmt.Errorf("storage: fsPath required for fs backend")
}
if err := os.MkdirAll(fsPath, 0700); err != nil {
return nil, fmt.Errorf("storage: create fs dir: %w", err)
}
}
return &Store{db: database, crypt: crypt, backend: backend, fsPath: fsPath}, nil
}
// SaveIncoming encrypts and persists an incoming message plus its attachments.
// Returns the new message row ID.
func (s *Store) SaveIncoming(ctx context.Context, userID, mailboxID int64, msg *IncomingMessage) (int64, error) {
uid, err := s.db.NextUID(ctx, mailboxID)
if err != nil {
return 0, fmt.Errorf("storage: next uid: %w", err)
}
bodyText, bodyHTML, attachments := parseMIME(msg.Raw)
key, err := s.crypt.DeriveKey("messages", userID)
if err != nil {
return 0, fmt.Errorf("storage: derive key: %w", err)
}
bodyTextEnc, err := crypto.Encrypt(key, []byte(bodyText))
if err != nil {
return 0, fmt.Errorf("storage: encrypt body text: %w", err)
}
bodyHTMLEnc, err := crypto.Encrypt(key, []byte(bodyHTML))
if err != nil {
return 0, fmt.Errorf("storage: encrypt body html: %w", err)
}
rawEnc, err := crypto.Encrypt(key, msg.Raw)
if err != nil {
return 0, fmt.Errorf("storage: encrypt raw: %w", err)
}
insert := &db.MessageInsert{
MailboxID: mailboxID,
UID: uid,
MessageID: msg.MessageID,
Subject: msg.Subject,
FromEmail: msg.FromEmail,
FromName: msg.FromName,
ToList: msg.ToList,
CCList: msg.CCList,
BCCList: msg.BCCList,
ReplyTo: msg.ReplyTo,
Date: msg.Date,
BodyTextEnc: bodyTextEnc,
BodyHTMLEnc: bodyHTMLEnc,
RawEnc: rawEnc,
SizeBytes: int64(len(msg.Raw)),
HasAttachment: len(attachments) > 0,
IsRead: false,
IsStarred: false,
IsDraft: false,
Flags: "",
SpamScore: msg.SpamScore,
}
messageID, err := s.db.InsertMessage(ctx, insert)
if err != nil {
return 0, fmt.Errorf("storage: insert message: %w", err)
}
if err := s.db.UpdateUsedBytes(ctx, userID, int64(len(msg.Raw))); err != nil {
return 0, fmt.Errorf("storage: update used bytes: %w", err)
}
for _, att := range attachments {
attEnc, err := crypto.Encrypt(key, att.Data)
if err != nil {
return 0, fmt.Errorf("storage: encrypt attachment: %w", err)
}
var dataPath string
var dataEnc []byte
if s.backend == "fs" {
h := sha256.Sum256(att.Data)
name := hex.EncodeToString(h[:]) + ".att"
p := filepath.Join(s.fsPath, name)
if err := os.WriteFile(p, attEnc, 0600); err != nil {
return 0, fmt.Errorf("storage: write attachment file: %w", err)
}
dataPath = p
} else {
dataEnc = attEnc
}
attInsert := &db.AttachmentInsert{
MessageID: messageID,
Filename: att.Filename,
ContentType: att.ContentType,
SizeBytes: int64(len(att.Data)),
DataEnc: dataEnc,
DataPath: dataPath,
ContentID: att.ContentID,
Inline: att.Inline,
MIMEPath: att.MIMEPath,
}
if _, err := s.db.InsertAttachment(ctx, attInsert); err != nil {
return 0, fmt.Errorf("storage: insert attachment: %w", err)
}
}
return messageID, nil
}
// GetRaw returns the decrypted RFC822 blob for a message.
func (s *Store) GetRaw(ctx context.Context, userID, messageID int64) ([]byte, error) {
rawEnc, err := s.db.GetMessageRaw(ctx, messageID)
if err != nil {
return nil, fmt.Errorf("storage: get raw enc: %w", err)
}
if rawEnc == nil {
return nil, fmt.Errorf("storage: message %d not found", messageID)
}
key, err := s.crypt.DeriveKey("messages", userID)
if err != nil {
return nil, fmt.Errorf("storage: derive key: %w", err)
}
plain, err := crypto.Decrypt(key, rawEnc)
if err != nil {
return nil, fmt.Errorf("storage: decrypt raw: %w", err)
}
return plain, nil
}
// parseMIME walks a raw RFC822 message and extracts body parts and attachments.
func parseMIME(raw []byte) (bodyText, bodyHTML string, attachments []parsedAttachment) {
m, err := mail.ReadMessage(bytes.NewReader(raw))
if err != nil {
return "", "", nil
}
ct := m.Header.Get("Content-Type")
body, err := io.ReadAll(m.Body)
if err != nil {
return "", "", nil
}
walkPart(ct, m.Header.Get("Content-Transfer-Encoding"),
m.Header.Get("Content-Disposition"), m.Header.Get("Content-ID"),
body, "1", &bodyText, &bodyHTML, &attachments)
return bodyText, bodyHTML, attachments
}
// walkPart recursively processes one MIME part.
func walkPart(
ctHeader, cteHeader, cdHeader, cidHeader string,
data []byte,
mimePath string,
bodyText, bodyHTML *string,
attachments *[]parsedAttachment,
) {
mediaType, params, err := mime.ParseMediaType(ctHeader)
if err != nil {
// Treat unparseable as text/plain.
mediaType = "text/plain"
params = map[string]string{}
}
// Decode transfer encoding first if this is a leaf part.
if !strings.HasPrefix(mediaType, "multipart/") {
data = decodeCTE(cteHeader, data)
}
switch {
case mediaType == "text/plain" && !isAttachment(cdHeader):
*bodyText = string(data)
case mediaType == "text/html" && !isAttachment(cdHeader):
*bodyHTML = string(data)
case strings.HasPrefix(mediaType, "multipart/"):
boundary, ok := params["boundary"]
if !ok {
return
}
mr := multipart.NewReader(bytes.NewReader(data), boundary)
partIdx := 1
for {
part, err := mr.NextPart()
if err != nil {
break
}
partData, err := io.ReadAll(part)
if err != nil {
part.Close()
break
}
part.Close()
subCT := part.Header.Get("Content-Type")
if subCT == "" {
subCT = "text/plain"
}
subCTE := part.Header.Get("Content-Transfer-Encoding")
subCD := part.Header.Get("Content-Disposition")
subCID := part.Header.Get("Content-ID")
subPath := fmt.Sprintf("%s.%d", mimePath, partIdx)
walkPart(subCT, subCTE, subCD, subCID, partData, subPath,
bodyText, bodyHTML, attachments)
partIdx++
}
default:
// Treat as attachment.
filename := filenameFrom(cdHeader, ctHeader)
inline := strings.HasPrefix(strings.ToLower(cdHeader), "inline")
cleanCID := strings.Trim(cidHeader, "<>")
*attachments = append(*attachments, parsedAttachment{
Filename: filename,
ContentType: mediaType,
Data: data,
ContentID: cleanCID,
Inline: inline,
MIMEPath: mimePath,
})
}
}
// decodeCTE applies quoted-printable or base64 decoding based on the
// Content-Transfer-Encoding header value.
func decodeCTE(cte string, data []byte) []byte {
switch strings.ToLower(strings.TrimSpace(cte)) {
case "quoted-printable":
decoded, err := io.ReadAll(quotedprintable.NewReader(bytes.NewReader(data)))
if err != nil {
return data
}
return decoded
case "base64":
// Strip whitespace — base64 in email is line-wrapped.
clean := strings.Map(func(r rune) rune {
if r == '\r' || r == '\n' || r == ' ' || r == '\t' {
return -1
}
return r
}, string(data))
decoded, err := base64.StdEncoding.DecodeString(clean)
if err != nil {
// Try raw/URL encoding as fallback.
decoded, err = base64.RawStdEncoding.DecodeString(clean)
if err != nil {
return data
}
}
return decoded
default:
return data
}
}
// isAttachment returns true when the Content-Disposition header signals attachment.
func isAttachment(cd string) bool {
if cd == "" {
return false
}
disp, _, _ := mime.ParseMediaType(cd)
return strings.EqualFold(disp, "attachment")
}
// filenameFrom extracts a filename from Content-Disposition, falling back to
// the name= param of Content-Type.
func filenameFrom(cd, ct string) string {
if cd != "" {
_, params, err := mime.ParseMediaType(cd)
if err == nil {
if name, ok := params["filename"]; ok && name != "" {
return filepath.Base(name)
}
}
}
if ct != "" {
_, params, err := mime.ParseMediaType(ct)
if err == nil {
if name, ok := params["name"]; ok && name != "" {
return filepath.Base(name)
}
}
}
return "attachment"
}
// Ensure models import is used (Subject field type reference avoids import pruning).
var _ = models.MailboxInbox
+350
View File
@@ -0,0 +1,350 @@
package tls
import (
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"encoding/json"
"encoding/pem"
"fmt"
"os"
"path/filepath"
"time"
"github.com/go-acme/lego/v4/certcrypto"
"github.com/go-acme/lego/v4/certificate"
"github.com/go-acme/lego/v4/challenge"
"github.com/go-acme/lego/v4/challenge/http01"
"github.com/go-acme/lego/v4/lego"
"github.com/go-acme/lego/v4/providers/dns/cloudflare"
"github.com/go-acme/lego/v4/providers/dns/digitalocean"
"github.com/go-acme/lego/v4/providers/dns/hetzner"
"github.com/go-acme/lego/v4/providers/dns/route53"
"github.com/go-acme/lego/v4/registration"
)
// ACME manages Let's Encrypt certificate issuance and renewal.
type ACME struct {
cfg ACMEConfig
client *lego.Client
account *acmeAccount
cacheDir string
}
// ACMEConfig holds all ACME-related settings from app config.
type ACMEConfig struct {
Email string
CacheDir string
Staging bool
Domains []string // e.g. ["example.com", "*.example.com"]
Mode string // "dns01" | "http01"
DNSProvider string // "cloudflare" | "route53" | "digitalocean" | "hetzner" | ...
}
// acmeAccount implements lego's registration.User interface.
type acmeAccount struct {
Email string `json:"email"`
Registration *registration.Resource `json:"registration"`
key crypto.PrivateKey
keyPEM []byte
}
func (a *acmeAccount) GetEmail() string { return a.Email }
func (a *acmeAccount) GetRegistration() *registration.Resource { return a.Registration }
func (a *acmeAccount) GetPrivateKey() crypto.PrivateKey { return a.key }
// NewACME initialises the ACME client. Does not yet obtain a cert.
// Call ObtainOrRenew() to get/refresh the certificate.
func NewACME(cfg ACMEConfig) (*ACME, error) {
if cfg.Email == "" {
return nil, fmt.Errorf("ACME_EMAIL required for Let's Encrypt")
}
if len(cfg.Domains) == 0 {
return nil, fmt.Errorf("ACME_DOMAINS must not be empty")
}
if err := os.MkdirAll(cfg.CacheDir, 0700); err != nil {
return nil, fmt.Errorf("acme cache dir: %w", err)
}
a := &ACME{cfg: cfg, cacheDir: cfg.CacheDir}
acc, err := a.loadOrCreateAccount()
if err != nil {
return nil, fmt.Errorf("acme account: %w", err)
}
a.account = acc
legoConfig := lego.NewConfig(a.account)
if cfg.Staging {
legoConfig.CADirURL = lego.LEDirectoryStaging
} else {
legoConfig.CADirURL = lego.LEDirectoryProduction
}
legoConfig.Certificate.KeyType = certcrypto.RSA2048
client, err := lego.NewClient(legoConfig)
if err != nil {
return nil, fmt.Errorf("lego client: %w", err)
}
a.client = client
// Configure challenge provider.
if err := a.setProvider(); err != nil {
return nil, err
}
// Register account if not already registered.
if acc.Registration == nil {
reg, err := client.Registration.Register(registration.RegisterOptions{
TermsOfServiceAgreed: true,
})
if err != nil {
return nil, fmt.Errorf("acme register: %w", err)
}
a.account.Registration = reg
if err := a.saveAccount(); err != nil {
return nil, fmt.Errorf("save acme account: %w", err)
}
}
return a, nil
}
// setProvider configures the ACME challenge based on cfg.Mode and cfg.DNSProvider.
func (a *ACME) setProvider() error {
switch a.cfg.Mode {
case "http01":
// Lego manages an ephemeral HTTP server on port 80.
return a.client.Challenge.SetHTTP01Provider(http01.NewProviderServer("", "80"))
case "dns01":
provider, err := a.buildDNSProvider()
if err != nil {
return fmt.Errorf("dns01 provider %q: %w", a.cfg.DNSProvider, err)
}
return a.client.Challenge.SetDNS01Provider(provider)
default:
return fmt.Errorf("unknown ACME mode %q (want dns01 or http01)", a.cfg.Mode)
}
}
// buildDNSProvider returns the lego DNS provider for cfg.DNSProvider.
// Credentials are read from env vars set by config.exportProviderEnv().
func (a *ACME) buildDNSProvider() (challenge.Provider, error) {
switch a.cfg.DNSProvider {
case "cloudflare":
return cloudflare.NewDNSProvider()
case "route53":
return route53.NewDNSProvider()
case "digitalocean":
return digitalocean.NewDNSProvider()
case "hetzner":
return hetzner.NewDNSProvider()
default:
// Generic: lego supports 90+ providers; if the user sets the correct
// env vars and the provider name matches a lego provider, it WILL work
// through the lego plugin system. For unlisted providers, document that
// the user must set env vars matching the lego provider docs.
return nil, fmt.Errorf(
"provider %q not built-in; set lego env vars and use a supported provider name.\n"+
"Built-in: cloudflare, route53, digitalocean, hetzner.\n"+
"Full list: https://go-acme.github.io/lego/dns/",
a.cfg.DNSProvider,
)
}
}
// ObtainOrRenew returns a valid *tls.Certificate, obtaining or renewing as needed.
// Uses cached cert on disk if it has >30 days remaining.
func (a *ACME) ObtainOrRenew() (*tls.Certificate, error) {
cached, err := a.loadCachedCert()
if err == nil && cached != nil {
return cached, nil
}
return a.obtain()
}
// obtain requests a new certificate from Let's Encrypt.
func (a *ACME) obtain() (*tls.Certificate, error) {
req := certificate.ObtainRequest{
Domains: a.cfg.Domains,
Bundle: true, // include full chain in cert PEM
}
res, err := a.client.Certificate.Obtain(req)
if err != nil {
return nil, fmt.Errorf("acme obtain %v: %w", a.cfg.Domains, err)
}
if err := a.saveCert(res); err != nil {
// Non-fatal: cert is in memory, just can't persist.
fmt.Printf("[acme] WARNING: could not cache cert: %v\n", err)
}
return parseCert(res.Certificate, res.PrivateKey)
}
// RenewalLoop blocks forever, renewing the cert 30 days before expiry.
// manager.UpdateCert() is called on each renewal to hot-swap the TLS config.
// Call as a goroutine. stopCh receives when it should quit.
func (a *ACME) RenewalLoop(manager *Manager, stopCh <-chan struct{}) {
ticker := time.NewTicker(12 * time.Hour)
defer ticker.Stop()
for {
select {
case <-stopCh:
return
case <-ticker.C:
cached, err := a.loadCachedCert()
if err != nil || cached == nil {
// No cert or corrupt cache — re-obtain.
cert, err := a.obtain()
if err != nil {
fmt.Printf("[acme] renewal obtain error: %v\n", err)
continue
}
manager.UpdateCert(cert)
fmt.Printf("[acme] cert renewed for %v\n", a.cfg.Domains)
}
// loadCachedCert returns nil if cert expires in < 30 days → triggers re-obtain above.
}
}
}
// ---- Cert persistence ----
func (a *ACME) certPath() string { return filepath.Join(a.cacheDir, "cert.pem") }
func (a *ACME) keyPath() string { return filepath.Join(a.cacheDir, "key.pem") }
func (a *ACME) accPath() string { return filepath.Join(a.cacheDir, "account.json") }
func (a *ACME) accKeyPath() string { return filepath.Join(a.cacheDir, "account.key") }
func (a *ACME) saveCert(res *certificate.Resource) error {
if err := os.WriteFile(a.certPath(), res.Certificate, 0600); err != nil {
return err
}
return os.WriteFile(a.keyPath(), res.PrivateKey, 0600)
}
// loadCachedCert returns the cached cert if it has >30 days remaining, nil otherwise.
func (a *ACME) loadCachedCert() (*tls.Certificate, error) {
certPEM, err := os.ReadFile(a.certPath())
if err != nil {
return nil, err
}
keyPEM, err := os.ReadFile(a.keyPath())
if err != nil {
return nil, err
}
cert, err := parseCert(certPEM, keyPEM)
if err != nil {
return nil, err
}
// Check expiry — renew if <30 days left.
if cert.Leaf != nil && time.Until(cert.Leaf.NotAfter) < 30*24*time.Hour {
return nil, nil // signal: needs renewal
}
// Parse leaf if not already parsed.
if cert.Leaf == nil {
leaf, err := x509.ParseCertificate(cert.Certificate[0])
if err == nil && time.Until(leaf.NotAfter) < 30*24*time.Hour {
return nil, nil
}
}
return cert, nil
}
func parseCert(certPEM, keyPEM []byte) (*tls.Certificate, error) {
cert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
return nil, fmt.Errorf("parse cert: %w", err)
}
// Pre-parse leaf for expiry checks.
if len(cert.Certificate) > 0 {
leaf, err := x509.ParseCertificate(cert.Certificate[0])
if err == nil {
cert.Leaf = leaf
}
}
return &cert, nil
}
// ---- Account persistence ----
func (a *ACME) loadOrCreateAccount() (*acmeAccount, error) {
// Try loading existing account.
accData, errAcc := os.ReadFile(a.accPath())
keyData, errKey := os.ReadFile(a.accKeyPath())
if errAcc == nil && errKey == nil {
acc := &acmeAccount{}
if err := json.Unmarshal(accData, acc); err != nil {
return nil, fmt.Errorf("parse account: %w", err)
}
key, err := parseECKey(keyData)
if err != nil {
return nil, fmt.Errorf("parse account key: %w", err)
}
acc.key = key
acc.keyPEM = keyData
return acc, nil
}
// Generate new account key.
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, fmt.Errorf("gen account key: %w", err)
}
keyPEM, err := encodeECKey(key)
if err != nil {
return nil, err
}
acc := &acmeAccount{
Email: a.cfg.Email,
key: key,
keyPEM: keyPEM,
}
return acc, nil
}
func (a *ACME) saveAccount() error {
data, err := json.MarshalIndent(a.account, "", " ")
if err != nil {
return err
}
if err := os.WriteFile(a.accPath(), data, 0600); err != nil {
return err
}
return os.WriteFile(a.accKeyPath(), a.account.keyPEM, 0600)
}
// ---- EC key helpers ----
func encodeECKey(key *ecdsa.PrivateKey) ([]byte, error) {
der, err := x509.MarshalECPrivateKey(key)
if err != nil {
return nil, fmt.Errorf("marshal ec key: %w", err)
}
return pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: der}), nil
}
func parseECKey(pemData []byte) (*ecdsa.PrivateKey, error) {
block, _ := pem.Decode(pemData)
if block == nil {
return nil, fmt.Errorf("no PEM block in account key")
}
return x509.ParseECPrivateKey(block.Bytes)
}
+158
View File
@@ -0,0 +1,158 @@
// Package tls provides TLS configuration builders, cert loading,
// and hot-reload on SIGHUP.
package tls
import (
"crypto/tls"
"fmt"
"net"
"os"
"sync"
"sync/atomic"
"unsafe"
)
// MinVersion enforced across all listeners.
const MinVersion = tls.VersionTLS12
// cipherSuites is the approved list — ECDHE + AEAD only.
var cipherSuites = []uint16{
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
}
// Manager holds the current TLS certificate and rebuilds tls.Config on demand.
// Thread-safe: cert updates are atomic.
type Manager struct {
mu sync.Mutex
certFile string
keyFile string
cert unsafe.Pointer // *tls.Certificate, accessed atomically
}
// NewManager creates a Manager that loads cert + key from disk.
// Call Reload() after obtaining/renewing a cert.
func NewManager(certFile, keyFile string) (*Manager, error) {
m := &Manager{certFile: certFile, keyFile: keyFile}
if certFile != "" && keyFile != "" {
if err := m.Reload(); err != nil {
return nil, err
}
}
return m, nil
}
// Reload re-reads the cert and key from disk atomically.
// Safe to call from a SIGHUP handler or ACME renewal goroutine.
func (m *Manager) Reload() error {
m.mu.Lock()
defer m.mu.Unlock()
if m.certFile == "" || m.keyFile == "" {
return nil
}
cert, err := tls.LoadX509KeyPair(m.certFile, m.keyFile)
if err != nil {
return fmt.Errorf("tls load cert %s / %s: %w", m.certFile, m.keyFile, err)
}
atomic.StorePointer(&m.cert, unsafe.Pointer(&cert))
return nil
}
// UpdateCert replaces the in-memory certificate (used by ACME after renewal).
func (m *Manager) UpdateCert(cert *tls.Certificate) {
atomic.StorePointer(&m.cert, unsafe.Pointer(cert))
}
// GetCertificate implements tls.Config.GetCertificate — returns the current cert
// regardless of SNI (single-domain or wildcard usage).
func (m *Manager) GetCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
p := atomic.LoadPointer(&m.cert)
if p == nil {
return nil, fmt.Errorf("no TLS certificate loaded")
}
return (*tls.Certificate)(p), nil
}
// Config returns a *tls.Config with secure defaults and GetCertificate set.
func (m *Manager) Config() *tls.Config {
cfg := &tls.Config{
GetCertificate: m.GetCertificate,
MinVersion: MinVersion,
CipherSuites: cipherSuites,
// TLS 1.3 cipher suites are configured by the Go runtime automatically.
}
return cfg
}
// ConfigForSMTP returns a tls.Config for SMTP servers.
// SMTP requires setting ServerName in ClientHello so SNI works;
// the GetCertificate callback handles the lookup.
func (m *Manager) ConfigForSMTP() *tls.Config {
cfg := m.Config()
// SMTP clients often don't send SNI, so always return our cert.
return cfg
}
// ---- Manual cert loader (TLS_MODE=manual without hot-reload) ----
// LoadCert loads a certificate pair from disk and returns a ready tls.Config.
// Used for services where the Manager is not needed (e.g. one-off TLS dial).
func LoadCert(certFile, keyFile string) (*tls.Config, error) {
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, fmt.Errorf("load cert %s: %w", certFile, err)
}
return &tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: MinVersion,
CipherSuites: cipherSuites,
}, nil
}
// ---- TLS_MODE=off placeholder ----
// IsOff returns true if the given mode string means no TLS.
func IsOff(mode string) bool { return mode == "off" || mode == "" }
// ---- Per-service cert resolution ----
// ServiceCert describes the cert files for one service (SMTP, IMAP, web).
type ServiceCert struct {
CertFile string
KeyFile string
}
// Resolve returns the cert pair to use for a service:
// service-specific override → global fallback → empty (ACME handles it).
func Resolve(svcCert, svcKey, globalCert, globalKey string) ServiceCert {
if svcCert != "" && svcKey != "" {
return ServiceCert{CertFile: svcCert, KeyFile: svcKey}
}
if globalCert != "" && globalKey != "" {
return ServiceCert{CertFile: globalCert, KeyFile: globalKey}
}
return ServiceCert{} // ACME-managed
}
// FileExists returns true if both files exist.
func (sc ServiceCert) FileExists() bool {
if sc.CertFile == "" || sc.KeyFile == "" {
return false
}
_, errC := os.Stat(sc.CertFile)
_, errK := os.Stat(sc.KeyFile)
return errC == nil && errK == nil
}
// ---- TLS listener ----
// NewListener wraps a net.Listener with TLS using the provided config.
func NewListener(ln net.Listener, cfg *tls.Config) net.Listener {
return tls.NewListener(ln, cfg)
}