initial
This commit is contained in:
@@ -0,0 +1,201 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/db"
|
||||
)
|
||||
|
||||
// BruteGuard tracks login attempts and bans IPs exceeding the threshold.
|
||||
type BruteGuard struct {
|
||||
db *db.DB
|
||||
maxTries int
|
||||
windowMin int
|
||||
banHours int
|
||||
whitelist map[string]struct{} // exempt IPs
|
||||
}
|
||||
|
||||
// NewBruteGuard creates a brute-force guard.
|
||||
func NewBruteGuard(database *db.DB, maxTries, windowMin, banHours int, whitelist []string) *BruteGuard {
|
||||
wl := make(map[string]struct{}, len(whitelist))
|
||||
for _, ip := range whitelist {
|
||||
wl[ip] = struct{}{}
|
||||
}
|
||||
return &BruteGuard{
|
||||
db: database,
|
||||
maxTries: maxTries,
|
||||
windowMin: windowMin,
|
||||
banHours: banHours,
|
||||
whitelist: wl,
|
||||
}
|
||||
}
|
||||
|
||||
// IsBanned returns true if the IP is currently banned.
|
||||
func (g *BruteGuard) IsBanned(ctx context.Context, ip string) (bool, error) {
|
||||
if _, ok := g.whitelist[ip]; ok {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
var expiresAt sql.NullTime
|
||||
err := g.db.SQL().QueryRowContext(ctx,
|
||||
"SELECT expires_at FROM ip_bans WHERE ip = ?", ip).Scan(&expiresAt)
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("check ban: %w", err)
|
||||
}
|
||||
|
||||
// Permanent ban (expires_at IS NULL) or not yet expired.
|
||||
if !expiresAt.Valid || expiresAt.Time.After(time.Now().UTC()) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Expired ban — clean up.
|
||||
_, _ = g.db.SQL().ExecContext(ctx, "DELETE FROM ip_bans WHERE ip = ?", ip)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// RecordAttempt records a login attempt and bans the IP if threshold exceeded.
|
||||
// Returns (banned, error).
|
||||
func (g *BruteGuard) RecordAttempt(ctx context.Context, ip, email string, success bool) (bool, error) {
|
||||
if _, ok := g.whitelist[ip]; ok {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Insert attempt.
|
||||
_, err := g.db.SQL().ExecContext(ctx,
|
||||
"INSERT INTO login_attempts (ip, user_email, success, created_at) VALUES (?, ?, ?, ?)",
|
||||
ip, email, success, time.Now().UTC())
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("record attempt: %w", err)
|
||||
}
|
||||
|
||||
if success {
|
||||
// Successful login resets the counter (remove recent failed attempts).
|
||||
window := time.Now().UTC().Add(-time.Duration(g.windowMin) * time.Minute)
|
||||
_, _ = g.db.SQL().ExecContext(ctx,
|
||||
"DELETE FROM login_attempts WHERE ip = ? AND success = 0 AND created_at >= ?",
|
||||
ip, window)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Count failures in window.
|
||||
window := time.Now().UTC().Add(-time.Duration(g.windowMin) * time.Minute)
|
||||
var count int
|
||||
err = g.db.SQL().QueryRowContext(ctx,
|
||||
"SELECT COUNT(*) FROM login_attempts WHERE ip = ? AND success = 0 AND created_at >= ?",
|
||||
ip, window).Scan(&count)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("count attempts: %w", err)
|
||||
}
|
||||
|
||||
if count < g.maxTries {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Threshold exceeded — ban the IP.
|
||||
var expiresAt *time.Time
|
||||
if g.banHours > 0 {
|
||||
t := time.Now().UTC().Add(time.Duration(g.banHours) * time.Hour)
|
||||
expiresAt = &t
|
||||
}
|
||||
|
||||
_, err = g.db.SQL().ExecContext(ctx, `
|
||||
INSERT INTO ip_bans (ip, reason, banned_at, expires_at)
|
||||
VALUES (?, ?, ?, ?)
|
||||
ON CONFLICT(ip) DO UPDATE SET
|
||||
reason = excluded.reason,
|
||||
banned_at = excluded.banned_at,
|
||||
expires_at = excluded.expires_at`,
|
||||
ip,
|
||||
fmt.Sprintf("brute force: %d failed attempts in %d minutes", count, g.windowMin),
|
||||
time.Now().UTC(),
|
||||
expiresAt,
|
||||
)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("ban ip: %w", err)
|
||||
}
|
||||
|
||||
// Log security event.
|
||||
_, _ = g.db.SQL().ExecContext(ctx, `
|
||||
INSERT INTO security_events (type, ip, detail, created_at)
|
||||
VALUES ('brute_ban', ?, ?, ?)`,
|
||||
ip,
|
||||
fmt.Sprintf("%d failed attempts in %d min window, last target: %s", count, g.windowMin, email),
|
||||
time.Now().UTC(),
|
||||
)
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// BanIP manually bans an IP (from admin panel).
|
||||
func (g *BruteGuard) BanIP(ctx context.Context, ip, reason, bannedBy string, hours int) error {
|
||||
var expiresAt *time.Time
|
||||
if hours > 0 {
|
||||
t := time.Now().UTC().Add(time.Duration(hours) * time.Hour)
|
||||
expiresAt = &t
|
||||
}
|
||||
_, err := g.db.SQL().ExecContext(ctx, `
|
||||
INSERT INTO ip_bans (ip, reason, banned_at, expires_at)
|
||||
VALUES (?, ?, ?, ?)
|
||||
ON CONFLICT(ip) DO UPDATE SET
|
||||
reason = excluded.reason,
|
||||
banned_at = excluded.banned_at,
|
||||
expires_at = excluded.expires_at`,
|
||||
ip, reason, time.Now().UTC(), expiresAt)
|
||||
return err
|
||||
}
|
||||
|
||||
// UnbanIP removes a ban.
|
||||
func (g *BruteGuard) UnbanIP(ctx context.Context, ip, releasedBy string) error {
|
||||
_, err := g.db.SQL().ExecContext(ctx,
|
||||
"DELETE FROM ip_bans WHERE ip = ?", ip)
|
||||
return err
|
||||
}
|
||||
|
||||
// ListBans returns active bans ordered by newest first.
|
||||
func (g *BruteGuard) ListBans(ctx context.Context, limit int) ([]banRow, error) {
|
||||
rows, err := g.db.SQL().QueryContext(ctx, `
|
||||
SELECT ip, reason, banned_at, expires_at
|
||||
FROM ip_bans
|
||||
ORDER BY banned_at DESC
|
||||
LIMIT ?`, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var bans []banRow
|
||||
for rows.Next() {
|
||||
var b banRow
|
||||
var exp sql.NullTime
|
||||
if err := rows.Scan(&b.IP, &b.Reason, &b.BannedAt, &exp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if exp.Valid {
|
||||
b.ExpiresAt = &exp.Time
|
||||
}
|
||||
bans = append(bans, b)
|
||||
}
|
||||
return bans, rows.Err()
|
||||
}
|
||||
|
||||
// PurgeOldAttempts removes login attempt records older than windowMin*2.
|
||||
// Call periodically to keep the table small.
|
||||
func (g *BruteGuard) PurgeOldAttempts(ctx context.Context) error {
|
||||
cutoff := time.Now().UTC().Add(-2 * time.Duration(g.windowMin) * time.Minute)
|
||||
_, err := g.db.SQL().ExecContext(ctx,
|
||||
"DELETE FROM login_attempts WHERE created_at < ?", cutoff)
|
||||
return err
|
||||
}
|
||||
|
||||
type banRow struct {
|
||||
IP string
|
||||
Reason string
|
||||
BannedAt time.Time
|
||||
ExpiresAt *time.Time
|
||||
}
|
||||
@@ -0,0 +1,276 @@
|
||||
// Package auth provides session management and brute-force protection.
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/crypto"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/db"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/models"
|
||||
)
|
||||
|
||||
const (
|
||||
sessionCookieName = "mgs_session"
|
||||
sessionPurpose = "session"
|
||||
)
|
||||
|
||||
// SessionStore manages user sessions stored in the database.
|
||||
type SessionStore struct {
|
||||
db *db.DB
|
||||
maxAge int // seconds
|
||||
secureCookie bool
|
||||
}
|
||||
|
||||
// NewSessionStore creates a session store.
|
||||
func NewSessionStore(database *db.DB, maxAge int, secureCookie bool) *SessionStore {
|
||||
return &SessionStore{
|
||||
db: database,
|
||||
maxAge: maxAge,
|
||||
secureCookie: secureCookie,
|
||||
}
|
||||
}
|
||||
|
||||
// Create creates a new session for userID, sets the cookie, and returns the session.
|
||||
func (s *SessionStore) Create(w http.ResponseWriter, r *http.Request, userID int64) (*models.Session, error) {
|
||||
raw, hash, err := crypto.NewToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("session token: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
expires := now.Add(time.Duration(s.maxAge) * time.Second)
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
_, err = s.db.SQL().ExecContext(ctx, `
|
||||
INSERT INTO sessions (user_id, token_hash, ip, user_agent, created_at, expires_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)`,
|
||||
userID,
|
||||
hash,
|
||||
realIP(r),
|
||||
truncate(r.UserAgent(), 512),
|
||||
now,
|
||||
expires,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("insert session: %w", err)
|
||||
}
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: sessionCookieName,
|
||||
Value: raw,
|
||||
Path: "/",
|
||||
MaxAge: s.maxAge,
|
||||
HttpOnly: true,
|
||||
Secure: s.secureCookie,
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
})
|
||||
|
||||
return &models.Session{
|
||||
UserID: userID,
|
||||
TokenHash: hash,
|
||||
IP: realIP(r),
|
||||
CreatedAt: now,
|
||||
ExpiresAt: expires,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Get validates the session cookie and returns the session + user.
|
||||
// Returns (nil, nil, nil) if no valid session exists.
|
||||
func (s *SessionStore) Get(r *http.Request) (*models.Session, *models.User, error) {
|
||||
cookie, err := r.Cookie(sessionCookieName)
|
||||
if err != nil {
|
||||
return nil, nil, nil // no cookie = not logged in
|
||||
}
|
||||
|
||||
raw := cookie.Value
|
||||
if len(raw) == 0 {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
hash := crypto.HashToken(raw)
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
row := s.db.SQL().QueryRowContext(ctx, `
|
||||
SELECT s.id, s.user_id, s.token_hash, s.ip, s.user_agent, s.created_at, s.expires_at,
|
||||
u.id, u.domain_id, u.username, u.email, u.display_name,
|
||||
u.quota_bytes, u.used_bytes, u.enabled, u.admin, u.domain_admin,
|
||||
u.mfa_enabled, u.created_at, u.last_login
|
||||
FROM sessions s
|
||||
JOIN users u ON u.id = s.user_id
|
||||
WHERE s.token_hash = ?
|
||||
AND s.expires_at > ?
|
||||
AND u.enabled = 1`,
|
||||
hash, time.Now().UTC())
|
||||
|
||||
var sess models.Session
|
||||
var user models.User
|
||||
var lastLogin sql.NullTime
|
||||
var createdAt sql.NullTime
|
||||
|
||||
err = row.Scan(
|
||||
&sess.ID, &sess.UserID, &sess.TokenHash, &sess.IP, &sess.UserAgent,
|
||||
&sess.CreatedAt, &sess.ExpiresAt,
|
||||
&user.ID, &user.DomainID, &user.Username, &user.Email, &user.DisplayName,
|
||||
&user.QuotaBytes, &user.UsedBytes, &user.Enabled, &user.Admin, &user.DomainAdmin,
|
||||
&user.MFAEnabled, &createdAt, &lastLogin,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("session lookup: %w", err)
|
||||
}
|
||||
if createdAt.Valid {
|
||||
user.CreatedAt = createdAt.Time
|
||||
}
|
||||
if lastLogin.Valid {
|
||||
user.LastLogin = lastLogin.Time
|
||||
}
|
||||
|
||||
return &sess, &user, nil
|
||||
}
|
||||
|
||||
// Destroy deletes the session from the database and clears the cookie.
|
||||
func (s *SessionStore) Destroy(w http.ResponseWriter, r *http.Request) error {
|
||||
cookie, err := r.Cookie(sessionCookieName)
|
||||
if err != nil {
|
||||
return nil // no session to destroy
|
||||
}
|
||||
|
||||
hash := crypto.HashToken(cookie.Value)
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
_, err = s.db.SQL().ExecContext(ctx,
|
||||
"DELETE FROM sessions WHERE token_hash = ?", hash)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete session: %w", err)
|
||||
}
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: sessionCookieName,
|
||||
Value: "",
|
||||
Path: "/",
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
Secure: s.secureCookie,
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// DestroyAll deletes all sessions for a user (logout everywhere).
|
||||
func (s *SessionStore) DestroyAll(ctx context.Context, userID int64) error {
|
||||
_, err := s.db.SQL().ExecContext(ctx,
|
||||
"DELETE FROM sessions WHERE user_id = ?", userID)
|
||||
return err
|
||||
}
|
||||
|
||||
// PurgeExpired deletes all expired sessions. Call periodically (e.g. hourly).
|
||||
func (s *SessionStore) PurgeExpired(ctx context.Context) error {
|
||||
_, err := s.db.SQL().ExecContext(ctx,
|
||||
"DELETE FROM sessions WHERE expires_at <= ?", time.Now().UTC())
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateLastLogin updates the user's last_login timestamp.
|
||||
func UpdateLastLogin(ctx context.Context, database *db.DB, userID int64) {
|
||||
database.SQL().ExecContext(ctx, // nolint: errcheck — best effort
|
||||
"UPDATE users SET last_login = ? WHERE id = ?",
|
||||
time.Now().UTC(), userID)
|
||||
}
|
||||
|
||||
// ---- User lookup ----
|
||||
|
||||
// GetUserByEmail returns the user matching email (case-insensitive), or nil.
|
||||
func GetUserByEmail(ctx context.Context, database *db.DB, email string) (*models.User, error) {
|
||||
row := database.SQL().QueryRowContext(ctx, `
|
||||
SELECT id, domain_id, username, email, password_hash, display_name,
|
||||
quota_bytes, used_bytes, enabled, admin, domain_admin,
|
||||
mfa_secret_enc, mfa_enabled, recovery_codes_enc,
|
||||
created_at, last_login
|
||||
FROM users
|
||||
WHERE lower(email) = lower(?)`, email)
|
||||
|
||||
return scanUser(row)
|
||||
}
|
||||
|
||||
// GetUserByID returns the user by ID, or nil.
|
||||
func GetUserByID(ctx context.Context, database *db.DB, id int64) (*models.User, error) {
|
||||
row := database.SQL().QueryRowContext(ctx, `
|
||||
SELECT id, domain_id, username, email, password_hash, display_name,
|
||||
quota_bytes, used_bytes, enabled, admin, domain_admin,
|
||||
mfa_secret_enc, mfa_enabled, recovery_codes_enc,
|
||||
created_at, last_login
|
||||
FROM users
|
||||
WHERE id = ?`, id)
|
||||
|
||||
return scanUser(row)
|
||||
}
|
||||
|
||||
func scanUser(row *sql.Row) (*models.User, error) {
|
||||
var u models.User
|
||||
var mfaSecretEnc, recoveryCodesEnc []byte
|
||||
var lastLogin sql.NullTime
|
||||
|
||||
err := row.Scan(
|
||||
&u.ID, &u.DomainID, &u.Username, &u.Email, &u.PasswordHash,
|
||||
&u.DisplayName, &u.QuotaBytes, &u.UsedBytes, &u.Enabled,
|
||||
&u.Admin, &u.DomainAdmin,
|
||||
&mfaSecretEnc, &u.MFAEnabled, &recoveryCodesEnc,
|
||||
&u.CreatedAt, &lastLogin,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("scan user: %w", err)
|
||||
}
|
||||
u.MFASecretEnc = mfaSecretEnc
|
||||
u.RecoveryCodesEnc = recoveryCodesEnc
|
||||
if lastLogin.Valid {
|
||||
u.LastLogin = lastLogin.Time
|
||||
}
|
||||
return &u, nil
|
||||
}
|
||||
|
||||
// ---- Helpers ----
|
||||
|
||||
func realIP(r *http.Request) string {
|
||||
if ip := r.Header.Get("X-Real-IP"); ip != "" {
|
||||
return ip
|
||||
}
|
||||
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
|
||||
return ip
|
||||
}
|
||||
host, _, err := parseHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
return r.RemoteAddr
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
func parseHostPort(addr string) (string, string, error) {
|
||||
// net.SplitHostPort returns error for bare IPs without port.
|
||||
for i := len(addr) - 1; i >= 0; i-- {
|
||||
if addr[i] == ':' {
|
||||
return addr[:i], addr[i+1:], nil
|
||||
}
|
||||
}
|
||||
return addr, "", fmt.Errorf("no port")
|
||||
}
|
||||
|
||||
func truncate(s string, max int) string {
|
||||
if len(s) <= max {
|
||||
return s
|
||||
}
|
||||
return s[:max]
|
||||
}
|
||||
@@ -0,0 +1,773 @@
|
||||
// Package config loads and auto-generates app_config.conf.
|
||||
// INI-style: KEY = value, # comments, blank lines ignored.
|
||||
// Missing keys appended on each startup — existing values preserved.
|
||||
// Env var MAILGO_<KEY> overrides file value.
|
||||
package config
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const ConfigPath = "./app_config.conf"
|
||||
|
||||
// Config holds all runtime configuration.
|
||||
type Config struct {
|
||||
// Identity
|
||||
Hostname string // FQDN for SMTP HELO, TLS SNI, URL building
|
||||
DefaultDomain string // Primary mail domain
|
||||
|
||||
// Network — SMTP
|
||||
SMTPIface string
|
||||
SMTPPort int
|
||||
SubmitIface string
|
||||
SubmitPort int // 587 STARTTLS
|
||||
SMTPSPort int // 465 implicit TLS
|
||||
SMTPEnabled bool
|
||||
SubmitEnabled bool
|
||||
SMTPSEnabled bool
|
||||
|
||||
// Network — IMAP
|
||||
IMAPIface string
|
||||
IMAPPort int // 143 STARTTLS
|
||||
IMAPSPort int // 993 implicit TLS
|
||||
IMAPEnabled bool
|
||||
IMAPSEnabled bool
|
||||
|
||||
// Network — Web
|
||||
WebClientIface string
|
||||
WebClientPort int
|
||||
WebAdminIface string
|
||||
WebAdminPort int
|
||||
CalDAVIface string
|
||||
CalDAVPort int
|
||||
|
||||
// TLS
|
||||
TLSMode string // dns01 | http01 | manual | off
|
||||
TLSCert string // manual: path to cert.pem
|
||||
TLSKey string // manual: path to key.pem
|
||||
// Per-service overrides (empty = use global TLS)
|
||||
SMTPTLSCert string
|
||||
SMTPTLSKey string
|
||||
IMAPTLSCert string
|
||||
IMAPTLSKey string
|
||||
WebTLSCert string
|
||||
WebTLSKey string
|
||||
|
||||
// ACME (dns01 / http01)
|
||||
ACMEEmail string
|
||||
ACMECacheDir string
|
||||
ACMEStaging bool
|
||||
ACMEDomains []string // domains to certify; empty = just Hostname
|
||||
|
||||
// DNS-01 provider
|
||||
ACMEDNSProvider string // cloudflare | route53 | digitalocean | hetzner | ...
|
||||
// Cloudflare
|
||||
CFDNSAPIToken string
|
||||
CFAPIKey string
|
||||
CFAPIEmail string
|
||||
// Route53
|
||||
AWSRegion string
|
||||
AWSAccessKeyID string
|
||||
AWSSecretAccessKey string
|
||||
AWSHostedZoneID string
|
||||
// DigitalOcean
|
||||
DOAuthToken string
|
||||
// Hetzner
|
||||
HetznerAPIKey string
|
||||
// Generic: any additional lego env vars should be exported before starting.
|
||||
|
||||
// Secrets (auto-generated on first run — BACK UP app_config.conf)
|
||||
EncryptionKey []byte // 32 bytes, AES-256 master key
|
||||
SessionSecret []byte // session cookie signing
|
||||
|
||||
// Database
|
||||
DBDriver string // sqlite | postgres | mysql | mssql
|
||||
DBPath string // SQLite path
|
||||
DBDSN string // PostgreSQL/MySQL/MSSQL DSN
|
||||
|
||||
// Storage
|
||||
StorageBackend string // db | fs
|
||||
StorageFSPath string // base path for fs storage
|
||||
|
||||
// Security
|
||||
MaxMessageSize int64
|
||||
SessionMaxAge int // seconds
|
||||
BruteMaxTries int
|
||||
BruteWindowMin int
|
||||
BruteBanHours int
|
||||
TrustedProxies []net.IPNet
|
||||
SecureCookie bool
|
||||
BruteWhitelist []net.IP
|
||||
|
||||
// SMTP server
|
||||
SMTPHostname string // override for SMTP HELO (defaults to Hostname)
|
||||
MaxRcptPer int // max recipients per message
|
||||
QueueMaxAgeHours int
|
||||
QueueRetryMins []int // backoff schedule
|
||||
DNSPrimary string
|
||||
DNSSecondary string
|
||||
|
||||
// DKIM
|
||||
DKIMSelector string
|
||||
DKIMAlgo string // rsa2048 | ed25519
|
||||
|
||||
// Spam
|
||||
SpamThreshold int
|
||||
SpamDNSBL []string
|
||||
SpamCheckSPF bool
|
||||
SpamCheckDKIM bool
|
||||
SpamCheckDMARC bool
|
||||
|
||||
// OAuth2 (external accounts)
|
||||
GoogleClientID string
|
||||
GoogleClientSecret string
|
||||
MicrosoftClientID string
|
||||
MicrosoftClientSecret string
|
||||
MicrosoftTenantID string
|
||||
|
||||
// Debug
|
||||
Debug bool
|
||||
LogFile string
|
||||
LogLevel string // debug | info | warn | error
|
||||
}
|
||||
|
||||
// field drives both config file generation and value parsing.
|
||||
type field struct {
|
||||
key string
|
||||
defVal string
|
||||
comments []string
|
||||
secret bool // true = never shown in logs
|
||||
}
|
||||
|
||||
var allFields = []field{
|
||||
// --- Identity ---
|
||||
{key: "HOSTNAME", defVal: "mail.example.com", comments: []string{
|
||||
"--- Server Identity ---",
|
||||
"FQDN used for SMTP HELO/EHLO, TLS SNI, and URL building.",
|
||||
"Must resolve in DNS if using TLS_MODE=autocert/dns01/http01.",
|
||||
}},
|
||||
{key: "DEFAULT_DOMAIN", defVal: "example.com", comments: []string{
|
||||
"Primary mail domain served by this instance.",
|
||||
}},
|
||||
|
||||
// --- SMTP ---
|
||||
{key: "SMTP_IFACE", defVal: "0.0.0.0", comments: []string{
|
||||
"--- SMTP Server (Inbound MTA) ---",
|
||||
"Network interface to bind SMTP port 25.",
|
||||
}},
|
||||
{key: "SMTP_PORT", defVal: "25"},
|
||||
{key: "SMTP_ENABLED", defVal: "true"},
|
||||
{key: "SUBMIT_IFACE", defVal: "0.0.0.0", comments: []string{
|
||||
"--- SMTP Submission (Authenticated Send) ---",
|
||||
}},
|
||||
{key: "SUBMIT_PORT", defVal: "587", comments: []string{"STARTTLS mandatory on this port."}},
|
||||
{key: "SUBMIT_ENABLED", defVal: "true"},
|
||||
{key: "SMTPS_PORT", defVal: "465", comments: []string{"Implicit TLS SMTP submission."}},
|
||||
{key: "SMTPS_ENABLED", defVal: "true"},
|
||||
|
||||
// --- IMAP ---
|
||||
{key: "IMAP_IFACE", defVal: "0.0.0.0", comments: []string{"--- IMAP Server ---"}},
|
||||
{key: "IMAP_PORT", defVal: "143"},
|
||||
{key: "IMAP_ENABLED", defVal: "true"},
|
||||
{key: "IMAPS_PORT", defVal: "993"},
|
||||
{key: "IMAPS_ENABLED", defVal: "true"},
|
||||
|
||||
// --- Web ---
|
||||
{key: "WEBCLIENT_IFACE", defVal: "0.0.0.0", comments: []string{"--- Web Client ---"}},
|
||||
{key: "WEBCLIENT_PORT", defVal: "8080"},
|
||||
{key: "WEBADMIN_IFACE", defVal: "127.0.0.1", comments: []string{
|
||||
"--- Web Admin ---",
|
||||
"Default: loopback only. Change to 0.0.0.0 only behind a reverse proxy with auth.",
|
||||
}},
|
||||
{key: "WEBADMIN_PORT", defVal: "8081"},
|
||||
{key: "CALDAV_IFACE", defVal: "0.0.0.0", comments: []string{"--- CalDAV / CardDAV ---"}},
|
||||
{key: "CALDAV_PORT", defVal: "5232"},
|
||||
|
||||
// --- TLS ---
|
||||
{key: "TLS_MODE", defVal: "dns01", comments: []string{
|
||||
"--- TLS Configuration ---",
|
||||
" dns01 = Let's Encrypt via DNS-01 (no open ports, wildcard support) [RECOMMENDED]",
|
||||
" http01 = Let's Encrypt via HTTP-01 (port 80 must be reachable, no wildcards)",
|
||||
" manual = Provide TLS_CERT + TLS_KEY paths",
|
||||
" off = No TLS (use ONLY behind a TLS-terminating reverse proxy)",
|
||||
}},
|
||||
{key: "TLS_CERT", defVal: "./certs/cert.pem", comments: []string{
|
||||
"Path to certificate file (PEM). Used when TLS_MODE=manual.",
|
||||
"Also used as per-service fallback if SMTP_TLS_CERT etc. are not set.",
|
||||
}},
|
||||
{key: "TLS_KEY", defVal: "./certs/key.pem"},
|
||||
{key: "SMTP_TLS_CERT", defVal: "", comments: []string{"Override TLS cert/key for SMTP services (blank = use global TLS)."}},
|
||||
{key: "SMTP_TLS_KEY", defVal: ""},
|
||||
{key: "IMAP_TLS_CERT", defVal: "", comments: []string{"Override TLS cert/key for IMAP services."}},
|
||||
{key: "IMAP_TLS_KEY", defVal: ""},
|
||||
{key: "WEB_TLS_CERT", defVal: "", comments: []string{"Override TLS cert/key for web/CalDAV ports."}},
|
||||
{key: "WEB_TLS_KEY", defVal: ""},
|
||||
|
||||
// --- ACME ---
|
||||
{key: "ACME_EMAIL", defVal: "", comments: []string{
|
||||
"--- ACME / Let's Encrypt ---",
|
||||
"Email for Let's Encrypt account registration and renewal notices. Required.",
|
||||
}},
|
||||
{key: "ACME_CACHE_DIR", defVal: "./acme-cache", comments: []string{
|
||||
"Directory to cache ACME account data and certificates.",
|
||||
}},
|
||||
{key: "ACME_STAGING", defVal: "false", comments: []string{
|
||||
"Use Let's Encrypt staging server (rate-limit-free testing). Set false for production.",
|
||||
}},
|
||||
{key: "ACME_DOMAINS", defVal: "", comments: []string{
|
||||
"Comma-separated domains to include in the certificate.",
|
||||
"Example: example.com,*.example.com,mail.example.com",
|
||||
"Blank = use HOSTNAME only. Include wildcard for full coverage.",
|
||||
}},
|
||||
{key: "ACME_DNS_PROVIDER", defVal: "cloudflare", comments: []string{
|
||||
"DNS provider for DNS-01 challenge (TLS_MODE=dns01).",
|
||||
"Supported: cloudflare | route53 | digitalocean | hetzner | ovh | porkbun |",
|
||||
" namecheap | gandi | desec | acmedns | godaddy | ... (90+ providers)",
|
||||
"Full list: https://go-acme.github.io/lego/dns/",
|
||||
}},
|
||||
|
||||
// --- Cloudflare ---
|
||||
{key: "CF_DNS_API_TOKEN", defVal: "", secret: true, comments: []string{
|
||||
"--- Cloudflare DNS-01 (ACME_DNS_PROVIDER=cloudflare) ---",
|
||||
"API Token with Zone.DNS:Edit permission on target zone(s).",
|
||||
"Preferred over CF_API_KEY. Create at: https://dash.cloudflare.com/profile/api-tokens",
|
||||
}},
|
||||
{key: "CF_API_KEY", defVal: "", secret: true, comments: []string{"Global API Key (alternative to CF_DNS_API_TOKEN)."}},
|
||||
{key: "CF_API_EMAIL", defVal: "", comments: []string{"Account email (required with CF_API_KEY, not needed with CF_DNS_API_TOKEN)."}},
|
||||
|
||||
// --- Route53 ---
|
||||
{key: "AWS_REGION", defVal: "", comments: []string{"--- AWS Route53 (ACME_DNS_PROVIDER=route53) ---"}},
|
||||
{key: "AWS_ACCESS_KEY_ID", defVal: "", secret: true},
|
||||
{key: "AWS_SECRET_ACCESS_KEY", defVal: "", secret: true},
|
||||
{key: "AWS_HOSTED_ZONE_ID", defVal: "", comments: []string{"Optional: skip auto-detection."}},
|
||||
|
||||
// --- DigitalOcean ---
|
||||
{key: "DO_AUTH_TOKEN", defVal: "", secret: true, comments: []string{"--- DigitalOcean (ACME_DNS_PROVIDER=digitalocean) ---"}},
|
||||
|
||||
// --- Hetzner ---
|
||||
{key: "HETZNER_API_KEY", defVal: "", secret: true, comments: []string{"--- Hetzner DNS (ACME_DNS_PROVIDER=hetzner) ---"}},
|
||||
|
||||
// --- Secrets ---
|
||||
{key: "ENCRYPTION_KEY", defVal: "", secret: true, comments: []string{
|
||||
"--- Secrets (auto-generated — BACK UP this file!) ---",
|
||||
"AES-256 master key for all data at rest (emails, tokens, keys, contacts, calendar).",
|
||||
"64 hex characters = 32 bytes. Losing this key = permanent data loss.",
|
||||
}},
|
||||
{key: "SESSION_SECRET", defVal: "", secret: true, comments: []string{
|
||||
"Session cookie signing secret. Changing this logs out all users.",
|
||||
}},
|
||||
|
||||
// --- Database ---
|
||||
{key: "DB_DRIVER", defVal: "sqlite", comments: []string{
|
||||
"--- Database ---",
|
||||
"Database driver: sqlite | postgres | mysql | mssql",
|
||||
}},
|
||||
{key: "DB_PATH", defVal: "./data/mail.db", comments: []string{"SQLite database path (DB_DRIVER=sqlite)."}},
|
||||
{key: "DB_DSN", defVal: "", secret: true, comments: []string{
|
||||
"Connection string for PostgreSQL/MySQL/MSSQL.",
|
||||
" PostgreSQL: host=localhost port=5432 user=mail password=secret dbname=mail sslmode=require",
|
||||
" MySQL: mail:secret@tcp(localhost:3306)/mail?tls=true",
|
||||
" MSSQL: sqlserver://mail:secret@localhost?database=mail",
|
||||
}},
|
||||
|
||||
// --- Storage ---
|
||||
{key: "STORAGE_BACKEND", defVal: "db", comments: []string{
|
||||
"--- Email Storage ---",
|
||||
" db = Store encrypted message blobs in database (simple, single backup file)",
|
||||
" fs = Store encrypted files on filesystem, metadata in DB (better for large attachments)",
|
||||
}},
|
||||
{key: "STORAGE_FS_PATH", defVal: "./data/messages", comments: []string{
|
||||
"Base directory for filesystem storage (STORAGE_BACKEND=fs).",
|
||||
}},
|
||||
|
||||
// --- Security ---
|
||||
{key: "MAX_MESSAGE_SIZE", defVal: "52428800", comments: []string{
|
||||
"--- Security ---",
|
||||
"Maximum accepted email size in bytes (default 50 MB).",
|
||||
}},
|
||||
{key: "SESSION_MAX_AGE", defVal: "604800", comments: []string{"Session lifetime in seconds (default 7 days)."}},
|
||||
{key: "BRUTE_MAX_TRIES", defVal: "5"},
|
||||
{key: "BRUTE_WINDOW_MIN", defVal: "30"},
|
||||
{key: "BRUTE_BAN_HOURS", defVal: "24"},
|
||||
{key: "BRUTE_WHITELIST_IPS", defVal: "", comments: []string{"Comma-separated IPs exempt from brute-force banning."}},
|
||||
{key: "TRUSTED_PROXIES", defVal: "", comments: []string{
|
||||
"Comma-separated CIDR ranges of trusted reverse proxies.",
|
||||
"Only these may set X-Forwarded-For / X-Forwarded-Proto headers.",
|
||||
}},
|
||||
{key: "SECURE_COOKIE", defVal: "false", comments: []string{
|
||||
"Mark session cookies Secure. Set true when serving over HTTPS.",
|
||||
"Auto-enabled when BASE_URL starts with https://",
|
||||
}},
|
||||
|
||||
// --- SMTP tuning ---
|
||||
{key: "SMTP_HOSTNAME", defVal: "", comments: []string{
|
||||
"--- SMTP Tuning ---",
|
||||
"SMTP HELO/EHLO hostname override. Blank = use HOSTNAME.",
|
||||
}},
|
||||
{key: "MAX_RCPT_PER", defVal: "100", comments: []string{"Maximum recipients per message."}},
|
||||
{key: "QUEUE_MAX_AGE_HOURS", defVal: "72", comments: []string{"Queue age before bounce (hours)."}},
|
||||
{key: "QUEUE_RETRY_MINS", defVal: "5,15,60,240,480", comments: []string{"Retry backoff schedule (minutes between attempts)."}},
|
||||
{key: "DNS_PRIMARY", defVal: "1.1.1.1"},
|
||||
{key: "DNS_SECONDARY", defVal: "8.8.8.8"},
|
||||
|
||||
// --- DKIM ---
|
||||
{key: "DKIM_SELECTOR", defVal: "mail", comments: []string{
|
||||
"--- DKIM ---",
|
||||
"Default DKIM selector for new domains.",
|
||||
}},
|
||||
{key: "DKIM_ALGO", defVal: "rsa2048", comments: []string{"Key algorithm: rsa2048 | ed25519"}},
|
||||
|
||||
// --- Spam ---
|
||||
{key: "SPAM_THRESHOLD", defVal: "10", comments: []string{
|
||||
"--- Spam Filtering ---",
|
||||
"Messages with spam score >= threshold delivered to Spam folder.",
|
||||
}},
|
||||
{key: "SPAM_DNSBL", defVal: "zen.spamhaus.org,bl.spamcop.net"},
|
||||
{key: "SPAM_CHECK_SPF", defVal: "true"},
|
||||
{key: "SPAM_CHECK_DKIM", defVal: "true"},
|
||||
{key: "SPAM_CHECK_DMARC", defVal: "true"},
|
||||
|
||||
// --- OAuth2 ---
|
||||
{key: "GOOGLE_CLIENT_ID", defVal: "", comments: []string{
|
||||
"--- Google OAuth2 (external Gmail accounts) ---",
|
||||
"Create at: https://console.cloud.google.com/apis/credentials",
|
||||
"Required scope: https://mail.google.com/",
|
||||
}},
|
||||
{key: "GOOGLE_CLIENT_SECRET", defVal: "", secret: true},
|
||||
{key: "MICROSOFT_CLIENT_ID", defVal: "", comments: []string{
|
||||
"--- Microsoft OAuth2 (external Outlook accounts) ---",
|
||||
"Register at: https://portal.azure.com/#blade/Microsoft_AAD_RegisteredApps",
|
||||
}},
|
||||
{key: "MICROSOFT_CLIENT_SECRET", defVal: "", secret: true},
|
||||
{key: "MICROSOFT_TENANT_ID", defVal: "consumers"},
|
||||
|
||||
// --- Debug ---
|
||||
{key: "DEBUG", defVal: "false", comments: []string{"--- Debug ---"}},
|
||||
{key: "LOG_FILE", defVal: "./logs/mail.log"},
|
||||
{key: "LOG_LEVEL", defVal: "info", comments: []string{"Log level: debug | info | warn | error"}},
|
||||
}
|
||||
|
||||
// Load reads app_config.conf, generates it if missing, returns populated Config.
|
||||
func Load() (*Config, error) {
|
||||
if err := os.MkdirAll("./data", 0700); err != nil {
|
||||
return nil, fmt.Errorf("create data dir: %w", err)
|
||||
}
|
||||
if err := os.MkdirAll("./logs", 0700); err != nil {
|
||||
return nil, fmt.Errorf("create logs dir: %w", err)
|
||||
}
|
||||
|
||||
existing, err := readFile(ConfigPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Auto-generate secrets if absent.
|
||||
if existing["ENCRYPTION_KEY"] == "" {
|
||||
existing["ENCRYPTION_KEY"] = mustHex(32)
|
||||
fmt.Fprintln(os.Stderr, "[mailgosend] WARNING: Generated new ENCRYPTION_KEY — back up app_config.conf immediately!")
|
||||
}
|
||||
if existing["SESSION_SECRET"] == "" {
|
||||
existing["SESSION_SECRET"] = mustHex(32)
|
||||
}
|
||||
|
||||
if err := writeFile(ConfigPath, existing); err != nil {
|
||||
return nil, fmt.Errorf("write config: %w", err)
|
||||
}
|
||||
|
||||
get := func(key string) string {
|
||||
if v := os.Getenv("MAILGO_" + key); v != "" {
|
||||
return v
|
||||
}
|
||||
return existing[key]
|
||||
}
|
||||
|
||||
// Decode master encryption key.
|
||||
encHex := get("ENCRYPTION_KEY")
|
||||
encKey, err := hex.DecodeString(encHex)
|
||||
if err != nil || len(encKey) != 32 {
|
||||
return nil, fmt.Errorf("ENCRYPTION_KEY must be 64 hex chars (32 bytes), got %d chars", len(encHex))
|
||||
}
|
||||
sessSecret := get("SESSION_SECRET")
|
||||
if sessSecret == "" {
|
||||
return nil, fmt.Errorf("SESSION_SECRET missing")
|
||||
}
|
||||
|
||||
trustedProxies, err := parseCIDRs(get("TRUSTED_PROXIES"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("TRUSTED_PROXIES: %w", err)
|
||||
}
|
||||
|
||||
acmeDomains := splitTrim(get("ACME_DOMAINS"), ",")
|
||||
retryMins := parseIntList(get("QUEUE_RETRY_MINS"), []int{5, 15, 60, 240, 480})
|
||||
dnsbl := splitTrim(get("SPAM_DNSBL"), ",")
|
||||
|
||||
smtpHostname := get("SMTP_HOSTNAME")
|
||||
if smtpHostname == "" {
|
||||
smtpHostname = get("HOSTNAME")
|
||||
}
|
||||
if smtpHostname == "" {
|
||||
smtpHostname = "localhost"
|
||||
}
|
||||
|
||||
cfg := &Config{
|
||||
Hostname: orDefault(get("HOSTNAME"), "mail.example.com"),
|
||||
DefaultDomain: orDefault(get("DEFAULT_DOMAIN"), "example.com"),
|
||||
|
||||
SMTPIface: orDefault(get("SMTP_IFACE"), "0.0.0.0"),
|
||||
SMTPPort: atoi(get("SMTP_PORT"), 25),
|
||||
SMTPEnabled: atobool(get("SMTP_ENABLED"), true),
|
||||
SubmitIface: orDefault(get("SUBMIT_IFACE"), "0.0.0.0"),
|
||||
SubmitPort: atoi(get("SUBMIT_PORT"), 587),
|
||||
SubmitEnabled: atobool(get("SUBMIT_ENABLED"), true),
|
||||
SMTPSPort: atoi(get("SMTPS_PORT"), 465),
|
||||
SMTPSEnabled: atobool(get("SMTPS_ENABLED"), true),
|
||||
|
||||
IMAPIface: orDefault(get("IMAP_IFACE"), "0.0.0.0"),
|
||||
IMAPPort: atoi(get("IMAP_PORT"), 143),
|
||||
IMAPEnabled: atobool(get("IMAP_ENABLED"), true),
|
||||
IMAPSPort: atoi(get("IMAPS_PORT"), 993),
|
||||
IMAPSEnabled: atobool(get("IMAPS_ENABLED"), true),
|
||||
|
||||
WebClientIface: orDefault(get("WEBCLIENT_IFACE"), "0.0.0.0"),
|
||||
WebClientPort: atoi(get("WEBCLIENT_PORT"), 8080),
|
||||
WebAdminIface: orDefault(get("WEBADMIN_IFACE"), "127.0.0.1"),
|
||||
WebAdminPort: atoi(get("WEBADMIN_PORT"), 8081),
|
||||
CalDAVIface: orDefault(get("CALDAV_IFACE"), "0.0.0.0"),
|
||||
CalDAVPort: atoi(get("CALDAV_PORT"), 5232),
|
||||
|
||||
TLSMode: orDefault(get("TLS_MODE"), "dns01"),
|
||||
TLSCert: get("TLS_CERT"),
|
||||
TLSKey: get("TLS_KEY"),
|
||||
SMTPTLSCert: get("SMTP_TLS_CERT"),
|
||||
SMTPTLSKey: get("SMTP_TLS_KEY"),
|
||||
IMAPTLSCert: get("IMAP_TLS_CERT"),
|
||||
IMAPTLSKey: get("IMAP_TLS_KEY"),
|
||||
WebTLSCert: get("WEB_TLS_CERT"),
|
||||
WebTLSKey: get("WEB_TLS_KEY"),
|
||||
|
||||
ACMEEmail: get("ACME_EMAIL"),
|
||||
ACMECacheDir: orDefault(get("ACME_CACHE_DIR"), "./acme-cache"),
|
||||
ACMEStaging: atobool(get("ACME_STAGING"), false),
|
||||
ACMEDomains: acmeDomains,
|
||||
ACMEDNSProvider: orDefault(get("ACME_DNS_PROVIDER"), "cloudflare"),
|
||||
|
||||
CFDNSAPIToken: get("CF_DNS_API_TOKEN"),
|
||||
CFAPIKey: get("CF_API_KEY"),
|
||||
CFAPIEmail: get("CF_API_EMAIL"),
|
||||
AWSRegion: get("AWS_REGION"),
|
||||
AWSAccessKeyID: get("AWS_ACCESS_KEY_ID"),
|
||||
AWSSecretAccessKey: get("AWS_SECRET_ACCESS_KEY"),
|
||||
AWSHostedZoneID: get("AWS_HOSTED_ZONE_ID"),
|
||||
DOAuthToken: get("DO_AUTH_TOKEN"),
|
||||
HetznerAPIKey: get("HETZNER_API_KEY"),
|
||||
|
||||
EncryptionKey: encKey,
|
||||
SessionSecret: []byte(sessSecret),
|
||||
|
||||
DBDriver: orDefault(get("DB_DRIVER"), "sqlite"),
|
||||
DBPath: orDefault(get("DB_PATH"), "./data/mail.db"),
|
||||
DBDSN: get("DB_DSN"),
|
||||
|
||||
StorageBackend: orDefault(get("STORAGE_BACKEND"), "db"),
|
||||
StorageFSPath: orDefault(get("STORAGE_FS_PATH"), "./data/messages"),
|
||||
|
||||
MaxMessageSize: int64(atoi(get("MAX_MESSAGE_SIZE"), 52428800)),
|
||||
SessionMaxAge: atoi(get("SESSION_MAX_AGE"), 604800),
|
||||
BruteMaxTries: atoi(get("BRUTE_MAX_TRIES"), 5),
|
||||
BruteWindowMin: atoi(get("BRUTE_WINDOW_MIN"), 30),
|
||||
BruteBanHours: atoi(get("BRUTE_BAN_HOURS"), 24),
|
||||
BruteWhitelist: parseIPs(get("BRUTE_WHITELIST_IPS")),
|
||||
TrustedProxies: trustedProxies,
|
||||
SecureCookie: atobool(get("SECURE_COOKIE"), false),
|
||||
|
||||
SMTPHostname: smtpHostname,
|
||||
MaxRcptPer: atoi(get("MAX_RCPT_PER"), 100),
|
||||
QueueMaxAgeHours: atoi(get("QUEUE_MAX_AGE_HOURS"), 72),
|
||||
QueueRetryMins: retryMins,
|
||||
DNSPrimary: orDefault(get("DNS_PRIMARY"), "1.1.1.1"),
|
||||
DNSSecondary: orDefault(get("DNS_SECONDARY"), "8.8.8.8"),
|
||||
|
||||
DKIMSelector: orDefault(get("DKIM_SELECTOR"), "mail"),
|
||||
DKIMAlgo: orDefault(get("DKIM_ALGO"), "rsa2048"),
|
||||
|
||||
SpamThreshold: atoi(get("SPAM_THRESHOLD"), 10),
|
||||
SpamDNSBL: dnsbl,
|
||||
SpamCheckSPF: atobool(get("SPAM_CHECK_SPF"), true),
|
||||
SpamCheckDKIM: atobool(get("SPAM_CHECK_DKIM"), true),
|
||||
SpamCheckDMARC: atobool(get("SPAM_CHECK_DMARC"), true),
|
||||
|
||||
GoogleClientID: get("GOOGLE_CLIENT_ID"),
|
||||
GoogleClientSecret: get("GOOGLE_CLIENT_SECRET"),
|
||||
MicrosoftClientID: get("MICROSOFT_CLIENT_ID"),
|
||||
MicrosoftClientSecret: get("MICROSOFT_CLIENT_SECRET"),
|
||||
MicrosoftTenantID: orDefault(get("MICROSOFT_TENANT_ID"), "consumers"),
|
||||
|
||||
Debug: atobool(get("DEBUG"), false),
|
||||
LogFile: orDefault(get("LOG_FILE"), "./logs/mail.log"),
|
||||
LogLevel: orDefault(get("LOG_LEVEL"), "info"),
|
||||
}
|
||||
|
||||
// Export provider-specific env vars so lego can pick them up.
|
||||
cfg.exportProviderEnv()
|
||||
|
||||
logStartup(cfg)
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// exportProviderEnv sets standard lego env vars from config values.
|
||||
// This allows using app_config.conf as the single source of truth.
|
||||
func (c *Config) exportProviderEnv() {
|
||||
setenv := func(key, val string) {
|
||||
if val != "" && os.Getenv(key) == "" {
|
||||
os.Setenv(key, val) //nolint:errcheck
|
||||
}
|
||||
}
|
||||
setenv("CF_DNS_API_TOKEN", c.CFDNSAPIToken)
|
||||
setenv("CF_API_KEY", c.CFAPIKey)
|
||||
setenv("CF_API_EMAIL", c.CFAPIEmail)
|
||||
setenv("AWS_REGION", c.AWSRegion)
|
||||
setenv("AWS_ACCESS_KEY_ID", c.AWSAccessKeyID)
|
||||
setenv("AWS_SECRET_ACCESS_KEY", c.AWSSecretAccessKey)
|
||||
setenv("AWS_HOSTED_ZONE_ID", c.AWSHostedZoneID)
|
||||
setenv("DO_AUTH_TOKEN", c.DOAuthToken)
|
||||
setenv("HETZNER_API_KEY", c.HetznerAPIKey)
|
||||
}
|
||||
|
||||
// ACMEDomainList returns ACME_DOMAINS, falling back to Hostname.
|
||||
func (c *Config) ACMEDomainList() []string {
|
||||
if len(c.ACMEDomains) > 0 {
|
||||
return c.ACMEDomains
|
||||
}
|
||||
return []string{c.Hostname}
|
||||
}
|
||||
|
||||
// RealIP extracts the client IP, honouring X-Forwarded-For from trusted proxies.
|
||||
func (c *Config) RealIP(remoteAddr, xForwardedFor string) string {
|
||||
remoteIP, _, err := net.SplitHostPort(remoteAddr)
|
||||
if err != nil {
|
||||
remoteIP = remoteAddr
|
||||
}
|
||||
if xForwardedFor == "" || !c.isTrustedProxy(remoteIP) {
|
||||
return remoteIP
|
||||
}
|
||||
parts := strings.Split(xForwardedFor, ",")
|
||||
if len(parts) > 0 {
|
||||
if ip := strings.TrimSpace(parts[0]); net.ParseIP(ip) != nil {
|
||||
return ip
|
||||
}
|
||||
}
|
||||
return remoteIP
|
||||
}
|
||||
|
||||
func (c *Config) isTrustedProxy(ipStr string) bool {
|
||||
ip := net.ParseIP(ipStr)
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
for _, cidr := range c.TrustedProxies {
|
||||
if cidr.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *Config) IsIPWhitelisted(ip string) bool {
|
||||
parsed := net.ParseIP(ip)
|
||||
if parsed == nil {
|
||||
return false
|
||||
}
|
||||
for _, w := range c.BruteWhitelist {
|
||||
if w.Equal(parsed) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ---- Config file I/O ----
|
||||
|
||||
func readFile(path string) (map[string]string, error) {
|
||||
vals := make(map[string]string)
|
||||
f, err := os.Open(path)
|
||||
if os.IsNotExist(err) {
|
||||
return vals, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open config: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
sc := bufio.NewScanner(f)
|
||||
for sc.Scan() {
|
||||
line := strings.TrimSpace(sc.Text())
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
idx := strings.IndexByte(line, '=')
|
||||
if idx < 0 {
|
||||
continue
|
||||
}
|
||||
k := strings.TrimSpace(line[:idx])
|
||||
v := strings.TrimSpace(line[idx+1:])
|
||||
vals[k] = v
|
||||
}
|
||||
return vals, sc.Err()
|
||||
}
|
||||
|
||||
func writeFile(path string, existing map[string]string) error {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("# mailgosend Configuration\n")
|
||||
sb.WriteString("# =========================\n")
|
||||
sb.WriteString("# Auto-generated and updated on each startup.\n")
|
||||
sb.WriteString("# Edit freely — your values are always preserved.\n")
|
||||
sb.WriteString("# Override any key via env var: MAILGO_<KEY>=value\n")
|
||||
sb.WriteString("#\n\n")
|
||||
|
||||
for _, f := range allFields {
|
||||
for _, c := range f.comments {
|
||||
if c == "" {
|
||||
sb.WriteString("#\n")
|
||||
} else {
|
||||
sb.WriteString("# " + c + "\n")
|
||||
}
|
||||
}
|
||||
v := existing[f.key]
|
||||
if v == "" {
|
||||
v = f.defVal
|
||||
}
|
||||
sb.WriteString(f.key + " = " + v + "\n\n")
|
||||
}
|
||||
return os.WriteFile(path, []byte(sb.String()), 0600)
|
||||
}
|
||||
|
||||
// ---- Startup log ----
|
||||
|
||||
func logStartup(c *Config) {
|
||||
fmt.Printf("mailgosend starting\n")
|
||||
fmt.Printf(" Hostname : %s\n", c.Hostname)
|
||||
fmt.Printf(" Default domain: %s\n", c.DefaultDomain)
|
||||
fmt.Printf(" TLS mode : %s\n", c.TLSMode)
|
||||
if c.TLSMode == "dns01" || c.TLSMode == "http01" {
|
||||
fmt.Printf(" ACME provider: %s\n", c.ACMEDNSProvider)
|
||||
fmt.Printf(" ACME domains : %v\n", c.ACMEDomainList())
|
||||
}
|
||||
fmt.Printf(" DB driver : %s\n", c.DBDriver)
|
||||
if c.SMTPEnabled {
|
||||
fmt.Printf(" SMTP : %s:%d\n", c.SMTPIface, c.SMTPPort)
|
||||
}
|
||||
if c.SubmitEnabled {
|
||||
fmt.Printf(" Submission : %s:%d (STARTTLS)\n", c.SubmitIface, c.SubmitPort)
|
||||
}
|
||||
if c.SMTPSEnabled {
|
||||
fmt.Printf(" SMTPS : %s:%d (TLS)\n", c.SMTPIface, c.SMTPSPort)
|
||||
}
|
||||
if c.IMAPEnabled {
|
||||
fmt.Printf(" IMAP : %s:%d (STARTTLS)\n", c.IMAPIface, c.IMAPPort)
|
||||
}
|
||||
if c.IMAPSEnabled {
|
||||
fmt.Printf(" IMAPS : %s:%d (TLS)\n", c.IMAPIface, c.IMAPSPort)
|
||||
}
|
||||
fmt.Printf(" Web client : %s:%d\n", c.WebClientIface, c.WebClientPort)
|
||||
fmt.Printf(" Web admin : %s:%d\n", c.WebAdminIface, c.WebAdminPort)
|
||||
fmt.Printf(" CalDAV : %s:%d\n", c.CalDAVIface, c.CalDAVPort)
|
||||
}
|
||||
|
||||
// ---- Helpers ----
|
||||
|
||||
func mustHex(n int) string {
|
||||
b := make([]byte, n)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
panic("crypto/rand unavailable: " + err.Error())
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
func atoi(s string, fallback int) int {
|
||||
if v, err := strconv.Atoi(s); err == nil {
|
||||
return v
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func atobool(s string, fallback bool) bool {
|
||||
if v, err := strconv.ParseBool(s); err == nil {
|
||||
return v
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func orDefault(s, def string) string {
|
||||
if s == "" {
|
||||
return def
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func splitTrim(s, sep string) []string {
|
||||
var out []string
|
||||
for _, p := range strings.Split(s, sep) {
|
||||
p = strings.TrimSpace(p)
|
||||
if p != "" {
|
||||
out = append(out, p)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func parseIntList(s string, fallback []int) []int {
|
||||
parts := splitTrim(s, ",")
|
||||
if len(parts) == 0 {
|
||||
return fallback
|
||||
}
|
||||
out := make([]int, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
if v, err := strconv.Atoi(p); err == nil {
|
||||
out = append(out, v)
|
||||
}
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return fallback
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func parseCIDRs(s string) ([]net.IPNet, error) {
|
||||
var nets []net.IPNet
|
||||
for _, raw := range splitTrim(s, ",") {
|
||||
if !strings.Contains(raw, "/") {
|
||||
ip := net.ParseIP(raw)
|
||||
if ip == nil {
|
||||
return nil, fmt.Errorf("invalid IP %q", raw)
|
||||
}
|
||||
bits := 32
|
||||
if ip.To4() == nil {
|
||||
bits = 128
|
||||
}
|
||||
raw = fmt.Sprintf("%s/%d", ip.String(), bits)
|
||||
}
|
||||
_, ipNet, err := net.ParseCIDR(raw)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid CIDR %q: %w", raw, err)
|
||||
}
|
||||
nets = append(nets, *ipNet)
|
||||
}
|
||||
return nets, nil
|
||||
}
|
||||
|
||||
func parseIPs(s string) []net.IP {
|
||||
var ips []net.IP
|
||||
for _, raw := range splitTrim(s, ",") {
|
||||
if ip := net.ParseIP(raw); ip != nil {
|
||||
ips = append(ips, ip)
|
||||
}
|
||||
}
|
||||
return ips
|
||||
}
|
||||
@@ -0,0 +1,216 @@
|
||||
// Package crypto provides AES-256-GCM encryption, HKDF key derivation,
|
||||
// bcrypt helpers, and secure random utilities.
|
||||
// All encryption uses authenticated encryption — tampering is detected.
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"golang.org/x/crypto/hkdf"
|
||||
)
|
||||
|
||||
const (
|
||||
// BcryptCost is the minimum bcrypt work factor.
|
||||
BcryptCost = 12
|
||||
|
||||
// keyLen is AES-256 key length in bytes.
|
||||
keyLen = 32
|
||||
|
||||
// gcmNonceLen is the standard GCM nonce size.
|
||||
gcmNonceLen = 12
|
||||
)
|
||||
|
||||
// Crypto holds the master encryption key.
|
||||
// One instance per application, injected everywhere that needs encryption.
|
||||
type Crypto struct {
|
||||
masterKey [keyLen]byte
|
||||
}
|
||||
|
||||
// New creates a Crypto instance from a 32-byte master key.
|
||||
func New(masterKey []byte) (*Crypto, error) {
|
||||
if len(masterKey) != keyLen {
|
||||
return nil, fmt.Errorf("crypto: master key must be %d bytes, got %d", keyLen, len(masterKey))
|
||||
}
|
||||
c := &Crypto{}
|
||||
copy(c.masterKey[:], masterKey)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// DeriveKey returns a unique 32-byte AES-256 key for a given purpose + userID.
|
||||
// Uses HKDF-SHA256 so each (purpose, userID) pair gets a unique subkey,
|
||||
// and the master key is never used directly for encryption.
|
||||
func (c *Crypto) DeriveKey(purpose string, userID int64) ([keyLen]byte, error) {
|
||||
var key [keyLen]byte
|
||||
info := fmt.Sprintf("%s:user:%d", purpose, userID)
|
||||
r := hkdf.New(sha256.New, c.masterKey[:], nil, []byte(info))
|
||||
if _, err := io.ReadFull(r, key[:]); err != nil {
|
||||
return key, fmt.Errorf("hkdf derive: %w", err)
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// DeriveKeyGlobal returns a 32-byte key derived from master key for global use
|
||||
// (e.g. encrypting DKIM private keys stored per-domain, not per-user).
|
||||
func (c *Crypto) DeriveKeyGlobal(purpose string) ([keyLen]byte, error) {
|
||||
var key [keyLen]byte
|
||||
r := hkdf.New(sha256.New, c.masterKey[:], nil, []byte("global:"+purpose))
|
||||
if _, err := io.ReadFull(r, key[:]); err != nil {
|
||||
return key, fmt.Errorf("hkdf derive global: %w", err)
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// Encrypt encrypts plaintext with AES-256-GCM using the provided 32-byte key.
|
||||
// Returns nonce||ciphertext||tag (nonce prepended, all opaque bytes).
|
||||
// Returns an error if plaintext is nil (use []byte{} for empty).
|
||||
func Encrypt(key [keyLen]byte, plaintext []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(key[:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("aes new cipher: %w", err)
|
||||
}
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("aes gcm: %w", err)
|
||||
}
|
||||
|
||||
nonce := make([]byte, gcmNonceLen)
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return nil, fmt.Errorf("rand nonce: %w", err)
|
||||
}
|
||||
|
||||
ciphertext := gcm.Seal(nonce, nonce, plaintext, nil)
|
||||
return ciphertext, nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts a nonce||ciphertext||tag blob produced by Encrypt.
|
||||
func Decrypt(key [keyLen]byte, ciphertext []byte) ([]byte, error) {
|
||||
if len(ciphertext) < gcmNonceLen {
|
||||
return nil, fmt.Errorf("ciphertext too short")
|
||||
}
|
||||
block, err := aes.NewCipher(key[:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("aes new cipher: %w", err)
|
||||
}
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("aes gcm: %w", err)
|
||||
}
|
||||
|
||||
nonce := ciphertext[:gcmNonceLen]
|
||||
data := ciphertext[gcmNonceLen:]
|
||||
plaintext, err := gcm.Open(nil, nonce, data, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decrypt: %w", err) // do not leak GCM error details
|
||||
}
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
// EncryptForUser derives a per-user key and encrypts.
|
||||
func (c *Crypto) EncryptForUser(userID int64, purpose string, plaintext []byte) ([]byte, error) {
|
||||
key, err := c.DeriveKey(purpose, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return Encrypt(key, plaintext)
|
||||
}
|
||||
|
||||
// DecryptForUser derives a per-user key and decrypts.
|
||||
func (c *Crypto) DecryptForUser(userID int64, purpose string, ciphertext []byte) ([]byte, error) {
|
||||
key, err := c.DeriveKey(purpose, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return Decrypt(key, ciphertext)
|
||||
}
|
||||
|
||||
// EncryptGlobal derives a global key for given purpose and encrypts.
|
||||
func (c *Crypto) EncryptGlobal(purpose string, plaintext []byte) ([]byte, error) {
|
||||
key, err := c.DeriveKeyGlobal(purpose)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return Encrypt(key, plaintext)
|
||||
}
|
||||
|
||||
// DecryptGlobal derives a global key for given purpose and decrypts.
|
||||
func (c *Crypto) DecryptGlobal(purpose string, ciphertext []byte) ([]byte, error) {
|
||||
key, err := c.DeriveKeyGlobal(purpose)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return Decrypt(key, ciphertext)
|
||||
}
|
||||
|
||||
// ---- Bcrypt ----
|
||||
|
||||
// HashPassword hashes a password with bcrypt at cost BcryptCost.
|
||||
func HashPassword(password string) (string, error) {
|
||||
if password == "" {
|
||||
return "", fmt.Errorf("password must not be empty")
|
||||
}
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(password), BcryptCost)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("bcrypt: %w", err)
|
||||
}
|
||||
return string(hash), nil
|
||||
}
|
||||
|
||||
// CheckPassword returns nil if password matches the stored bcrypt hash.
|
||||
// Uses constant-time comparison internally (bcrypt).
|
||||
func CheckPassword(hash, password string) error {
|
||||
return bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
|
||||
}
|
||||
|
||||
// ---- Session tokens ----
|
||||
|
||||
// NewToken generates a cryptographically random 32-byte token and returns
|
||||
// (rawToken, sha256HexHash). Store the hash; send the raw token to the client.
|
||||
func NewToken() (raw string, hash string, err error) {
|
||||
b := make([]byte, 32)
|
||||
if _, err = rand.Read(b); err != nil {
|
||||
return "", "", fmt.Errorf("rand token: %w", err)
|
||||
}
|
||||
raw = hex.EncodeToString(b)
|
||||
hash = HashToken(raw)
|
||||
return raw, hash, nil
|
||||
}
|
||||
|
||||
// HashToken returns the SHA-256 hex hash of a raw token string.
|
||||
func HashToken(raw string) string {
|
||||
h := sha256.Sum256([]byte(raw))
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
|
||||
// SecureCompare returns true if a == b using constant-time comparison.
|
||||
// Use for any comparison where timing attacks are a concern.
|
||||
func SecureCompare(a, b string) bool {
|
||||
return subtle.ConstantTimeCompare([]byte(a), []byte(b)) == 1
|
||||
}
|
||||
|
||||
// ---- Random helpers ----
|
||||
|
||||
// RandomHex returns n random bytes as a hex string (length 2n).
|
||||
func RandomHex(n int) (string, error) {
|
||||
b := make([]byte, n)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", fmt.Errorf("rand: %w", err)
|
||||
}
|
||||
return hex.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// RandomBytes returns n cryptographically random bytes.
|
||||
func RandomBytes(n int) ([]byte, error) {
|
||||
b := make([]byte, n)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return nil, fmt.Errorf("rand: %w", err)
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
@@ -0,0 +1,171 @@
|
||||
// Package db provides the database/sql wrapper and driver registration.
|
||||
// Default driver: modernc.org/sqlite (pure Go, no CGO).
|
||||
// Additional drivers registered via build tags: postgres, mysql, mssql.
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
_ "modernc.org/sqlite" // pure-Go SQLite driver
|
||||
)
|
||||
|
||||
// DB wraps sql.DB with convenience methods and prepared statement caching.
|
||||
type DB struct {
|
||||
db *sql.DB
|
||||
driver string
|
||||
}
|
||||
|
||||
// Open opens and validates the database connection, runs migrations, returns DB.
|
||||
func Open(driver, dsn string) (*DB, error) {
|
||||
if driver == "" {
|
||||
driver = "sqlite"
|
||||
}
|
||||
|
||||
// Map friendly driver names to database/sql driver names.
|
||||
sqlDriver := sqlDriverName(driver)
|
||||
|
||||
if driver == "sqlite" {
|
||||
dsn = sqliteDSN(dsn)
|
||||
}
|
||||
|
||||
sqlDB, err := sql.Open(sqlDriver, dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db open %s: %w", driver, err)
|
||||
}
|
||||
|
||||
// Connection pool tuning.
|
||||
if driver == "sqlite" {
|
||||
// SQLite: serialise with single connection to avoid SQLITE_BUSY.
|
||||
sqlDB.SetMaxOpenConns(1)
|
||||
sqlDB.SetMaxIdleConns(1)
|
||||
sqlDB.SetConnMaxLifetime(0)
|
||||
} else {
|
||||
sqlDB.SetMaxOpenConns(25)
|
||||
sqlDB.SetMaxIdleConns(10)
|
||||
sqlDB.SetConnMaxLifetime(5 * time.Minute)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := sqlDB.PingContext(ctx); err != nil {
|
||||
sqlDB.Close()
|
||||
return nil, fmt.Errorf("db ping: %w", err)
|
||||
}
|
||||
|
||||
d := &DB{db: sqlDB, driver: driver}
|
||||
|
||||
// Enable WAL mode for SQLite (dramatically improves concurrent read performance).
|
||||
if driver == "sqlite" {
|
||||
if _, err := sqlDB.Exec(`PRAGMA journal_mode=WAL`); err != nil {
|
||||
sqlDB.Close()
|
||||
return nil, fmt.Errorf("sqlite WAL: %w", err)
|
||||
}
|
||||
if _, err := sqlDB.Exec(`PRAGMA foreign_keys=ON`); err != nil {
|
||||
sqlDB.Close()
|
||||
return nil, fmt.Errorf("sqlite foreign_keys: %w", err)
|
||||
}
|
||||
if _, err := sqlDB.Exec(`PRAGMA busy_timeout=5000`); err != nil {
|
||||
sqlDB.Close()
|
||||
return nil, fmt.Errorf("sqlite busy_timeout: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := d.migrate(); err != nil {
|
||||
sqlDB.Close()
|
||||
return nil, fmt.Errorf("migrate: %w", err)
|
||||
}
|
||||
|
||||
return d, nil
|
||||
}
|
||||
|
||||
// Close closes the underlying sql.DB.
|
||||
func (d *DB) Close() error { return d.db.Close() }
|
||||
|
||||
// Driver returns the driver name (sqlite / postgres / mysql / mssql).
|
||||
func (d *DB) Driver() string { return d.driver }
|
||||
|
||||
// SQL returns the underlying *sql.DB for direct use when needed.
|
||||
func (d *DB) SQL() *sql.DB { return d.db }
|
||||
|
||||
// Exec runs a query with a per-call context timeout.
|
||||
func (d *DB) Exec(query string, args ...any) (sql.Result, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
return d.db.ExecContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
// QueryRow runs a single-row query.
|
||||
func (d *DB) QueryRow(query string, args ...any) *sql.Row {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
return d.db.QueryRowContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
// Query runs a multi-row query.
|
||||
func (d *DB) Query(query string, args ...any) (*sql.Rows, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
return d.db.QueryContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
// WithTx runs fn inside a transaction, rolling back on error or panic.
|
||||
func (d *DB) WithTx(ctx context.Context, fn func(*sql.Tx) error) error {
|
||||
tx, err := d.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("begin tx: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if p := recover(); p != nil {
|
||||
_ = tx.Rollback()
|
||||
panic(p) // re-raise
|
||||
}
|
||||
}()
|
||||
if err := fn(tx); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// ---- Placeholder helper ----
|
||||
|
||||
// Placeholder returns the SQL parameter placeholder for the current driver.
|
||||
// SQLite and MySQL use ?, PostgreSQL uses $1, $2… MSSQL uses @p1, @p2…
|
||||
func (d *DB) Placeholder(n int) string {
|
||||
switch d.driver {
|
||||
case "postgres":
|
||||
return fmt.Sprintf("$%d", n)
|
||||
case "mssql":
|
||||
return fmt.Sprintf("@p%d", n)
|
||||
default:
|
||||
return "?"
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Private helpers ----
|
||||
|
||||
func sqlDriverName(driver string) string {
|
||||
switch driver {
|
||||
case "sqlite":
|
||||
return "sqlite" // modernc.org/sqlite registers as "sqlite"
|
||||
case "postgres":
|
||||
return "postgres"
|
||||
case "mysql":
|
||||
return "mysql"
|
||||
case "mssql":
|
||||
return "sqlserver"
|
||||
default:
|
||||
return driver
|
||||
}
|
||||
}
|
||||
|
||||
func sqliteDSN(path string) string {
|
||||
if path == "" {
|
||||
path = "./data/mail.db"
|
||||
}
|
||||
// modernc.org/sqlite DSN supports query parameters.
|
||||
return path + "?_pragma=foreign_keys(1)"
|
||||
}
|
||||
@@ -0,0 +1,120 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/models"
|
||||
)
|
||||
|
||||
// GetDomain returns the domain row by name, or nil if not found.
|
||||
func (d *DB) GetDomain(ctx context.Context, name string) (*models.Domain, error) {
|
||||
row := d.db.QueryRowContext(ctx, `
|
||||
SELECT id, name, enabled, dkim_private_enc, dkim_public, dkim_selector,
|
||||
dkim_algo, spf_policy, dmarc_policy, max_users, max_quota_bytes, created_at
|
||||
FROM domains WHERE lower(name) = lower(?)`, name)
|
||||
|
||||
var dom models.Domain
|
||||
var privEnc []byte
|
||||
err := row.Scan(
|
||||
&dom.ID, &dom.Name, &dom.Enabled,
|
||||
&privEnc, &dom.DKIMPublic, &dom.DKIMSelector,
|
||||
&dom.DKIMAlgo, &dom.SPFPolicy, &dom.DMARCPolicy,
|
||||
&dom.MaxUsers, &dom.MaxQuotaBytes, &dom.CreatedAt,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get domain: %w", err)
|
||||
}
|
||||
dom.DKIMPrivateEnc = privEnc
|
||||
return &dom, nil
|
||||
}
|
||||
|
||||
// GetDomainByID returns the domain row by ID.
|
||||
func (d *DB) GetDomainByID(ctx context.Context, id int64) (*models.Domain, error) {
|
||||
row := d.db.QueryRowContext(ctx, `
|
||||
SELECT id, name, enabled, dkim_private_enc, dkim_public, dkim_selector,
|
||||
dkim_algo, spf_policy, dmarc_policy, max_users, max_quota_bytes, created_at
|
||||
FROM domains WHERE id = ?`, id)
|
||||
|
||||
var dom models.Domain
|
||||
var privEnc []byte
|
||||
err := row.Scan(
|
||||
&dom.ID, &dom.Name, &dom.Enabled,
|
||||
&privEnc, &dom.DKIMPublic, &dom.DKIMSelector,
|
||||
&dom.DKIMAlgo, &dom.SPFPolicy, &dom.DMARCPolicy,
|
||||
&dom.MaxUsers, &dom.MaxQuotaBytes, &dom.CreatedAt,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get domain by id: %w", err)
|
||||
}
|
||||
dom.DKIMPrivateEnc = privEnc
|
||||
return &dom, nil
|
||||
}
|
||||
|
||||
// IsLocalDomain returns true if name is a known enabled domain.
|
||||
func (d *DB) IsLocalDomain(ctx context.Context, name string) (bool, error) {
|
||||
var count int
|
||||
err := d.db.QueryRowContext(ctx,
|
||||
"SELECT COUNT(*) FROM domains WHERE lower(name)=lower(?) AND enabled=1", name).
|
||||
Scan(&count)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// ListDomains returns all domains ordered by name.
|
||||
func (d *DB) ListDomains(ctx context.Context) ([]*models.Domain, error) {
|
||||
rows, err := d.db.QueryContext(ctx, `
|
||||
SELECT id, name, enabled, dkim_private_enc, dkim_public, dkim_selector,
|
||||
dkim_algo, spf_policy, dmarc_policy, max_users, max_quota_bytes, created_at
|
||||
FROM domains ORDER BY name`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var doms []*models.Domain
|
||||
for rows.Next() {
|
||||
var dom models.Domain
|
||||
var privEnc []byte
|
||||
err := rows.Scan(
|
||||
&dom.ID, &dom.Name, &dom.Enabled,
|
||||
&privEnc, &dom.DKIMPublic, &dom.DKIMSelector,
|
||||
&dom.DKIMAlgo, &dom.SPFPolicy, &dom.DMARCPolicy,
|
||||
&dom.MaxUsers, &dom.MaxQuotaBytes, &dom.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dom.DKIMPrivateEnc = privEnc
|
||||
doms = append(doms, &dom)
|
||||
}
|
||||
return doms, rows.Err()
|
||||
}
|
||||
|
||||
// CreateDomain inserts a new domain. Returns the new ID.
|
||||
func (d *DB) CreateDomain(ctx context.Context, name, selector, algo string) (int64, error) {
|
||||
res, err := d.db.ExecContext(ctx, `
|
||||
INSERT INTO domains (name, enabled, dkim_selector, dkim_algo)
|
||||
VALUES (?, 1, ?, ?)`, name, selector, algo)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("create domain: %w", err)
|
||||
}
|
||||
return res.LastInsertId()
|
||||
}
|
||||
|
||||
// SaveDKIMKeys stores encrypted DKIM private key + public key for a domain.
|
||||
func (d *DB) SaveDKIMKeys(ctx context.Context, domainID int64, privEnc []byte, pubPEM string) error {
|
||||
_, err := d.db.ExecContext(ctx,
|
||||
"UPDATE domains SET dkim_private_enc=?, dkim_public=? WHERE id=?",
|
||||
privEnc, pubPEM, domainID)
|
||||
return err
|
||||
}
|
||||
@@ -0,0 +1,284 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/models"
|
||||
)
|
||||
|
||||
// IMAPMessage is a lightweight message descriptor used by the IMAP layer.
|
||||
// The raw/body blobs are NOT loaded here — fetch separately via GetMessageRaw.
|
||||
type IMAPMessage struct {
|
||||
ID int64
|
||||
MailboxID int64
|
||||
UID uint32
|
||||
MessageID string // RFC 2822 Message-ID header
|
||||
Subject string
|
||||
FromEmail string
|
||||
FromName string
|
||||
ToList string
|
||||
Date time.Time
|
||||
SizeBytes int64
|
||||
HasAttachment bool
|
||||
IsRead bool
|
||||
IsStarred bool
|
||||
IsDraft bool
|
||||
IsDeleted bool // deleted_at IS NOT NULL
|
||||
Flags string
|
||||
SpamScore int
|
||||
ReceivedAt time.Time
|
||||
}
|
||||
|
||||
// ListIMAPMessages returns all non-deleted messages in a mailbox ordered by UID ascending.
|
||||
func (d *DB) ListIMAPMessages(ctx context.Context, mailboxID int64) ([]*IMAPMessage, error) {
|
||||
rows, err := d.db.QueryContext(ctx, `
|
||||
SELECT id, mailbox_id, uid, message_id, subject, from_email, from_name,
|
||||
to_list, date, size_bytes, has_attachment,
|
||||
is_read, is_starred, is_draft, flags, spam_score, received_at
|
||||
FROM messages
|
||||
WHERE mailbox_id = ? AND deleted_at IS NULL
|
||||
ORDER BY uid ASC`, mailboxID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var out []*IMAPMessage
|
||||
for rows.Next() {
|
||||
m, err := scanIMAPMessage(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, m)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
// GetIMAPMessageByUID returns one message by UID within a mailbox.
|
||||
func (d *DB) GetIMAPMessageByUID(ctx context.Context, mailboxID int64, uid uint32) (*IMAPMessage, error) {
|
||||
rows, err := d.db.QueryContext(ctx, `
|
||||
SELECT id, mailbox_id, uid, message_id, subject, from_email, from_name,
|
||||
to_list, date, size_bytes, has_attachment,
|
||||
is_read, is_starred, is_draft, flags, spam_score, received_at
|
||||
FROM messages
|
||||
WHERE mailbox_id = ? AND uid = ? AND deleted_at IS NULL
|
||||
LIMIT 1`, mailboxID, uid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
if !rows.Next() {
|
||||
return nil, nil
|
||||
}
|
||||
return scanIMAPMessage(rows)
|
||||
}
|
||||
|
||||
// SetMessageFlags updates the mutable flags for a message.
|
||||
func (d *DB) SetMessageFlags(ctx context.Context, messageID int64, isRead, isStarred, isDraft bool, extraFlags string) error {
|
||||
_, err := d.db.ExecContext(ctx,
|
||||
"UPDATE messages SET is_read=?, is_starred=?, is_draft=?, flags=? WHERE id=?",
|
||||
isRead, isStarred, isDraft, extraFlags, messageID)
|
||||
return err
|
||||
}
|
||||
|
||||
// SoftDeleteMessage marks a message as deleted (sets deleted_at).
|
||||
func (d *DB) SoftDeleteMessage(ctx context.Context, messageID int64) error {
|
||||
_, err := d.db.ExecContext(ctx,
|
||||
"UPDATE messages SET deleted_at=? WHERE id=?", time.Now().UTC(), messageID)
|
||||
return err
|
||||
}
|
||||
|
||||
// HardDeleteMessages physically removes all soft-deleted messages from a mailbox.
|
||||
// Returns the UIDs of deleted messages (for EXPUNGE responses).
|
||||
func (d *DB) HardDeleteMessages(ctx context.Context, mailboxID int64) ([]uint32, error) {
|
||||
rows, err := d.db.QueryContext(ctx,
|
||||
"SELECT uid FROM messages WHERE mailbox_id=? AND deleted_at IS NOT NULL ORDER BY uid ASC",
|
||||
mailboxID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var uids []uint32
|
||||
for rows.Next() {
|
||||
var uid uint32
|
||||
if err := rows.Scan(&uid); err != nil {
|
||||
rows.Close()
|
||||
return nil, err
|
||||
}
|
||||
uids = append(uids, uid)
|
||||
}
|
||||
rows.Close()
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(uids) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Delete attachments first (FK).
|
||||
_, err = d.db.ExecContext(ctx, `
|
||||
DELETE FROM attachments WHERE message_id IN (
|
||||
SELECT id FROM messages WHERE mailbox_id=? AND deleted_at IS NOT NULL
|
||||
)`, mailboxID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("delete attachments: %w", err)
|
||||
}
|
||||
_, err = d.db.ExecContext(ctx,
|
||||
"DELETE FROM messages WHERE mailbox_id=? AND deleted_at IS NOT NULL", mailboxID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("delete messages: %w", err)
|
||||
}
|
||||
return uids, nil
|
||||
}
|
||||
|
||||
// CopyMessageToMailbox duplicates a message row to another mailbox.
|
||||
// Returns the new UID.
|
||||
func (d *DB) CopyMessageToMailbox(ctx context.Context, srcMsgID, destMailboxID, userID int64) (uint32, error) {
|
||||
// Read source.
|
||||
var src struct {
|
||||
mailboxID int64
|
||||
uid uint32
|
||||
messageID string
|
||||
subject string
|
||||
fromEmail string
|
||||
fromName string
|
||||
toList string
|
||||
ccList string
|
||||
bccList string
|
||||
replyTo string
|
||||
date time.Time
|
||||
bodyTextEnc []byte
|
||||
bodyHTMLEnc []byte
|
||||
rawEnc []byte
|
||||
sizeBytes int64
|
||||
hasAttachment bool
|
||||
isRead bool
|
||||
isStarred bool
|
||||
isDraft bool
|
||||
flags string
|
||||
spamScore int
|
||||
}
|
||||
err := d.db.QueryRowContext(ctx, `
|
||||
SELECT mailbox_id, uid, message_id, subject, from_email, from_name,
|
||||
to_list, cc_list, bcc_list, reply_to, date,
|
||||
body_text_enc, body_html_enc, raw_enc,
|
||||
size_bytes, has_attachment, is_read, is_starred, is_draft,
|
||||
flags, spam_score
|
||||
FROM messages WHERE id=? AND deleted_at IS NULL`, srcMsgID).Scan(
|
||||
&src.mailboxID, &src.uid, &src.messageID, &src.subject,
|
||||
&src.fromEmail, &src.fromName, &src.toList, &src.ccList, &src.bccList, &src.replyTo,
|
||||
&src.date, &src.bodyTextEnc, &src.bodyHTMLEnc, &src.rawEnc,
|
||||
&src.sizeBytes, &src.hasAttachment, &src.isRead, &src.isStarred, &src.isDraft,
|
||||
&src.flags, &src.spamScore,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return 0, fmt.Errorf("source message %d not found", srcMsgID)
|
||||
}
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("copy message read: %w", err)
|
||||
}
|
||||
|
||||
// Allocate UID in destination.
|
||||
uid, err := d.NextUID(ctx, destMailboxID)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("copy message uid: %w", err)
|
||||
}
|
||||
|
||||
ins := &MessageInsert{
|
||||
MailboxID: destMailboxID,
|
||||
UID: uid,
|
||||
MessageID: src.messageID,
|
||||
Subject: src.subject,
|
||||
FromEmail: src.fromEmail,
|
||||
FromName: src.fromName,
|
||||
ToList: src.toList,
|
||||
CCList: src.ccList,
|
||||
BCCList: src.bccList,
|
||||
ReplyTo: src.replyTo,
|
||||
Date: src.date,
|
||||
BodyTextEnc: src.bodyTextEnc,
|
||||
BodyHTMLEnc: src.bodyHTMLEnc,
|
||||
RawEnc: src.rawEnc,
|
||||
SizeBytes: src.sizeBytes,
|
||||
HasAttachment: src.hasAttachment,
|
||||
IsRead: src.isRead,
|
||||
IsStarred: src.isStarred,
|
||||
IsDraft: src.isDraft,
|
||||
Flags: src.flags,
|
||||
SpamScore: src.spamScore,
|
||||
}
|
||||
if _, err := d.InsertMessage(ctx, ins); err != nil {
|
||||
return 0, fmt.Errorf("copy message insert: %w", err)
|
||||
}
|
||||
return uid, nil
|
||||
}
|
||||
|
||||
// RenameMailbox updates the name field of a mailbox.
|
||||
func (d *DB) RenameMailbox(ctx context.Context, mailboxID int64, newName string) error {
|
||||
_, err := d.db.ExecContext(ctx,
|
||||
"UPDATE mailboxes SET name=? WHERE id=?", newName, mailboxID)
|
||||
return err
|
||||
}
|
||||
|
||||
// SetMailboxSubscribed updates the subscribed flag on a mailbox.
|
||||
func (d *DB) SetMailboxSubscribed(ctx context.Context, mailboxID int64, subscribed bool) error {
|
||||
_, err := d.db.ExecContext(ctx,
|
||||
"UPDATE mailboxes SET subscribed=? WHERE id=?", subscribed, mailboxID)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetMailboxMessageCounts returns (total, unseen) counts for a mailbox.
|
||||
func (d *DB) GetMailboxMessageCounts(ctx context.Context, mailboxID int64) (total, unseen int64, err error) {
|
||||
err = d.db.QueryRowContext(ctx, `
|
||||
SELECT COUNT(*), COUNT(CASE WHEN is_read=0 THEN 1 END)
|
||||
FROM messages WHERE mailbox_id=? AND deleted_at IS NULL`, mailboxID).Scan(&total, &unseen)
|
||||
return
|
||||
}
|
||||
|
||||
// GetMailboxSize returns the total size in bytes of all messages in a mailbox.
|
||||
func (d *DB) GetMailboxSize(ctx context.Context, mailboxID int64) (int64, error) {
|
||||
var sz sql.NullInt64
|
||||
err := d.db.QueryRowContext(ctx,
|
||||
"SELECT SUM(size_bytes) FROM messages WHERE mailbox_id=? AND deleted_at IS NULL",
|
||||
mailboxID).Scan(&sz)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return sz.Int64, nil
|
||||
}
|
||||
|
||||
// ---- helpers ----
|
||||
|
||||
func scanIMAPMessage(rows *sql.Rows) (*IMAPMessage, error) {
|
||||
m := &IMAPMessage{}
|
||||
err := rows.Scan(
|
||||
&m.ID, &m.MailboxID, &m.UID, &m.MessageID, &m.Subject,
|
||||
&m.FromEmail, &m.FromName, &m.ToList,
|
||||
&m.Date, &m.SizeBytes, &m.HasAttachment,
|
||||
&m.IsRead, &m.IsStarred, &m.IsDraft,
|
||||
&m.Flags, &m.SpamScore, &m.ReceivedAt,
|
||||
)
|
||||
return m, err
|
||||
}
|
||||
|
||||
// mailboxTypeToAttr converts our type string to an IMAP special-use string.
|
||||
// Callers handle the conversion to imap.MailboxAttr themselves.
|
||||
func MailboxTypeToSpecialUse(mboxType string) string {
|
||||
switch mboxType {
|
||||
case models.MailboxSent:
|
||||
return `\Sent`
|
||||
case models.MailboxDrafts:
|
||||
return `\Drafts`
|
||||
case models.MailboxTrash:
|
||||
return `\Trash`
|
||||
case models.MailboxSpam:
|
||||
return `\Junk`
|
||||
case models.MailboxArchive:
|
||||
return `\Archive`
|
||||
case models.MailboxInbox:
|
||||
return `\Inbox`
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -0,0 +1,401 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/models"
|
||||
)
|
||||
|
||||
// GetMailbox returns the mailbox with the given name for a user, or nil.
|
||||
func (d *DB) GetMailbox(ctx context.Context, userID int64, name string) (*models.Mailbox, error) {
|
||||
row := d.db.QueryRowContext(ctx, `
|
||||
SELECT id, user_id, name, type, parent_id, uid_validity, uid_next, subscribed, created_at
|
||||
FROM mailboxes WHERE user_id=? AND name=?`, userID, name)
|
||||
return scanMailbox(row)
|
||||
}
|
||||
|
||||
// GetMailboxByType returns the first mailbox of the given type for a user.
|
||||
func (d *DB) GetMailboxByType(ctx context.Context, userID int64, mboxType string) (*models.Mailbox, error) {
|
||||
row := d.db.QueryRowContext(ctx, `
|
||||
SELECT id, user_id, name, type, parent_id, uid_validity, uid_next, subscribed, created_at
|
||||
FROM mailboxes WHERE user_id=? AND type=? LIMIT 1`, userID, mboxType)
|
||||
return scanMailbox(row)
|
||||
}
|
||||
|
||||
// GetMailboxByID returns the mailbox by ID.
|
||||
func (d *DB) GetMailboxByID(ctx context.Context, id int64) (*models.Mailbox, error) {
|
||||
row := d.db.QueryRowContext(ctx, `
|
||||
SELECT id, user_id, name, type, parent_id, uid_validity, uid_next, subscribed, created_at
|
||||
FROM mailboxes WHERE id=?`, id)
|
||||
return scanMailbox(row)
|
||||
}
|
||||
|
||||
// ListMailboxes returns all subscribed mailboxes for a user, ordered by name.
|
||||
func (d *DB) ListMailboxes(ctx context.Context, userID int64) ([]*models.Mailbox, error) {
|
||||
rows, err := d.db.QueryContext(ctx, `
|
||||
SELECT id, user_id, name, type, parent_id, uid_validity, uid_next, subscribed, created_at
|
||||
FROM mailboxes WHERE user_id=? ORDER BY name`, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var mbs []*models.Mailbox
|
||||
for rows.Next() {
|
||||
mb, err := scanMailboxRow(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mbs = append(mbs, mb)
|
||||
}
|
||||
return mbs, rows.Err()
|
||||
}
|
||||
|
||||
// CreateMailbox creates a mailbox. Returns the new mailbox with uid_validity set.
|
||||
func (d *DB) CreateMailbox(ctx context.Context, userID int64, name, mboxType string, parentID *int64) (*models.Mailbox, error) {
|
||||
uidValidity := uint32(rand.Int31()) //nolint:gosec — not a security value
|
||||
if uidValidity == 0 {
|
||||
uidValidity = 1
|
||||
}
|
||||
|
||||
res, err := d.db.ExecContext(ctx, `
|
||||
INSERT INTO mailboxes (user_id, name, type, parent_id, uid_validity, uid_next, subscribed, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, 1, 1, ?)`,
|
||||
userID, name, mboxType, parentID, uidValidity, time.Now().UTC())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create mailbox: %w", err)
|
||||
}
|
||||
id, _ := res.LastInsertId()
|
||||
|
||||
return &models.Mailbox{
|
||||
ID: id,
|
||||
UserID: userID,
|
||||
Name: name,
|
||||
Type: mboxType,
|
||||
ParentID: parentID,
|
||||
UIDValidity: uidValidity,
|
||||
UIDNext: 1,
|
||||
Subscribed: true,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateDefaultMailboxes creates the standard mailbox set for a new user.
|
||||
// Idempotent — skips any that already exist.
|
||||
func (d *DB) CreateDefaultMailboxes(ctx context.Context, userID int64) error {
|
||||
defaults := []struct {
|
||||
name string
|
||||
mboxType string
|
||||
}{
|
||||
{"INBOX", models.MailboxInbox},
|
||||
{"Sent", models.MailboxSent},
|
||||
{"Drafts", models.MailboxDrafts},
|
||||
{"Trash", models.MailboxTrash},
|
||||
{"Spam", models.MailboxSpam},
|
||||
{"Archive", models.MailboxArchive},
|
||||
}
|
||||
|
||||
for _, mb := range defaults {
|
||||
existing, err := d.GetMailbox(ctx, userID, mb.name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if existing != nil {
|
||||
continue
|
||||
}
|
||||
if _, err := d.CreateMailbox(ctx, userID, mb.name, mb.mboxType, nil); err != nil {
|
||||
return fmt.Errorf("create default mailbox %s: %w", mb.name, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NextUID allocates the next UID for a mailbox atomically.
|
||||
// Returns the UID to use for the new message.
|
||||
func (d *DB) NextUID(ctx context.Context, mailboxID int64) (uint32, error) {
|
||||
tx, err := d.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer tx.Rollback() //nolint:errcheck
|
||||
|
||||
var next uint32
|
||||
err = tx.QueryRowContext(ctx,
|
||||
"SELECT uid_next FROM mailboxes WHERE id=?", mailboxID).Scan(&next)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("read uid_next: %w", err)
|
||||
}
|
||||
|
||||
_, err = tx.ExecContext(ctx,
|
||||
"UPDATE mailboxes SET uid_next=uid_next+1 WHERE id=?", mailboxID)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("increment uid_next: %w", err)
|
||||
}
|
||||
|
||||
return next, tx.Commit()
|
||||
}
|
||||
|
||||
// ---- Message operations ----
|
||||
|
||||
// InsertMessage stores a message record (body is already encrypted; call SaveRawBody separately).
|
||||
func (d *DB) InsertMessage(ctx context.Context, m *MessageInsert) (int64, error) {
|
||||
res, err := d.db.ExecContext(ctx, `
|
||||
INSERT INTO messages
|
||||
(mailbox_id, uid, message_id, subject, from_email, from_name,
|
||||
to_list, cc_list, bcc_list, reply_to, date,
|
||||
body_text_enc, body_html_enc, raw_enc,
|
||||
size_bytes, has_attachment, is_read, is_starred, is_draft,
|
||||
flags, spam_score, received_at)
|
||||
VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)`,
|
||||
m.MailboxID, m.UID, m.MessageID, m.Subject, m.FromEmail, m.FromName,
|
||||
m.ToList, m.CCList, m.BCCList, m.ReplyTo, m.Date,
|
||||
m.BodyTextEnc, m.BodyHTMLEnc, m.RawEnc,
|
||||
m.SizeBytes, m.HasAttachment, m.IsRead, m.IsStarred, m.IsDraft,
|
||||
m.Flags, m.SpamScore, time.Now().UTC(),
|
||||
)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("insert message: %w", err)
|
||||
}
|
||||
return res.LastInsertId()
|
||||
}
|
||||
|
||||
// MessageInsert is the data transfer object for inserting a new message.
|
||||
type MessageInsert struct {
|
||||
MailboxID int64
|
||||
UID uint32
|
||||
MessageID string
|
||||
Subject string
|
||||
FromEmail string
|
||||
FromName string
|
||||
ToList string
|
||||
CCList string
|
||||
BCCList string
|
||||
ReplyTo string
|
||||
Date time.Time
|
||||
BodyTextEnc []byte
|
||||
BodyHTMLEnc []byte
|
||||
RawEnc []byte
|
||||
SizeBytes int64
|
||||
HasAttachment bool
|
||||
IsRead bool
|
||||
IsStarred bool
|
||||
IsDraft bool
|
||||
Flags string
|
||||
SpamScore int
|
||||
}
|
||||
|
||||
// InsertAttachment stores an attachment record for a message.
|
||||
func (d *DB) InsertAttachment(ctx context.Context, a *AttachmentInsert) (int64, error) {
|
||||
res, err := d.db.ExecContext(ctx, `
|
||||
INSERT INTO attachments
|
||||
(message_id, filename, content_type, size_bytes, data_enc, data_path,
|
||||
content_id, inline, mime_path)
|
||||
VALUES (?,?,?,?,?,?,?,?,?)`,
|
||||
a.MessageID, a.Filename, a.ContentType, a.SizeBytes,
|
||||
a.DataEnc, a.DataPath, a.ContentID, a.Inline, a.MIMEPath,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("insert attachment: %w", err)
|
||||
}
|
||||
return res.LastInsertId()
|
||||
}
|
||||
|
||||
// AttachmentInsert is the data transfer object for inserting an attachment.
|
||||
type AttachmentInsert struct {
|
||||
MessageID int64
|
||||
Filename string
|
||||
ContentType string
|
||||
SizeBytes int64
|
||||
DataEnc []byte
|
||||
DataPath string
|
||||
ContentID string
|
||||
Inline bool
|
||||
MIMEPath string
|
||||
}
|
||||
|
||||
// GetMessageRaw returns the encrypted raw blob for a message.
|
||||
func (d *DB) GetMessageRaw(ctx context.Context, messageID int64) ([]byte, error) {
|
||||
var raw []byte
|
||||
err := d.db.QueryRowContext(ctx,
|
||||
"SELECT raw_enc FROM messages WHERE id=?", messageID).Scan(&raw)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return raw, err
|
||||
}
|
||||
|
||||
// ListMessages returns messages in a mailbox ordered by UID descending.
|
||||
// Only non-deleted messages are returned.
|
||||
func (d *DB) ListMessages(ctx context.Context, mailboxID int64, limit, offset int) ([]*models.Message, error) {
|
||||
rows, err := d.db.QueryContext(ctx, `
|
||||
SELECT id, mailbox_id, uid, message_id, subject, from_email, from_name,
|
||||
to_list, cc_list, bcc_list, reply_to, date,
|
||||
size_bytes, has_attachment, is_read, is_starred, is_draft,
|
||||
flags, spam_score, received_at
|
||||
FROM messages
|
||||
WHERE mailbox_id=? AND deleted_at IS NULL
|
||||
ORDER BY uid DESC
|
||||
LIMIT ? OFFSET ?`, mailboxID, limit, offset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var msgs []*models.Message
|
||||
for rows.Next() {
|
||||
var m models.Message
|
||||
err := rows.Scan(
|
||||
&m.ID, &m.MailboxID, &m.UID, &m.MessageID, &m.Subject,
|
||||
&m.FromEmail, &m.FromName, &m.ToList, &m.CCList, &m.BCCList,
|
||||
&m.ReplyTo, &m.Date, &m.SizeBytes, &m.HasAttachment,
|
||||
&m.IsRead, &m.IsStarred, &m.IsDraft, &m.Flags,
|
||||
&m.SpamScore, &m.ReceivedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
msgs = append(msgs, &m)
|
||||
}
|
||||
return msgs, rows.Err()
|
||||
}
|
||||
|
||||
// CountUnread returns the number of unread messages in a mailbox.
|
||||
func (d *DB) CountUnread(ctx context.Context, mailboxID int64) (int, error) {
|
||||
var n int
|
||||
err := d.db.QueryRowContext(ctx,
|
||||
"SELECT COUNT(*) FROM messages WHERE mailbox_id=? AND is_read=0 AND deleted_at IS NULL",
|
||||
mailboxID).Scan(&n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
// ---- Queue operations ----
|
||||
|
||||
// EnqueueMessage inserts a delivery queue entry. Returns the new queue ID.
|
||||
func (d *DB) EnqueueMessage(ctx context.Context, domainID int64, from, to, msgID string, rawEnc []byte, maxAgeHours int) (int64, error) {
|
||||
expires := time.Now().UTC().Add(time.Duration(maxAgeHours) * time.Hour)
|
||||
res, err := d.db.ExecContext(ctx, `
|
||||
INSERT INTO queue
|
||||
(domain_id, from_addr, to_addr, raw_enc, message_id, status,
|
||||
attempts, next_attempt, created_at, expires_at)
|
||||
VALUES (?, ?, ?, ?, ?, 'pending', 0, ?, ?, ?)`,
|
||||
domainID, from, to, rawEnc, msgID,
|
||||
time.Now().UTC(), time.Now().UTC(), expires)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("enqueue: %w", err)
|
||||
}
|
||||
return res.LastInsertId()
|
||||
}
|
||||
|
||||
// PeekQueue returns up to limit pending/retry-eligible queue entries.
|
||||
func (d *DB) PeekQueue(ctx context.Context, limit int) ([]QueueRow, error) {
|
||||
rows, err := d.db.QueryContext(ctx, `
|
||||
SELECT id, domain_id, from_addr, to_addr, raw_enc, message_id,
|
||||
status, attempts, expires_at
|
||||
FROM queue
|
||||
WHERE status IN ('pending','failed')
|
||||
AND next_attempt <= ?
|
||||
AND expires_at > ?
|
||||
ORDER BY next_attempt ASC
|
||||
LIMIT ?`,
|
||||
time.Now().UTC(), time.Now().UTC(), limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var out []QueueRow
|
||||
for rows.Next() {
|
||||
var q QueueRow
|
||||
var domainID sql.NullInt64
|
||||
err := rows.Scan(
|
||||
&q.ID, &domainID, &q.FromAddr, &q.ToAddr,
|
||||
&q.RawEnc, &q.MessageID, &q.Status, &q.Attempts, &q.ExpiresAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if domainID.Valid {
|
||||
q.DomainID = domainID.Int64
|
||||
}
|
||||
out = append(out, q)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
// QueueRow is a minimal queue entry for the delivery worker.
|
||||
type QueueRow struct {
|
||||
ID int64
|
||||
DomainID int64
|
||||
FromAddr string
|
||||
ToAddr string
|
||||
RawEnc []byte
|
||||
MessageID string
|
||||
Status string
|
||||
Attempts int
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
// SetQueueStatus updates the status of a queue entry.
|
||||
func (d *DB) SetQueueStatus(ctx context.Context, id int64, status, errMsg string, nextAttempt *time.Time) error {
|
||||
_, err := d.db.ExecContext(ctx, `
|
||||
UPDATE queue
|
||||
SET status=?, attempts=attempts+1, last_attempt=?,
|
||||
error_log=error_log || ?, next_attempt=COALESCE(?, next_attempt)
|
||||
WHERE id=?`,
|
||||
status, time.Now().UTC(),
|
||||
fmt.Sprintf("[%s] %s\n", time.Now().UTC().Format(time.RFC3339), errMsg),
|
||||
nextAttempt, id)
|
||||
return err
|
||||
}
|
||||
|
||||
// LogDelivery inserts a delivery log entry.
|
||||
func (d *DB) LogDelivery(ctx context.Context, queueID int64, from, to, status string, smtpCode int, smtpMsg, mxHost string) error {
|
||||
_, err := d.db.ExecContext(ctx, `
|
||||
INSERT INTO delivery_log (queue_id, from_addr, to_addr, status, smtp_code, smtp_message, mx_host, created_at)
|
||||
VALUES (?,?,?,?,?,?,?,?)`,
|
||||
queueID, from, to, status, smtpCode, smtpMsg, mxHost, time.Now().UTC())
|
||||
return err
|
||||
}
|
||||
|
||||
// ---- private ----
|
||||
|
||||
func scanMailbox(row *sql.Row) (*models.Mailbox, error) {
|
||||
var mb models.Mailbox
|
||||
var parentID sql.NullInt64
|
||||
err := row.Scan(
|
||||
&mb.ID, &mb.UserID, &mb.Name, &mb.Type,
|
||||
&parentID, &mb.UIDValidity, &mb.UIDNext, &mb.Subscribed, &mb.CreatedAt,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("scan mailbox: %w", err)
|
||||
}
|
||||
if parentID.Valid {
|
||||
id := parentID.Int64
|
||||
mb.ParentID = &id
|
||||
}
|
||||
return &mb, nil
|
||||
}
|
||||
|
||||
func scanMailboxRow(rows *sql.Rows) (*models.Mailbox, error) {
|
||||
var mb models.Mailbox
|
||||
var parentID sql.NullInt64
|
||||
err := rows.Scan(
|
||||
&mb.ID, &mb.UserID, &mb.Name, &mb.Type,
|
||||
&parentID, &mb.UIDValidity, &mb.UIDNext, &mb.Subscribed, &mb.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if parentID.Valid {
|
||||
id := parentID.Int64
|
||||
mb.ParentID = &id
|
||||
}
|
||||
return &mb, nil
|
||||
}
|
||||
@@ -0,0 +1,337 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// migration is a versioned schema change.
|
||||
type migration struct {
|
||||
version int
|
||||
up string // SQL to apply
|
||||
}
|
||||
|
||||
// migrations must be append-only. Never edit an applied migration.
|
||||
var migrations = []migration{
|
||||
{1, schemav1},
|
||||
}
|
||||
|
||||
// migrate applies any unapplied migrations in order.
|
||||
func (d *DB) migrate() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Ensure migrations table exists.
|
||||
_, err := d.db.ExecContext(ctx, `
|
||||
CREATE TABLE IF NOT EXISTS schema_migrations (
|
||||
version INTEGER PRIMARY KEY,
|
||||
applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
)`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create migrations table: %w", err)
|
||||
}
|
||||
|
||||
for _, m := range migrations {
|
||||
var count int
|
||||
err := d.db.QueryRowContext(ctx,
|
||||
"SELECT COUNT(*) FROM schema_migrations WHERE version = ?", m.version).Scan(&count)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check migration %d: %w", m.version, err)
|
||||
}
|
||||
if count > 0 {
|
||||
continue // already applied
|
||||
}
|
||||
|
||||
if err := d.WithTx(ctx, func(tx *sql.Tx) error {
|
||||
if _, err := tx.ExecContext(ctx, m.up); err != nil {
|
||||
return fmt.Errorf("apply migration %d: %w", m.version, err)
|
||||
}
|
||||
_, err := tx.ExecContext(ctx,
|
||||
"INSERT INTO schema_migrations (version) VALUES (?)", m.version)
|
||||
return err
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Printf("[db] applied migration %d\n", m.version)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ---- Schema v1 (initial) ----
|
||||
|
||||
const schemav1 = `
|
||||
-- Domains
|
||||
CREATE TABLE IF NOT EXISTS domains (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL UNIQUE,
|
||||
enabled BOOLEAN NOT NULL DEFAULT 1,
|
||||
dkim_private_enc BLOB,
|
||||
dkim_public TEXT,
|
||||
dkim_selector TEXT NOT NULL DEFAULT 'mail',
|
||||
dkim_algo TEXT NOT NULL DEFAULT 'rsa2048',
|
||||
spf_policy TEXT,
|
||||
dmarc_policy TEXT,
|
||||
max_users INTEGER NOT NULL DEFAULT 0,
|
||||
max_quota_bytes INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- Users
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
domain_id INTEGER NOT NULL REFERENCES domains(id) ON DELETE CASCADE,
|
||||
username TEXT NOT NULL,
|
||||
email TEXT NOT NULL UNIQUE,
|
||||
password_hash TEXT NOT NULL,
|
||||
display_name TEXT NOT NULL DEFAULT '',
|
||||
quota_bytes INTEGER NOT NULL DEFAULT 1073741824,
|
||||
used_bytes INTEGER NOT NULL DEFAULT 0,
|
||||
enabled BOOLEAN NOT NULL DEFAULT 1,
|
||||
admin BOOLEAN NOT NULL DEFAULT 0,
|
||||
domain_admin BOOLEAN NOT NULL DEFAULT 0,
|
||||
mfa_secret_enc BLOB,
|
||||
mfa_enabled BOOLEAN NOT NULL DEFAULT 0,
|
||||
recovery_codes_enc BLOB,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
last_login TIMESTAMP
|
||||
);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_users_email ON users(email);
|
||||
CREATE INDEX IF NOT EXISTS idx_users_domain ON users(domain_id);
|
||||
|
||||
-- User aliases
|
||||
CREATE TABLE IF NOT EXISTS user_aliases (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
alias_email TEXT NOT NULL UNIQUE
|
||||
);
|
||||
|
||||
-- Mailboxes (IMAP folders)
|
||||
CREATE TABLE IF NOT EXISTS mailboxes (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
name TEXT NOT NULL,
|
||||
type TEXT NOT NULL DEFAULT 'custom',
|
||||
parent_id INTEGER REFERENCES mailboxes(id) ON DELETE CASCADE,
|
||||
uid_validity INTEGER NOT NULL DEFAULT 1,
|
||||
uid_next INTEGER NOT NULL DEFAULT 1,
|
||||
subscribed BOOLEAN NOT NULL DEFAULT 1,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(user_id, name)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_mailboxes_user ON mailboxes(user_id);
|
||||
|
||||
-- Messages
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
mailbox_id INTEGER NOT NULL REFERENCES mailboxes(id) ON DELETE CASCADE,
|
||||
uid INTEGER NOT NULL,
|
||||
message_id TEXT,
|
||||
subject TEXT NOT NULL DEFAULT '',
|
||||
from_email TEXT NOT NULL DEFAULT '',
|
||||
from_name TEXT NOT NULL DEFAULT '',
|
||||
to_list TEXT NOT NULL DEFAULT '',
|
||||
cc_list TEXT NOT NULL DEFAULT '',
|
||||
bcc_list TEXT NOT NULL DEFAULT '',
|
||||
reply_to TEXT NOT NULL DEFAULT '',
|
||||
date TIMESTAMP,
|
||||
body_text_enc BLOB,
|
||||
body_html_enc BLOB,
|
||||
raw_enc BLOB,
|
||||
size_bytes INTEGER NOT NULL DEFAULT 0,
|
||||
has_attachment BOOLEAN NOT NULL DEFAULT 0,
|
||||
is_read BOOLEAN NOT NULL DEFAULT 0,
|
||||
is_starred BOOLEAN NOT NULL DEFAULT 0,
|
||||
is_draft BOOLEAN NOT NULL DEFAULT 0,
|
||||
flags TEXT NOT NULL DEFAULT '',
|
||||
spam_score INTEGER NOT NULL DEFAULT 0,
|
||||
received_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
deleted_at TIMESTAMP,
|
||||
UNIQUE(mailbox_id, uid)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_messages_mailbox ON messages(mailbox_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_messages_uid ON messages(mailbox_id, uid);
|
||||
CREATE INDEX IF NOT EXISTS idx_messages_date ON messages(mailbox_id, date);
|
||||
CREATE INDEX IF NOT EXISTS idx_messages_deleted ON messages(mailbox_id, deleted_at);
|
||||
|
||||
-- Attachments
|
||||
CREATE TABLE IF NOT EXISTS attachments (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
message_id INTEGER NOT NULL REFERENCES messages(id) ON DELETE CASCADE,
|
||||
filename TEXT NOT NULL,
|
||||
content_type TEXT NOT NULL,
|
||||
size_bytes INTEGER NOT NULL DEFAULT 0,
|
||||
data_enc BLOB,
|
||||
data_path TEXT,
|
||||
content_id TEXT,
|
||||
inline BOOLEAN NOT NULL DEFAULT 0,
|
||||
mime_path TEXT
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_attachments_message ON attachments(message_id);
|
||||
|
||||
-- Delivery queue
|
||||
CREATE TABLE IF NOT EXISTS queue (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
domain_id INTEGER REFERENCES domains(id),
|
||||
from_addr TEXT NOT NULL,
|
||||
to_addr TEXT NOT NULL,
|
||||
raw_enc BLOB NOT NULL,
|
||||
message_id TEXT,
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
attempts INTEGER NOT NULL DEFAULT 0,
|
||||
last_attempt TIMESTAMP,
|
||||
next_attempt TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
error_log TEXT NOT NULL DEFAULT '',
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
expires_at TIMESTAMP NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_queue_status ON queue(status, next_attempt);
|
||||
|
||||
-- Delivery log
|
||||
CREATE TABLE IF NOT EXISTS delivery_log (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
queue_id INTEGER REFERENCES queue(id) ON DELETE SET NULL,
|
||||
from_addr TEXT NOT NULL,
|
||||
to_addr TEXT NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
smtp_code INTEGER NOT NULL DEFAULT 0,
|
||||
smtp_message TEXT NOT NULL DEFAULT '',
|
||||
mx_host TEXT NOT NULL DEFAULT '',
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_delivery_log_created ON delivery_log(created_at);
|
||||
|
||||
-- Sessions
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
token_hash TEXT NOT NULL UNIQUE,
|
||||
ip TEXT NOT NULL DEFAULT '',
|
||||
user_agent TEXT NOT NULL DEFAULT '',
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
expires_at TIMESTAMP NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_token ON sessions(token_hash);
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_user ON sessions(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_expires ON sessions(expires_at);
|
||||
|
||||
-- IP bans
|
||||
CREATE TABLE IF NOT EXISTS ip_bans (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
ip TEXT NOT NULL UNIQUE,
|
||||
reason TEXT NOT NULL DEFAULT '',
|
||||
banned_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
expires_at TIMESTAMP,
|
||||
released_by TEXT NOT NULL DEFAULT ''
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_ip_bans_ip ON ip_bans(ip);
|
||||
CREATE INDEX IF NOT EXISTS idx_ip_bans_expires ON ip_bans(expires_at);
|
||||
|
||||
-- Login attempts
|
||||
CREATE TABLE IF NOT EXISTS login_attempts (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
ip TEXT NOT NULL,
|
||||
user_email TEXT NOT NULL DEFAULT '',
|
||||
success BOOLEAN NOT NULL DEFAULT 0,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_login_attempts_ip ON login_attempts(ip, created_at);
|
||||
|
||||
-- Security events
|
||||
CREATE TABLE IF NOT EXISTS security_events (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
type TEXT NOT NULL,
|
||||
ip TEXT NOT NULL DEFAULT '',
|
||||
user_id INTEGER REFERENCES users(id) ON DELETE SET NULL,
|
||||
detail TEXT NOT NULL DEFAULT '',
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_security_events_created ON security_events(created_at);
|
||||
|
||||
-- External accounts (Gmail / Outlook / custom IMAP)
|
||||
CREATE TABLE IF NOT EXISTS external_accounts (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
provider TEXT NOT NULL,
|
||||
email_address TEXT NOT NULL,
|
||||
display_name TEXT NOT NULL DEFAULT '',
|
||||
access_token_enc BLOB,
|
||||
refresh_token_enc BLOB,
|
||||
token_expiry TIMESTAMP,
|
||||
imap_host TEXT NOT NULL DEFAULT '',
|
||||
imap_port INTEGER NOT NULL DEFAULT 993,
|
||||
smtp_host TEXT NOT NULL DEFAULT '',
|
||||
smtp_port INTEGER NOT NULL DEFAULT 587,
|
||||
enabled BOOLEAN NOT NULL DEFAULT 1,
|
||||
sync_enabled BOOLEAN NOT NULL DEFAULT 1,
|
||||
last_sync TIMESTAMP,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_ext_accounts_user ON external_accounts(user_id);
|
||||
|
||||
-- Address books (CardDAV)
|
||||
CREATE TABLE IF NOT EXISTS address_books (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
name TEXT NOT NULL,
|
||||
description TEXT NOT NULL DEFAULT '',
|
||||
color TEXT NOT NULL DEFAULT '#4A90E2',
|
||||
sync_token INTEGER NOT NULL DEFAULT 1,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- Contacts (CardDAV)
|
||||
CREATE TABLE IF NOT EXISTS contacts (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
address_book_id INTEGER NOT NULL REFERENCES address_books(id) ON DELETE CASCADE,
|
||||
uid TEXT NOT NULL,
|
||||
vcard_enc BLOB NOT NULL,
|
||||
etag TEXT NOT NULL,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(address_book_id, uid)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_contacts_book ON contacts(address_book_id);
|
||||
|
||||
-- Calendars (CalDAV)
|
||||
CREATE TABLE IF NOT EXISTS calendars (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
name TEXT NOT NULL,
|
||||
description TEXT NOT NULL DEFAULT '',
|
||||
color TEXT NOT NULL DEFAULT '#4CAF50',
|
||||
timezone TEXT NOT NULL DEFAULT 'UTC',
|
||||
sync_token INTEGER NOT NULL DEFAULT 1,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- Calendar events (CalDAV)
|
||||
CREATE TABLE IF NOT EXISTS calendar_events (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
calendar_id INTEGER NOT NULL REFERENCES calendars(id) ON DELETE CASCADE,
|
||||
uid TEXT NOT NULL,
|
||||
ical_enc BLOB NOT NULL,
|
||||
etag TEXT NOT NULL,
|
||||
dt_start TIMESTAMP,
|
||||
dt_end TIMESTAMP,
|
||||
summary TEXT NOT NULL DEFAULT '',
|
||||
recurring BOOLEAN NOT NULL DEFAULT 0,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(calendar_id, uid)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_events_calendar ON calendar_events(calendar_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_events_dtstart ON calendar_events(calendar_id, dt_start);
|
||||
|
||||
-- Spam Bayesian tokens (per-user)
|
||||
CREATE TABLE IF NOT EXISTS spam_tokens (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
token TEXT NOT NULL,
|
||||
spam_count INTEGER NOT NULL DEFAULT 0,
|
||||
ham_count INTEGER NOT NULL DEFAULT 0,
|
||||
UNIQUE(user_id, token)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_spam_tokens_user ON spam_tokens(user_id, token);
|
||||
`
|
||||
@@ -0,0 +1,158 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/models"
|
||||
)
|
||||
|
||||
// GetUserByEmail returns the user with the given email (case-insensitive), or nil.
|
||||
func (d *DB) GetUserByEmail(ctx context.Context, email string) (*models.User, error) {
|
||||
row := d.db.QueryRowContext(ctx, `
|
||||
SELECT id, domain_id, username, email, password_hash, display_name,
|
||||
quota_bytes, used_bytes, enabled, admin, domain_admin,
|
||||
mfa_secret_enc, mfa_enabled, recovery_codes_enc, created_at, last_login
|
||||
FROM users WHERE lower(email)=lower(?)`, email)
|
||||
return scanUser(row)
|
||||
}
|
||||
|
||||
// GetUserByID returns the user with the given ID, or nil.
|
||||
func (d *DB) GetUserByID(ctx context.Context, id int64) (*models.User, error) {
|
||||
row := d.db.QueryRowContext(ctx, `
|
||||
SELECT id, domain_id, username, email, password_hash, display_name,
|
||||
quota_bytes, used_bytes, enabled, admin, domain_admin,
|
||||
mfa_secret_enc, mfa_enabled, recovery_codes_enc, created_at, last_login
|
||||
FROM users WHERE id=?`, id)
|
||||
return scanUser(row)
|
||||
}
|
||||
|
||||
// UserExistsByEmail returns true if any user (enabled or not) has this email or alias.
|
||||
func (d *DB) UserExistsByEmail(ctx context.Context, email string) (bool, error) {
|
||||
var count int
|
||||
err := d.db.QueryRowContext(ctx, `
|
||||
SELECT COUNT(*) FROM (
|
||||
SELECT 1 FROM users WHERE lower(email)=lower(?) AND enabled=1
|
||||
UNION ALL
|
||||
SELECT 1 FROM user_aliases WHERE lower(alias_email)=lower(?)
|
||||
)`, email, email).Scan(&count)
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// ResolveEmail returns the canonical user for an email or alias, or nil.
|
||||
func (d *DB) ResolveEmail(ctx context.Context, email string) (*models.User, error) {
|
||||
// Direct match first.
|
||||
u, err := d.GetUserByEmail(ctx, email)
|
||||
if err != nil || u != nil {
|
||||
return u, err
|
||||
}
|
||||
|
||||
// Alias match.
|
||||
var userID int64
|
||||
err = d.db.QueryRowContext(ctx,
|
||||
"SELECT user_id FROM user_aliases WHERE lower(alias_email)=lower(?)", email).
|
||||
Scan(&userID)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return d.GetUserByID(ctx, userID)
|
||||
}
|
||||
|
||||
// CreateUser inserts a new user. Returns the new ID.
|
||||
func (d *DB) CreateUser(ctx context.Context, domainID int64, username, email, passwordHash, displayName string, quotaBytes int64, domainAdmin bool) (int64, error) {
|
||||
res, err := d.db.ExecContext(ctx, `
|
||||
INSERT INTO users
|
||||
(domain_id, username, email, password_hash, display_name, quota_bytes,
|
||||
enabled, admin, domain_admin, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, 1, 0, ?, ?)`,
|
||||
domainID, username, email, passwordHash, displayName, quotaBytes, domainAdmin,
|
||||
time.Now().UTC())
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("create user: %w", err)
|
||||
}
|
||||
return res.LastInsertId()
|
||||
}
|
||||
|
||||
// UpdateUsedBytes sets the cached used_bytes for a user (approximate, updated on store).
|
||||
func (d *DB) UpdateUsedBytes(ctx context.Context, userID int64, delta int64) error {
|
||||
_, err := d.db.ExecContext(ctx,
|
||||
"UPDATE users SET used_bytes = MAX(0, used_bytes + ?) WHERE id=?",
|
||||
delta, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateLastLogin sets last_login to now.
|
||||
func (d *DB) UpdateLastLogin(ctx context.Context, userID int64) {
|
||||
d.db.ExecContext(ctx, //nolint:errcheck — best-effort
|
||||
"UPDATE users SET last_login=? WHERE id=?", time.Now().UTC(), userID)
|
||||
}
|
||||
|
||||
// ListUsers returns all users for a domain.
|
||||
func (d *DB) ListUsers(ctx context.Context, domainID int64) ([]*models.User, error) {
|
||||
rows, err := d.db.QueryContext(ctx, `
|
||||
SELECT id, domain_id, username, email, password_hash, display_name,
|
||||
quota_bytes, used_bytes, enabled, admin, domain_admin,
|
||||
mfa_secret_enc, mfa_enabled, recovery_codes_enc, created_at, last_login
|
||||
FROM users WHERE domain_id=? ORDER BY email`, domainID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var users []*models.User
|
||||
for rows.Next() {
|
||||
var u models.User
|
||||
var mfaEnc, rcEnc []byte
|
||||
var lastLogin sql.NullTime
|
||||
err := rows.Scan(
|
||||
&u.ID, &u.DomainID, &u.Username, &u.Email, &u.PasswordHash,
|
||||
&u.DisplayName, &u.QuotaBytes, &u.UsedBytes, &u.Enabled,
|
||||
&u.Admin, &u.DomainAdmin,
|
||||
&mfaEnc, &u.MFAEnabled, &rcEnc,
|
||||
&u.CreatedAt, &lastLogin,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u.MFASecretEnc = mfaEnc
|
||||
u.RecoveryCodesEnc = rcEnc
|
||||
if lastLogin.Valid {
|
||||
u.LastLogin = lastLogin.Time
|
||||
}
|
||||
users = append(users, &u)
|
||||
}
|
||||
return users, rows.Err()
|
||||
}
|
||||
|
||||
// ---- private ----
|
||||
|
||||
func scanUser(row *sql.Row) (*models.User, error) {
|
||||
var u models.User
|
||||
var mfaEnc, rcEnc []byte
|
||||
var lastLogin sql.NullTime
|
||||
|
||||
err := row.Scan(
|
||||
&u.ID, &u.DomainID, &u.Username, &u.Email, &u.PasswordHash,
|
||||
&u.DisplayName, &u.QuotaBytes, &u.UsedBytes, &u.Enabled,
|
||||
&u.Admin, &u.DomainAdmin,
|
||||
&mfaEnc, &u.MFAEnabled, &rcEnc,
|
||||
&u.CreatedAt, &lastLogin,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("scan user: %w", err)
|
||||
}
|
||||
u.MFASecretEnc = mfaEnc
|
||||
u.RecoveryCodesEnc = rcEnc
|
||||
if lastLogin.Valid {
|
||||
u.LastLogin = lastLogin.Time
|
||||
}
|
||||
return &u, nil
|
||||
}
|
||||
@@ -0,0 +1,211 @@
|
||||
// Package delivery implements outbound SMTP delivery: MX lookup, connection,
|
||||
// TLS upgrade, message submission. Used by the queue worker.
|
||||
package delivery
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/smtp"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
deliveryTimeout = 60 * time.Second
|
||||
connectTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
// Result holds the outcome of a single delivery attempt.
|
||||
type Result struct {
|
||||
MXHost string
|
||||
SMTPCode int
|
||||
Message string
|
||||
Perm bool // true = permanent failure (5xx), don't retry
|
||||
}
|
||||
|
||||
// Deliver attempts to deliver raw to the address to using a fresh SMTP
|
||||
// connection to the recipient domain's MX. Signs with the given EHLO hostname.
|
||||
// Returns a Result describing success or failure.
|
||||
func Deliver(ctx context.Context, ehloHostname, from, to string, raw []byte) *Result {
|
||||
at := strings.LastIndex(to, "@")
|
||||
if at < 0 {
|
||||
return &Result{Perm: true, Message: "invalid recipient address: " + to}
|
||||
}
|
||||
toDomain := strings.ToLower(to[at+1:])
|
||||
|
||||
mxHosts, err := lookupMX(ctx, toDomain)
|
||||
if err != nil {
|
||||
return &Result{Message: fmt.Sprintf("MX lookup %s: %v", toDomain, err)}
|
||||
}
|
||||
if len(mxHosts) == 0 {
|
||||
return &Result{Perm: true, Message: "no MX records for " + toDomain}
|
||||
}
|
||||
|
||||
var lastResult *Result
|
||||
for _, host := range mxHosts {
|
||||
r := deliver(ctx, ehloHostname, host, from, to, raw)
|
||||
lastResult = r
|
||||
if r.SMTPCode == 0 || r.SMTPCode/100 == 4 {
|
||||
// Temp error or connection failure — try next MX.
|
||||
continue
|
||||
}
|
||||
// 2xx = success, 5xx = permanent failure — stop trying.
|
||||
return r
|
||||
}
|
||||
return lastResult
|
||||
}
|
||||
|
||||
// deliver connects to one MX host and submits the message.
|
||||
func deliver(ctx context.Context, ehloHostname, mxHost, from, to string, raw []byte) *Result {
|
||||
addr := net.JoinHostPort(mxHost, "25")
|
||||
|
||||
dialCtx, cancel := context.WithTimeout(ctx, connectTimeout)
|
||||
defer cancel()
|
||||
|
||||
conn, err := (&net.Dialer{}).DialContext(dialCtx, "tcp", addr)
|
||||
if err != nil {
|
||||
return &Result{MXHost: mxHost, Message: fmt.Sprintf("connect %s: %v", addr, err)}
|
||||
}
|
||||
|
||||
// Wrap in a deadline for the full SMTP exchange.
|
||||
deadline := time.Now().Add(deliveryTimeout)
|
||||
_ = conn.SetDeadline(deadline)
|
||||
|
||||
c, err := smtp.NewClient(conn, mxHost)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return &Result{MXHost: mxHost, Message: fmt.Sprintf("smtp client %s: %v", mxHost, err)}
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
// EHLO.
|
||||
if err := c.Hello(ehloHostname); err != nil {
|
||||
return &Result{MXHost: mxHost, Message: fmt.Sprintf("EHLO: %v", err)}
|
||||
}
|
||||
|
||||
// Try STARTTLS (best effort — not all remote servers require it).
|
||||
if ok, _ := c.Extension("STARTTLS"); ok {
|
||||
tlsCfg := &tls.Config{
|
||||
ServerName: mxHost,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
if err := c.StartTLS(tlsCfg); err != nil {
|
||||
log.Printf("[delivery] STARTTLS %s failed (continuing plain): %v", mxHost, err)
|
||||
}
|
||||
}
|
||||
|
||||
// MAIL FROM.
|
||||
if err := c.Mail(from); err != nil {
|
||||
return smtpResult(mxHost, err)
|
||||
}
|
||||
|
||||
// RCPT TO.
|
||||
if err := c.Rcpt(to); err != nil {
|
||||
return smtpResult(mxHost, err)
|
||||
}
|
||||
|
||||
// DATA.
|
||||
wc, err := c.Data()
|
||||
if err != nil {
|
||||
return smtpResult(mxHost, err)
|
||||
}
|
||||
if _, err := wc.Write(raw); err != nil {
|
||||
wc.Close()
|
||||
return &Result{MXHost: mxHost, Message: fmt.Sprintf("write data: %v", err)}
|
||||
}
|
||||
if err := wc.Close(); err != nil {
|
||||
return smtpResult(mxHost, err)
|
||||
}
|
||||
|
||||
_ = c.Quit()
|
||||
|
||||
log.Printf("[delivery] delivered %s → %s via %s", from, to, mxHost)
|
||||
return &Result{MXHost: mxHost, SMTPCode: 250, Message: "2.0.0 OK"}
|
||||
}
|
||||
|
||||
// lookupMX resolves MX records and returns hosts sorted by priority.
|
||||
func lookupMX(ctx context.Context, domain string) ([]string, error) {
|
||||
r := net.DefaultResolver
|
||||
mxs, err := r.LookupMX(ctx, domain)
|
||||
if err != nil {
|
||||
// Treat NXDOMAIN as no-MX (not a transient error).
|
||||
if dnsErr, ok := err.(*net.DNSError); ok && dnsErr.IsNotFound {
|
||||
// Fall back: try A record (some small domains don't publish MX).
|
||||
addrs, aerr := r.LookupHost(ctx, domain)
|
||||
if aerr != nil || len(addrs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return []string{domain}, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Sort by priority ascending.
|
||||
sort.Slice(mxs, func(i, j int) bool {
|
||||
return mxs[i].Pref < mxs[j].Pref
|
||||
})
|
||||
|
||||
hosts := make([]string, 0, len(mxs))
|
||||
for _, mx := range mxs {
|
||||
h := strings.TrimSuffix(mx.Host, ".")
|
||||
if h != "" {
|
||||
hosts = append(hosts, h)
|
||||
}
|
||||
}
|
||||
return hosts, nil
|
||||
}
|
||||
|
||||
// smtpResult maps an smtp error to a Result, marking 5xx as permanent.
|
||||
func smtpResult(mxHost string, err error) *Result {
|
||||
if err == nil {
|
||||
return &Result{MXHost: mxHost, SMTPCode: 250, Message: "2.0.0 OK"}
|
||||
}
|
||||
msg := err.Error()
|
||||
code := parseCode(msg)
|
||||
return &Result{
|
||||
MXHost: mxHost,
|
||||
SMTPCode: code,
|
||||
Message: msg,
|
||||
Perm: code/100 == 5,
|
||||
}
|
||||
}
|
||||
|
||||
// parseCode extracts the leading 3-digit SMTP code from an error string.
|
||||
func parseCode(s string) int {
|
||||
if len(s) < 3 {
|
||||
return 0
|
||||
}
|
||||
var code int
|
||||
_, _ = fmt.Sscanf(s[:3], "%d", &code)
|
||||
return code
|
||||
}
|
||||
|
||||
// IsLocal reports whether the given domain is served locally. Used by the
|
||||
// queue worker to skip outbound delivery for internal mail.
|
||||
// The caller supplies its own domain list to avoid a DB call here.
|
||||
func IsLocal(domain string, localDomains []string) bool {
|
||||
domain = strings.ToLower(domain)
|
||||
for _, d := range localDomains {
|
||||
if strings.EqualFold(d, domain) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// RecipientDomain returns the domain part of an email address.
|
||||
func RecipientDomain(addr string) string {
|
||||
at := strings.LastIndex(addr, "@")
|
||||
if at < 0 {
|
||||
return ""
|
||||
}
|
||||
return strings.ToLower(addr[at+1:])
|
||||
}
|
||||
|
||||
// _ suppresses unused import if bytes is only used in tests later.
|
||||
var _ = bytes.NewReader
|
||||
@@ -0,0 +1,465 @@
|
||||
// Package dkim implements DKIM signing (outbound) and verification (inbound)
|
||||
// for RSA-2048 and Ed25519 keys per RFC 6376 and RFC 8463.
|
||||
package dkim
|
||||
|
||||
import (
|
||||
gocrypto "crypto"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Signer holds a loaded private key and signing metadata.
|
||||
type Signer struct {
|
||||
privateKey gocrypto.PrivateKey // *rsa.PrivateKey or ed25519.PrivateKey
|
||||
selector string
|
||||
domain string
|
||||
algo string // "rsa2048" | "ed25519"
|
||||
}
|
||||
|
||||
// GenerateKeyPair generates a DKIM key pair for the given algorithm.
|
||||
// algo must be "rsa2048" or "ed25519".
|
||||
// Returns PEM-encoded private key and PEM-encoded public key.
|
||||
func GenerateKeyPair(algo string) (privateKeyPEM, publicKeyPEM string, err error) {
|
||||
switch algo {
|
||||
case "rsa2048":
|
||||
priv, genErr := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if genErr != nil {
|
||||
return "", "", fmt.Errorf("dkim: generate RSA key: %w", genErr)
|
||||
}
|
||||
privDER := x509.MarshalPKCS1PrivateKey(priv)
|
||||
privBlock := &pem.Block{Type: "RSA PRIVATE KEY", Bytes: privDER}
|
||||
privateKeyPEM = string(pem.EncodeToMemory(privBlock))
|
||||
|
||||
pubDER, marshalErr := x509.MarshalPKIXPublicKey(&priv.PublicKey)
|
||||
if marshalErr != nil {
|
||||
return "", "", fmt.Errorf("dkim: marshal RSA public key: %w", marshalErr)
|
||||
}
|
||||
pubBlock := &pem.Block{Type: "PUBLIC KEY", Bytes: pubDER}
|
||||
publicKeyPEM = string(pem.EncodeToMemory(pubBlock))
|
||||
return privateKeyPEM, publicKeyPEM, nil
|
||||
|
||||
case "ed25519":
|
||||
pub, priv, genErr := ed25519.GenerateKey(rand.Reader)
|
||||
if genErr != nil {
|
||||
return "", "", fmt.Errorf("dkim: generate Ed25519 key: %w", genErr)
|
||||
}
|
||||
privDER, marshalErr := x509.MarshalPKCS8PrivateKey(priv)
|
||||
if marshalErr != nil {
|
||||
return "", "", fmt.Errorf("dkim: marshal Ed25519 private key: %w", marshalErr)
|
||||
}
|
||||
privBlock := &pem.Block{Type: "PRIVATE KEY", Bytes: privDER}
|
||||
privateKeyPEM = string(pem.EncodeToMemory(privBlock))
|
||||
|
||||
pubDER, marshalErr := x509.MarshalPKIXPublicKey(pub)
|
||||
if marshalErr != nil {
|
||||
return "", "", fmt.Errorf("dkim: marshal Ed25519 public key: %w", marshalErr)
|
||||
}
|
||||
pubBlock := &pem.Block{Type: "PUBLIC KEY", Bytes: pubDER}
|
||||
publicKeyPEM = string(pem.EncodeToMemory(pubBlock))
|
||||
return privateKeyPEM, publicKeyPEM, nil
|
||||
|
||||
default:
|
||||
return "", "", fmt.Errorf("dkim: unsupported algorithm %q (want rsa2048 or ed25519)", algo)
|
||||
}
|
||||
}
|
||||
|
||||
// NewSigner parses a PEM private key and returns a Signer.
|
||||
// Tries PKCS1 RSA first, then PKCS8 (Ed25519 or RSA).
|
||||
func NewSigner(privateKeyPEM, domain, selector string) (*Signer, error) {
|
||||
if domain == "" {
|
||||
return nil, fmt.Errorf("dkim: domain must not be empty")
|
||||
}
|
||||
if selector == "" {
|
||||
return nil, fmt.Errorf("dkim: selector must not be empty")
|
||||
}
|
||||
|
||||
block, _ := pem.Decode([]byte(privateKeyPEM))
|
||||
if block == nil {
|
||||
return nil, fmt.Errorf("dkim: failed to decode PEM block")
|
||||
}
|
||||
|
||||
// Try PKCS1 RSA.
|
||||
if rsaKey, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil {
|
||||
return &Signer{
|
||||
privateKey: rsaKey,
|
||||
selector: selector,
|
||||
domain: domain,
|
||||
algo: "rsa2048",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Try PKCS8 (covers Ed25519 and RSA PKCS8).
|
||||
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dkim: parse private key: %w", err)
|
||||
}
|
||||
switch k := key.(type) {
|
||||
case ed25519.PrivateKey:
|
||||
return &Signer{
|
||||
privateKey: k,
|
||||
selector: selector,
|
||||
domain: domain,
|
||||
algo: "ed25519",
|
||||
}, nil
|
||||
case *rsa.PrivateKey:
|
||||
return &Signer{
|
||||
privateKey: k,
|
||||
selector: selector,
|
||||
domain: domain,
|
||||
algo: "rsa2048",
|
||||
}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("dkim: unsupported private key type %T", key)
|
||||
}
|
||||
}
|
||||
|
||||
// Sign produces a DKIM-Signature header for the given RFC822 message.
|
||||
// Returns the complete "DKIM-Signature: ..." header line (no trailing CRLF).
|
||||
func (s *Signer) Sign(message []byte) (string, error) {
|
||||
// Split at first blank line (\r\n\r\n or \n\n).
|
||||
headerBytes, bodyBytes := splitMessage(message)
|
||||
|
||||
// Canonicalize body (relaxed).
|
||||
canonBody := canonicalizeBodyRelaxed(bodyBytes)
|
||||
|
||||
// Body hash.
|
||||
bodyHash := sha256.Sum256(canonBody)
|
||||
bh := base64.StdEncoding.EncodeToString(bodyHash[:])
|
||||
|
||||
// Determine algorithm tag.
|
||||
var aTag string
|
||||
switch s.algo {
|
||||
case "ed25519":
|
||||
aTag = "ed25519-sha256"
|
||||
default:
|
||||
aTag = "rsa-sha256"
|
||||
}
|
||||
|
||||
// Signed header fields (lower-case, in sign order).
|
||||
signedFields := "from:to:subject:date:message-id"
|
||||
|
||||
// Build DKIM-Signature with b= empty.
|
||||
ts := fmt.Sprintf("%d", time.Now().Unix())
|
||||
sigHeader := fmt.Sprintf(
|
||||
"DKIM-Signature: v=1; a=%s; c=relaxed/relaxed; d=%s; s=%s; t=%s; bh=%s; h=%s; b=",
|
||||
aTag, s.domain, s.selector, ts, bh, signedFields,
|
||||
)
|
||||
|
||||
// Canonicalize headers to sign + the sig header (b= empty).
|
||||
hdrMap := parseHeaders(headerBytes)
|
||||
var sb strings.Builder
|
||||
for _, field := range strings.Split(signedFields, ":") {
|
||||
field = strings.TrimSpace(field)
|
||||
val, ok := hdrMap[strings.ToLower(field)]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
sb.WriteString(canonicalizeHeaderRelaxed(field, val))
|
||||
sb.WriteString("\r\n")
|
||||
}
|
||||
// Append the DKIM-Signature line itself (with b= empty), canonicalized.
|
||||
sb.WriteString(canonicalizeHeaderRelaxed("dkim-signature", strings.TrimPrefix(sigHeader, "DKIM-Signature: ")))
|
||||
|
||||
dataToSign := []byte(sb.String())
|
||||
hash := sha256.Sum256(dataToSign)
|
||||
|
||||
var sigBytes []byte
|
||||
var signErr error
|
||||
|
||||
switch s.algo {
|
||||
case "ed25519":
|
||||
privKey, ok := s.privateKey.(ed25519.PrivateKey)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("dkim: private key type mismatch for ed25519")
|
||||
}
|
||||
// RFC 8463: ed25519-sha256 — sign the SHA-256 hash of the data.
|
||||
sigBytes = ed25519.Sign(privKey, hash[:])
|
||||
|
||||
default:
|
||||
privKey, ok := s.privateKey.(*rsa.PrivateKey)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("dkim: private key type mismatch for RSA")
|
||||
}
|
||||
sigBytes, signErr = rsa.SignPKCS1v15(rand.Reader, privKey, gocrypto.SHA256, hash[:])
|
||||
if signErr != nil {
|
||||
return "", fmt.Errorf("dkim: RSA sign: %w", signErr)
|
||||
}
|
||||
}
|
||||
|
||||
b := base64.StdEncoding.EncodeToString(sigBytes)
|
||||
return sigHeader + b, nil
|
||||
}
|
||||
|
||||
// Verify finds the DKIM-Signature in message, fetches the DNS public key,
|
||||
// and verifies the signature. Returns the signing domain on success.
|
||||
func Verify(message []byte) (domain string, err error) {
|
||||
headerBytes, bodyBytes := splitMessage(message)
|
||||
hdrMap := parseHeaders(headerBytes)
|
||||
|
||||
sigVal, ok := hdrMap["dkim-signature"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("dkim: no DKIM-Signature header found")
|
||||
}
|
||||
|
||||
params := parseDKIMParams(sigVal)
|
||||
|
||||
sel, ok := params["s"]
|
||||
if !ok || sel == "" {
|
||||
return "", fmt.Errorf("dkim: missing selector (s=) in DKIM-Signature")
|
||||
}
|
||||
dom, ok := params["d"]
|
||||
if !ok || dom == "" {
|
||||
return "", fmt.Errorf("dkim: missing domain (d=) in DKIM-Signature")
|
||||
}
|
||||
bh64, _ := params["bh"]
|
||||
b64, ok := params["b"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("dkim: missing signature (b=) in DKIM-Signature")
|
||||
}
|
||||
signedFields, _ := params["h"]
|
||||
|
||||
// Verify body hash.
|
||||
canonBody := canonicalizeBodyRelaxed(bodyBytes)
|
||||
bodyHash := sha256.Sum256(canonBody)
|
||||
expectedBH := base64.StdEncoding.EncodeToString(bodyHash[:])
|
||||
// Strip whitespace from DNS-retrieved bh for comparison.
|
||||
cleanBH := strings.Map(func(r rune) rune {
|
||||
if r == ' ' || r == '\t' || r == '\r' || r == '\n' {
|
||||
return -1
|
||||
}
|
||||
return r
|
||||
}, bh64)
|
||||
if cleanBH != expectedBH {
|
||||
return "", fmt.Errorf("dkim: body hash mismatch")
|
||||
}
|
||||
|
||||
// DNS lookup.
|
||||
lookupName := sel + "._domainkey." + dom
|
||||
txts, lookupErr := net.LookupTXT(lookupName)
|
||||
if lookupErr != nil {
|
||||
return "", fmt.Errorf("dkim: DNS lookup %s: %w", lookupName, lookupErr)
|
||||
}
|
||||
if len(txts) == 0 {
|
||||
return "", fmt.Errorf("dkim: no TXT record at %s", lookupName)
|
||||
}
|
||||
|
||||
// Join TXT record parts (DNS may split at 255 bytes).
|
||||
dnsRecord := strings.Join(txts, "")
|
||||
dnsParams := parseDKIMParams(dnsRecord)
|
||||
pVal, ok := dnsParams["p"]
|
||||
if !ok || pVal == "" {
|
||||
return "", fmt.Errorf("dkim: no p= (public key) in DNS TXT record")
|
||||
}
|
||||
|
||||
pubDER, decErr := base64.StdEncoding.DecodeString(pVal)
|
||||
if decErr != nil {
|
||||
return "", fmt.Errorf("dkim: decode public key base64: %w", decErr)
|
||||
}
|
||||
|
||||
// Rebuild the data-to-sign exactly as Signer.Sign did.
|
||||
// Canonicalize the signed headers.
|
||||
var sb strings.Builder
|
||||
for _, field := range strings.Split(signedFields, ":") {
|
||||
field = strings.TrimSpace(field)
|
||||
val, ok2 := hdrMap[strings.ToLower(field)]
|
||||
if !ok2 {
|
||||
continue
|
||||
}
|
||||
sb.WriteString(canonicalizeHeaderRelaxed(field, val))
|
||||
sb.WriteString("\r\n")
|
||||
}
|
||||
// Append DKIM-Signature with b= stripped (zeroed), canonicalized.
|
||||
cleanedSig := stripBValue(sigVal)
|
||||
sb.WriteString(canonicalizeHeaderRelaxed("dkim-signature", cleanedSig))
|
||||
|
||||
dataToVerify := []byte(sb.String())
|
||||
hash := sha256.Sum256(dataToVerify)
|
||||
|
||||
sigBytes, decErr := base64.StdEncoding.DecodeString(
|
||||
strings.Map(func(r rune) rune {
|
||||
if r == ' ' || r == '\t' || r == '\r' || r == '\n' {
|
||||
return -1
|
||||
}
|
||||
return r
|
||||
}, b64),
|
||||
)
|
||||
if decErr != nil {
|
||||
return "", fmt.Errorf("dkim: decode signature base64: %w", decErr)
|
||||
}
|
||||
|
||||
// Try RSA PKIX public key first.
|
||||
if pubKey, rsaErr := x509.ParsePKIXPublicKey(pubDER); rsaErr == nil {
|
||||
switch k := pubKey.(type) {
|
||||
case *rsa.PublicKey:
|
||||
if err := rsa.VerifyPKCS1v15(k, gocrypto.SHA256, hash[:], sigBytes); err != nil {
|
||||
return "", fmt.Errorf("dkim: RSA signature invalid: %w", err)
|
||||
}
|
||||
return dom, nil
|
||||
case ed25519.PublicKey:
|
||||
if !ed25519.Verify(k, hash[:], sigBytes) {
|
||||
return "", fmt.Errorf("dkim: Ed25519 signature invalid")
|
||||
}
|
||||
return dom, nil
|
||||
default:
|
||||
return "", fmt.Errorf("dkim: unsupported public key type %T", pubKey)
|
||||
}
|
||||
}
|
||||
|
||||
// Try raw Ed25519 public key (32 bytes).
|
||||
if len(pubDER) == ed25519.PublicKeySize {
|
||||
edPub := ed25519.PublicKey(pubDER)
|
||||
if !ed25519.Verify(edPub, hash[:], sigBytes) {
|
||||
return "", fmt.Errorf("dkim: Ed25519 signature invalid")
|
||||
}
|
||||
return dom, nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("dkim: unable to parse public key from DNS record")
|
||||
}
|
||||
|
||||
// DNSRecord returns the DKIM TXT record string for publishing.
|
||||
func DNSRecord(selector, domain, publicKeyPEM string) string {
|
||||
block, _ := pem.Decode([]byte(publicKeyPEM))
|
||||
var p string
|
||||
if block != nil {
|
||||
p = base64.StdEncoding.EncodeToString(block.Bytes)
|
||||
}
|
||||
return fmt.Sprintf("%s._domainkey.%s. IN TXT \"v=DKIM1; k=rsa; p=%s\"", selector, domain, p)
|
||||
}
|
||||
|
||||
// SPFRecord returns the recommended SPF TXT record for a domain.
|
||||
func SPFRecord(domain string) string {
|
||||
return fmt.Sprintf("%s IN TXT \"v=spf1 mx a ~all\"", domain)
|
||||
}
|
||||
|
||||
// ---- Internals ----
|
||||
|
||||
// splitMessage splits at the first blank line (CRLF or LF variants).
|
||||
func splitMessage(msg []byte) (headers, body []byte) {
|
||||
s := string(msg)
|
||||
for _, sep := range []string{"\r\n\r\n", "\n\n"} {
|
||||
if idx := strings.Index(s, sep); idx >= 0 {
|
||||
return []byte(s[:idx]), []byte(s[idx+len(sep):])
|
||||
}
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// canonicalizeBodyRelaxed applies RFC 6376 relaxed body canonicalization.
|
||||
func canonicalizeBodyRelaxed(body []byte) []byte {
|
||||
if len(body) == 0 {
|
||||
return []byte("\r\n")
|
||||
}
|
||||
// Normalize line endings.
|
||||
s := strings.ReplaceAll(string(body), "\r\n", "\n")
|
||||
s = strings.ReplaceAll(s, "\r", "\n")
|
||||
|
||||
lines := strings.Split(s, "\n")
|
||||
var out []string
|
||||
for _, line := range lines {
|
||||
// Collapse whitespace runs within each line, trim trailing whitespace.
|
||||
fields := strings.Fields(line)
|
||||
out = append(out, strings.Join(fields, " "))
|
||||
}
|
||||
|
||||
// Strip trailing empty lines.
|
||||
for len(out) > 0 && out[len(out)-1] == "" {
|
||||
out = out[:len(out)-1]
|
||||
}
|
||||
|
||||
result := strings.Join(out, "\r\n") + "\r\n"
|
||||
return []byte(result)
|
||||
}
|
||||
|
||||
// canonicalizeHeaderRelaxed applies RFC 6376 relaxed header canonicalization
|
||||
// to a single header name + value. Returns "lowername:value" (no trailing CRLF).
|
||||
func canonicalizeHeaderRelaxed(name, value string) string {
|
||||
name = strings.ToLower(strings.TrimSpace(name))
|
||||
// Collapse all whitespace (including CRLF folding) to a single space.
|
||||
value = strings.ReplaceAll(value, "\r\n", " ")
|
||||
value = strings.ReplaceAll(value, "\r", " ")
|
||||
value = strings.ReplaceAll(value, "\n", " ")
|
||||
// Collapse multiple spaces.
|
||||
parts := strings.Fields(value)
|
||||
value = strings.Join(parts, " ")
|
||||
value = strings.TrimSpace(value)
|
||||
return name + ":" + value
|
||||
}
|
||||
|
||||
// parseHeaders builds a map of lower-cased header name → last value.
|
||||
// Handles multi-line (folded) headers.
|
||||
func parseHeaders(headerBytes []byte) map[string]string {
|
||||
m := make(map[string]string)
|
||||
lines := strings.Split(strings.ReplaceAll(string(headerBytes), "\r\n", "\n"), "\n")
|
||||
|
||||
var curName, curVal string
|
||||
flush := func() {
|
||||
if curName != "" {
|
||||
m[strings.ToLower(curName)] = curVal
|
||||
}
|
||||
}
|
||||
|
||||
for _, line := range lines {
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
if len(line) > 0 && (line[0] == ' ' || line[0] == '\t') {
|
||||
// Folded continuation.
|
||||
curVal += " " + strings.TrimSpace(line)
|
||||
continue
|
||||
}
|
||||
flush()
|
||||
idx := strings.IndexByte(line, ':')
|
||||
if idx < 0 {
|
||||
curName = ""
|
||||
curVal = ""
|
||||
continue
|
||||
}
|
||||
curName = line[:idx]
|
||||
curVal = strings.TrimSpace(line[idx+1:])
|
||||
}
|
||||
flush()
|
||||
return m
|
||||
}
|
||||
|
||||
// parseDKIMParams parses semicolon-separated tag=value pairs (DKIM and DNS TXT).
|
||||
func parseDKIMParams(s string) map[string]string {
|
||||
m := make(map[string]string)
|
||||
for _, part := range strings.Split(s, ";") {
|
||||
part = strings.TrimSpace(part)
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
idx := strings.IndexByte(part, '=')
|
||||
if idx < 0 {
|
||||
continue
|
||||
}
|
||||
key := strings.TrimSpace(part[:idx])
|
||||
val := strings.TrimSpace(part[idx+1:])
|
||||
m[key] = val
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// stripBValue removes the value of the b= tag (sets it to empty) so the
|
||||
// header can be re-canonicalized for verification.
|
||||
func stripBValue(sigVal string) string {
|
||||
// Find b= and zero everything after it up to the next ;.
|
||||
parts := strings.Split(sigVal, ";")
|
||||
for i, p := range parts {
|
||||
trimmed := strings.TrimSpace(p)
|
||||
if strings.HasPrefix(trimmed, "b=") {
|
||||
parts[i] = " b="
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, ";")
|
||||
}
|
||||
@@ -0,0 +1,125 @@
|
||||
// Package dmarc implements basic DMARC (RFC 7489) policy evaluation.
|
||||
// Only stdlib net is used for DNS. Supports p=none, p=quarantine, p=reject.
|
||||
package dmarc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/spf"
|
||||
)
|
||||
|
||||
// Policy represents the DMARC disposition policy.
|
||||
type Policy int
|
||||
|
||||
const (
|
||||
PolicyNone Policy = iota // p=none
|
||||
PolicyQuarantine // p=quarantine
|
||||
PolicyReject // p=reject
|
||||
)
|
||||
|
||||
func (p Policy) String() string {
|
||||
return [...]string{"none", "quarantine", "reject"}[p]
|
||||
}
|
||||
|
||||
// Result holds the DMARC evaluation outcome.
|
||||
type Result struct {
|
||||
Pass bool
|
||||
Policy Policy
|
||||
Disposition string // none | quarantine | reject
|
||||
Reason string
|
||||
}
|
||||
|
||||
// Check evaluates DMARC policy for the given envelope-from domain.
|
||||
// spfPass: whether SPF passed for this envelope-from domain.
|
||||
// dkimDomains: set of signing domains that produced a valid DKIM signature.
|
||||
// Returns a Result and logs no external calls beyond DNS.
|
||||
func Check(fromDomain string, spfResult spf.Result, dkimDomains []string) *Result {
|
||||
record, err := fetchDMARC(fromDomain)
|
||||
if err != nil || record == "" {
|
||||
// No DMARC record — not a failure, but can't enforce.
|
||||
return &Result{
|
||||
Pass: true,
|
||||
Policy: PolicyNone,
|
||||
Disposition: "none",
|
||||
Reason: "no DMARC record for " + fromDomain,
|
||||
}
|
||||
}
|
||||
|
||||
policy := parsePolicy(record)
|
||||
|
||||
// SPF alignment (relaxed): envelope-from domain must equal or be a subdomain
|
||||
// of the DMARC organisational domain.
|
||||
orgDomain := orgDomainOf(fromDomain)
|
||||
spfAligned := spfResult == spf.ResultPass
|
||||
|
||||
// DKIM alignment (relaxed): at least one valid DKIM signature domain must
|
||||
// equal or be a subdomain of the org domain.
|
||||
dkimAligned := false
|
||||
for _, d := range dkimDomains {
|
||||
if strings.EqualFold(d, fromDomain) || strings.HasSuffix(strings.ToLower(d), "."+strings.ToLower(orgDomain)) {
|
||||
dkimAligned = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
pass := spfAligned || dkimAligned
|
||||
reason := fmt.Sprintf("SPF=%v DKIM-aligned=%v policy=%s", spfAligned, dkimAligned, policy)
|
||||
|
||||
if pass {
|
||||
return &Result{
|
||||
Pass: true,
|
||||
Policy: policy,
|
||||
Disposition: "none",
|
||||
Reason: reason,
|
||||
}
|
||||
}
|
||||
|
||||
return &Result{
|
||||
Pass: false,
|
||||
Policy: policy,
|
||||
Disposition: policy.String(),
|
||||
Reason: reason,
|
||||
}
|
||||
}
|
||||
|
||||
// fetchDMARC queries TXT at _dmarc.<domain> and returns the first DMARC record.
|
||||
func fetchDMARC(domain string) (string, error) {
|
||||
txts, err := net.LookupTXT("_dmarc." + domain)
|
||||
if err != nil {
|
||||
if dnsErr, ok := err.(*net.DNSError); ok && dnsErr.IsNotFound {
|
||||
return "", nil
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
for _, txt := range txts {
|
||||
txt = strings.TrimSpace(txt)
|
||||
if strings.HasPrefix(txt, "v=DMARC1") {
|
||||
return txt, nil
|
||||
}
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// parsePolicy extracts the p= tag from a DMARC record.
|
||||
func parsePolicy(record string) Policy {
|
||||
for _, part := range strings.Split(record, ";") {
|
||||
part = strings.TrimSpace(part)
|
||||
if strings.HasPrefix(strings.ToLower(part), "p=") {
|
||||
switch strings.ToLower(part[2:]) {
|
||||
case "quarantine":
|
||||
return PolicyQuarantine
|
||||
case "reject":
|
||||
return PolicyReject
|
||||
}
|
||||
}
|
||||
}
|
||||
return PolicyNone
|
||||
}
|
||||
|
||||
// orgDomainOf returns a simplified "organisational domain" — for this
|
||||
// implementation we use the domain as-is (a full PSL lookup is out of scope).
|
||||
func orgDomainOf(domain string) string {
|
||||
return strings.ToLower(domain)
|
||||
}
|
||||
@@ -0,0 +1,80 @@
|
||||
// Package imap provides an IMAP4rev2 server backed by the encrypted SQLite store.
|
||||
package imap
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/emersion/go-imap/v2"
|
||||
"github.com/emersion/go-imap/v2/imapserver"
|
||||
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/auth"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/config"
|
||||
appCrypto "ghb.freebede.com/nahakubuilder/mailgosend/internal/crypto"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/db"
|
||||
)
|
||||
|
||||
// Deps groups dependencies needed by the IMAP server.
|
||||
type Deps struct {
|
||||
DB *db.DB
|
||||
Crypt *appCrypto.Crypto
|
||||
Brute *auth.BruteGuard
|
||||
Cfg *config.Config
|
||||
}
|
||||
|
||||
// NewServer creates an IMAP4rev2 server for port 143 (STARTTLS).
|
||||
func NewServer(d *Deps, tlsCfg *tls.Config) *imapserver.Server {
|
||||
return imapserver.New(&imapserver.Options{
|
||||
NewSession: func(c *imapserver.Conn) (imapserver.Session, *imapserver.GreetingData, error) {
|
||||
clientIP := connRemoteIP(c)
|
||||
if d.Brute != nil && d.Cfg.BruteMaxTries > 0 {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
if banned, err := d.Brute.IsBanned(ctx, clientIP); err == nil && banned {
|
||||
cancel()
|
||||
return nil, nil, fmt.Errorf("connection refused")
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
return &IMAPSession{deps: d, clientIP: clientIP}, &imapserver.GreetingData{}, nil
|
||||
},
|
||||
Caps: imap.CapSet{
|
||||
imap.CapIMAP4rev2: {},
|
||||
imap.CapIMAP4rev1: {},
|
||||
},
|
||||
TLSConfig: tlsCfg,
|
||||
InsecureAuth: tlsCfg == nil,
|
||||
})
|
||||
}
|
||||
|
||||
// ListenAndServe listens on addr (port 143 / STARTTLS).
|
||||
func ListenAndServe(s *imapserver.Server, addr, name string) error {
|
||||
ln, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s listen %s: %w", name, addr, err)
|
||||
}
|
||||
log.Printf("[%s] listening on %s", name, addr)
|
||||
return s.Serve(ln)
|
||||
}
|
||||
|
||||
// ListenAndServeTLS listens with implicit TLS (port 993 / IMAPS).
|
||||
func ListenAndServeTLS(s *imapserver.Server, addr, name string) error {
|
||||
log.Printf("[%s] listening on %s (TLS)", name, addr)
|
||||
return s.ListenAndServeTLS(addr)
|
||||
}
|
||||
|
||||
// connRemoteIP returns the client IP from an imapserver.Conn.
|
||||
func connRemoteIP(c *imapserver.Conn) string {
|
||||
if c == nil {
|
||||
return ""
|
||||
}
|
||||
nc := c.NetConn()
|
||||
if nc == nil {
|
||||
return ""
|
||||
}
|
||||
host, _, _ := net.SplitHostPort(nc.RemoteAddr().String())
|
||||
return host
|
||||
}
|
||||
@@ -0,0 +1,884 @@
|
||||
package imap
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/emersion/go-imap/v2"
|
||||
"github.com/emersion/go-imap/v2/imapserver"
|
||||
"github.com/emersion/go-message/textproto"
|
||||
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/crypto"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/db"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/models"
|
||||
)
|
||||
|
||||
const mailboxDelim rune = '/'
|
||||
|
||||
// msgEntry holds the in-memory descriptor for one message in the selected mailbox.
|
||||
type msgEntry struct {
|
||||
dbID int64
|
||||
uid imap.UID
|
||||
isRead bool
|
||||
isStarred bool
|
||||
isDraft bool
|
||||
isDeleted bool
|
||||
extraFlags string
|
||||
size int64
|
||||
internalDate time.Time
|
||||
}
|
||||
|
||||
func (e *msgEntry) flagList() []imap.Flag {
|
||||
var flags []imap.Flag
|
||||
if e.isRead {
|
||||
flags = append(flags, imap.FlagSeen)
|
||||
}
|
||||
if e.isStarred {
|
||||
flags = append(flags, imap.FlagFlagged)
|
||||
}
|
||||
if e.isDraft {
|
||||
flags = append(flags, imap.FlagDraft)
|
||||
}
|
||||
if e.isDeleted {
|
||||
flags = append(flags, imap.FlagDeleted)
|
||||
}
|
||||
for _, f := range strings.Fields(e.extraFlags) {
|
||||
flags = append(flags, imap.Flag(f))
|
||||
}
|
||||
return flags
|
||||
}
|
||||
|
||||
// IMAPSession implements imapserver.SessionIMAP4rev2.
|
||||
type IMAPSession struct {
|
||||
deps *Deps
|
||||
clientIP string
|
||||
user *models.User // set after Login
|
||||
|
||||
selectedMailbox *models.Mailbox
|
||||
msgs []msgEntry // index+1 = IMAP sequence number
|
||||
mboxTracker *imapserver.MailboxTracker
|
||||
sessionTracker *imapserver.SessionTracker
|
||||
}
|
||||
|
||||
var _ imapserver.SessionIMAP4rev2 = (*IMAPSession)(nil)
|
||||
|
||||
// ---- Authentication ----
|
||||
|
||||
func (s *IMAPSession) Close() error {
|
||||
if s.sessionTracker != nil {
|
||||
s.sessionTracker.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *IMAPSession) Login(username, password string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
user, err := s.deps.DB.GetUserByEmail(ctx, username)
|
||||
if err != nil {
|
||||
log.Printf("[imap] login error %s: %v", username, err)
|
||||
return imapserver.ErrAuthFailed
|
||||
}
|
||||
if user == nil || !user.Enabled {
|
||||
s.recordAttempt(ctx, username, false)
|
||||
return imapserver.ErrAuthFailed
|
||||
}
|
||||
if err := crypto.CheckPassword(user.PasswordHash, password); err != nil {
|
||||
log.Printf("[imap] auth failed %s from %s", username, s.clientIP)
|
||||
s.recordAttempt(ctx, username, false)
|
||||
return imapserver.ErrAuthFailed
|
||||
}
|
||||
|
||||
s.user = user
|
||||
s.deps.DB.UpdateLastLogin(ctx, user.ID) //nolint:errcheck
|
||||
s.recordAttempt(ctx, username, true)
|
||||
log.Printf("[imap] auth OK %s from %s", username, s.clientIP)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *IMAPSession) recordAttempt(ctx context.Context, email string, success bool) {
|
||||
if s.deps.Brute != nil {
|
||||
s.deps.Brute.RecordAttempt(ctx, s.clientIP, email, success) //nolint:errcheck
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Mailbox management ----
|
||||
|
||||
func (s *IMAPSession) Select(mailboxName string, options *imap.SelectOptions) (*imap.SelectData, error) {
|
||||
s.doUnselect()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
mbox, err := s.deps.DB.GetMailbox(ctx, s.user.ID, mailboxName)
|
||||
if err != nil {
|
||||
return nil, imapErr(err)
|
||||
}
|
||||
if mbox == nil {
|
||||
return nil, noSuchMailbox()
|
||||
}
|
||||
|
||||
msgs, err := s.deps.DB.ListIMAPMessages(ctx, mbox.ID)
|
||||
if err != nil {
|
||||
return nil, imapErr(err)
|
||||
}
|
||||
|
||||
s.selectedMailbox = mbox
|
||||
s.msgs = make([]msgEntry, 0, len(msgs))
|
||||
for _, m := range msgs {
|
||||
s.msgs = append(s.msgs, msgEntry{
|
||||
dbID: m.ID,
|
||||
uid: imap.UID(m.UID),
|
||||
isRead: m.IsRead,
|
||||
isStarred: m.IsStarred,
|
||||
isDraft: m.IsDraft,
|
||||
extraFlags: m.Flags,
|
||||
size: m.SizeBytes,
|
||||
internalDate: m.ReceivedAt,
|
||||
})
|
||||
}
|
||||
|
||||
s.mboxTracker = imapserver.NewMailboxTracker(uint32(len(s.msgs)))
|
||||
s.sessionTracker = s.mboxTracker.NewSession()
|
||||
|
||||
var firstUnseen uint32
|
||||
for i, m := range s.msgs {
|
||||
if !m.isRead {
|
||||
firstUnseen = uint32(i + 1)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return &imap.SelectData{
|
||||
Flags: allFlags(),
|
||||
PermanentFlags: allFlags(),
|
||||
NumMessages: uint32(len(s.msgs)),
|
||||
FirstUnseenSeqNum: firstUnseen,
|
||||
UIDNext: imap.UID(mbox.UIDNext),
|
||||
UIDValidity: mbox.UIDValidity,
|
||||
List: &imap.ListData{
|
||||
Mailbox: mailboxName,
|
||||
Delim: mailboxDelim,
|
||||
Attrs: mboxAttrs(mbox),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *IMAPSession) Unselect() error {
|
||||
s.doUnselect()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *IMAPSession) doUnselect() {
|
||||
if s.sessionTracker != nil {
|
||||
s.sessionTracker.Close()
|
||||
s.sessionTracker = nil
|
||||
}
|
||||
s.selectedMailbox = nil
|
||||
s.msgs = nil
|
||||
s.mboxTracker = nil
|
||||
}
|
||||
|
||||
func (s *IMAPSession) Create(mailboxName string, options *imap.CreateOptions) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
existing, _ := s.deps.DB.GetMailbox(ctx, s.user.ID, mailboxName)
|
||||
if existing != nil {
|
||||
return &imap.Error{Type: imap.StatusResponseTypeNo, Code: imap.ResponseCodeAlreadyExists, Text: "Mailbox already exists"}
|
||||
}
|
||||
_, err := s.deps.DB.CreateMailbox(ctx, s.user.ID, mailboxName, "", nil)
|
||||
return imapErr(err)
|
||||
}
|
||||
|
||||
func (s *IMAPSession) Delete(mailboxName string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
mbox, _ := s.deps.DB.GetMailbox(ctx, s.user.ID, mailboxName)
|
||||
if mbox == nil {
|
||||
return noSuchMailbox()
|
||||
}
|
||||
if mbox.Type != "" {
|
||||
return &imap.Error{Type: imap.StatusResponseTypeNo, Text: "Cannot delete system mailbox"}
|
||||
}
|
||||
_, err := s.deps.DB.SQL().ExecContext(ctx, "DELETE FROM mailboxes WHERE id=?", mbox.ID)
|
||||
return imapErr(err)
|
||||
}
|
||||
|
||||
func (s *IMAPSession) Rename(oldName, newName string, options *imap.RenameOptions) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
mbox, _ := s.deps.DB.GetMailbox(ctx, s.user.ID, oldName)
|
||||
if mbox == nil {
|
||||
return noSuchMailbox()
|
||||
}
|
||||
return imapErr(s.deps.DB.RenameMailbox(ctx, mbox.ID, newName))
|
||||
}
|
||||
|
||||
func (s *IMAPSession) Subscribe(mailboxName string) error {
|
||||
return s.setSubscribed(mailboxName, true)
|
||||
}
|
||||
|
||||
func (s *IMAPSession) Unsubscribe(mailboxName string) error {
|
||||
return s.setSubscribed(mailboxName, false)
|
||||
}
|
||||
|
||||
func (s *IMAPSession) setSubscribed(name string, sub bool) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
mbox, _ := s.deps.DB.GetMailbox(ctx, s.user.ID, name)
|
||||
if mbox == nil {
|
||||
return noSuchMailbox()
|
||||
}
|
||||
return imapErr(s.deps.DB.SetMailboxSubscribed(ctx, mbox.ID, sub))
|
||||
}
|
||||
|
||||
func (s *IMAPSession) List(w *imapserver.ListWriter, ref string, patterns []string, options *imap.ListOptions) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if len(patterns) == 0 {
|
||||
return w.WriteList(&imap.ListData{
|
||||
Attrs: []imap.MailboxAttr{imap.MailboxAttrNoSelect},
|
||||
Delim: mailboxDelim,
|
||||
})
|
||||
}
|
||||
|
||||
mailboxes, err := s.deps.DB.ListMailboxes(ctx, s.user.ID)
|
||||
if err != nil {
|
||||
return imapErr(err)
|
||||
}
|
||||
|
||||
type entry struct {
|
||||
name string
|
||||
data imap.ListData
|
||||
}
|
||||
var entries []entry
|
||||
|
||||
for _, mbox := range mailboxes {
|
||||
if options.SelectSubscribed && !mbox.Subscribed {
|
||||
continue
|
||||
}
|
||||
matched := false
|
||||
for _, pat := range patterns {
|
||||
if imapserver.MatchList(mbox.Name, mailboxDelim, ref, pat) {
|
||||
matched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !matched {
|
||||
continue
|
||||
}
|
||||
|
||||
data := imap.ListData{
|
||||
Mailbox: mbox.Name,
|
||||
Delim: mailboxDelim,
|
||||
Attrs: mboxAttrs(mbox),
|
||||
}
|
||||
if mbox.Subscribed {
|
||||
data.Attrs = append(data.Attrs, imap.MailboxAttrSubscribed)
|
||||
}
|
||||
if options.ReturnStatus != nil {
|
||||
sd, _ := s.statusFor(ctx, mbox, options.ReturnStatus)
|
||||
data.Status = sd
|
||||
}
|
||||
entries = append(entries, entry{name: mbox.Name, data: data})
|
||||
}
|
||||
|
||||
sort.Slice(entries, func(i, j int) bool { return entries[i].name < entries[j].name })
|
||||
for _, e := range entries {
|
||||
if err := w.WriteList(&e.data); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *IMAPSession) Status(mailboxName string, options *imap.StatusOptions) (*imap.StatusData, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
mbox, _ := s.deps.DB.GetMailbox(ctx, s.user.ID, mailboxName)
|
||||
if mbox == nil {
|
||||
return nil, noSuchMailbox()
|
||||
}
|
||||
return s.statusFor(ctx, mbox, options)
|
||||
}
|
||||
|
||||
func (s *IMAPSession) statusFor(ctx context.Context, mbox *models.Mailbox, opts *imap.StatusOptions) (*imap.StatusData, error) {
|
||||
data := &imap.StatusData{Mailbox: mbox.Name}
|
||||
|
||||
total, unseen, err := s.deps.DB.GetMailboxMessageCounts(ctx, mbox.ID)
|
||||
if err != nil {
|
||||
return nil, imapErr(err)
|
||||
}
|
||||
|
||||
if opts.NumMessages {
|
||||
n := uint32(total)
|
||||
data.NumMessages = &n
|
||||
}
|
||||
if opts.NumUnseen {
|
||||
n := uint32(unseen)
|
||||
data.NumUnseen = &n
|
||||
}
|
||||
if opts.UIDNext {
|
||||
data.UIDNext = imap.UID(mbox.UIDNext)
|
||||
}
|
||||
if opts.UIDValidity {
|
||||
data.UIDValidity = mbox.UIDValidity
|
||||
}
|
||||
if opts.NumRecent {
|
||||
n := uint32(0)
|
||||
data.NumRecent = &n
|
||||
}
|
||||
if opts.NumDeleted {
|
||||
n := uint32(0)
|
||||
data.NumDeleted = &n
|
||||
}
|
||||
if opts.Size {
|
||||
sz, _ := s.deps.DB.GetMailboxSize(ctx, mbox.ID)
|
||||
data.Size = &sz
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (s *IMAPSession) Append(mailboxName string, r imap.LiteralReader, options *imap.AppendOptions) (*imap.AppendData, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
mbox, _ := s.deps.DB.GetMailbox(ctx, s.user.ID, mailboxName)
|
||||
if mbox == nil {
|
||||
return nil, &imap.Error{Type: imap.StatusResponseTypeNo, Code: imap.ResponseCodeTryCreate, Text: "No such mailbox"}
|
||||
}
|
||||
|
||||
raw, err := readLiteral(r, s.deps.Cfg.MaxMessageSize)
|
||||
if err != nil {
|
||||
return nil, imapErr(err)
|
||||
}
|
||||
|
||||
key, err := s.deps.Crypt.DeriveKey("messages", s.user.ID)
|
||||
if err != nil {
|
||||
return nil, imapErr(fmt.Errorf("derive key: %w", err))
|
||||
}
|
||||
rawEnc, err := crypto.Encrypt(key, raw)
|
||||
if err != nil {
|
||||
return nil, imapErr(err)
|
||||
}
|
||||
|
||||
uid, err := s.deps.DB.NextUID(ctx, mbox.ID)
|
||||
if err != nil {
|
||||
return nil, imapErr(err)
|
||||
}
|
||||
|
||||
internalDate := time.Now().UTC()
|
||||
isRead, isDraft := false, false
|
||||
var extraParts []string
|
||||
|
||||
if options != nil {
|
||||
if !options.Time.IsZero() {
|
||||
internalDate = options.Time
|
||||
}
|
||||
for _, f := range options.Flags {
|
||||
switch f {
|
||||
case imap.FlagSeen:
|
||||
isRead = true
|
||||
case imap.FlagDraft:
|
||||
isDraft = true
|
||||
default:
|
||||
extraParts = append(extraParts, string(f))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ins := &db.MessageInsert{
|
||||
MailboxID: mbox.ID,
|
||||
UID: uid,
|
||||
SizeBytes: int64(len(raw)),
|
||||
RawEnc: rawEnc,
|
||||
IsRead: isRead,
|
||||
IsDraft: isDraft,
|
||||
Flags: strings.Join(extraParts, " "),
|
||||
Date: internalDate,
|
||||
}
|
||||
if _, err := s.deps.DB.InsertMessage(ctx, ins); err != nil {
|
||||
return nil, imapErr(err)
|
||||
}
|
||||
|
||||
return &imap.AppendData{
|
||||
UIDValidity: mbox.UIDValidity,
|
||||
UID: imap.UID(uid),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *IMAPSession) Poll(w *imapserver.UpdateWriter, allowExpunge bool) error {
|
||||
if s.sessionTracker == nil {
|
||||
return nil
|
||||
}
|
||||
return s.sessionTracker.Poll(w, allowExpunge)
|
||||
}
|
||||
|
||||
func (s *IMAPSession) Idle(w *imapserver.UpdateWriter, stop <-chan struct{}) error {
|
||||
if s.sessionTracker == nil {
|
||||
return nil
|
||||
}
|
||||
return s.sessionTracker.Idle(w, stop)
|
||||
}
|
||||
|
||||
// ---- Selected-state operations ----
|
||||
|
||||
func (s *IMAPSession) Expunge(w *imapserver.ExpungeWriter, uids *imap.UIDSet) error {
|
||||
if s.selectedMailbox == nil {
|
||||
return noSelectedMailbox()
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Soft-delete entries marked \Deleted (filtered by UIDs if provided).
|
||||
for i := range s.msgs {
|
||||
m := &s.msgs[i]
|
||||
if uids != nil && !uids.Contains(m.uid) {
|
||||
continue
|
||||
}
|
||||
if m.isDeleted {
|
||||
s.deps.DB.SoftDeleteMessage(ctx, m.dbID) //nolint:errcheck
|
||||
}
|
||||
}
|
||||
|
||||
deletedUIDs, err := s.deps.DB.HardDeleteMessages(ctx, s.selectedMailbox.ID)
|
||||
if err != nil {
|
||||
return imapErr(err)
|
||||
}
|
||||
deletedSet := make(map[imap.UID]struct{}, len(deletedUIDs))
|
||||
for _, uid := range deletedUIDs {
|
||||
deletedSet[imap.UID(uid)] = struct{}{}
|
||||
}
|
||||
|
||||
// Collect seq nums to expunge (in ascending order, write in reverse).
|
||||
var seqNums []uint32
|
||||
for i, m := range s.msgs {
|
||||
if _, ok := deletedSet[m.uid]; ok {
|
||||
seqNums = append(seqNums, uint32(i+1))
|
||||
}
|
||||
}
|
||||
for i := len(seqNums) - 1; i >= 0; i-- {
|
||||
if err := w.WriteExpunge(seqNums[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Remove from in-memory list.
|
||||
filtered := s.msgs[:0]
|
||||
for _, m := range s.msgs {
|
||||
if _, ok := deletedSet[m.uid]; !ok {
|
||||
filtered = append(filtered, m)
|
||||
}
|
||||
}
|
||||
s.msgs = filtered
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *IMAPSession) Search(kind imapserver.NumKind, criteria *imap.SearchCriteria, options *imap.SearchOptions) (*imap.SearchData, error) {
|
||||
if s.selectedMailbox == nil {
|
||||
return nil, noSelectedMailbox()
|
||||
}
|
||||
|
||||
if kind == imapserver.NumKindUID {
|
||||
var result imap.UIDSet
|
||||
for i, m := range s.msgs {
|
||||
seqNum := uint32(i + 1)
|
||||
if s.matchCriteria(seqNum, &m, criteria) {
|
||||
result.AddNum(m.uid)
|
||||
}
|
||||
}
|
||||
return &imap.SearchData{All: result}, nil
|
||||
}
|
||||
|
||||
var result imap.SeqSet
|
||||
for i, m := range s.msgs {
|
||||
seqNum := uint32(i + 1)
|
||||
if s.matchCriteria(seqNum, &m, criteria) {
|
||||
result.AddNum(seqNum)
|
||||
}
|
||||
}
|
||||
return &imap.SearchData{All: result}, nil
|
||||
}
|
||||
|
||||
func (s *IMAPSession) matchCriteria(seqNum uint32, m *msgEntry, c *imap.SearchCriteria) bool {
|
||||
if c == nil {
|
||||
return true
|
||||
}
|
||||
for _, seqSet := range c.SeqNum {
|
||||
if !seqSet.Contains(seqNum) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
for _, uidSet := range c.UID {
|
||||
if !uidSet.Contains(m.uid) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
flagSet := make(map[imap.Flag]struct{})
|
||||
for _, f := range m.flagList() {
|
||||
flagSet[f] = struct{}{}
|
||||
}
|
||||
for _, f := range c.Flag {
|
||||
if _, ok := flagSet[f]; !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
for _, f := range c.NotFlag {
|
||||
if _, ok := flagSet[f]; ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if c.Larger != 0 && m.size <= c.Larger {
|
||||
return false
|
||||
}
|
||||
if c.Smaller != 0 && m.size >= c.Smaller {
|
||||
return false
|
||||
}
|
||||
if !matchDate(m.internalDate, c.Since, c.Before) {
|
||||
return false
|
||||
}
|
||||
for _, sub := range c.Not {
|
||||
if s.matchCriteria(seqNum, m, &sub) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
for _, or := range c.Or {
|
||||
if !s.matchCriteria(seqNum, m, &or[0]) && !s.matchCriteria(seqNum, m, &or[1]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *IMAPSession) Fetch(w *imapserver.FetchWriter, numSet imap.NumSet, options *imap.FetchOptions) error {
|
||||
if s.selectedMailbox == nil {
|
||||
return noSelectedMailbox()
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
key, err := s.deps.Crypt.DeriveKey("messages", s.user.ID)
|
||||
if err != nil {
|
||||
return imapErr(err)
|
||||
}
|
||||
|
||||
for i := range s.msgs {
|
||||
m := &s.msgs[i]
|
||||
seqNum := uint32(i + 1)
|
||||
if !numSetContains(numSet, seqNum, m.uid) {
|
||||
continue
|
||||
}
|
||||
|
||||
rw := w.CreateMessage(seqNum)
|
||||
rw.WriteUID(m.uid)
|
||||
|
||||
if options.Flags {
|
||||
rw.WriteFlags(m.flagList())
|
||||
}
|
||||
if options.InternalDate {
|
||||
rw.WriteInternalDate(m.internalDate)
|
||||
}
|
||||
if options.RFC822Size {
|
||||
rw.WriteRFC822Size(m.size)
|
||||
}
|
||||
|
||||
needRaw := options.Envelope || options.BodyStructure != nil ||
|
||||
len(options.BodySection) > 0 || len(options.BinarySection) > 0 ||
|
||||
len(options.BinarySectionSize) > 0
|
||||
|
||||
if needRaw {
|
||||
rawEnc, err := s.deps.DB.GetMessageRaw(ctx, m.dbID)
|
||||
if err != nil || rawEnc == nil {
|
||||
rw.Close() //nolint:errcheck
|
||||
continue
|
||||
}
|
||||
raw, err := crypto.Decrypt(key, rawEnc)
|
||||
if err != nil {
|
||||
log.Printf("[imap] decrypt %d: %v", m.dbID, err)
|
||||
rw.Close() //nolint:errcheck
|
||||
continue
|
||||
}
|
||||
|
||||
if options.Envelope {
|
||||
if env := extractEnvelope(raw); env != nil {
|
||||
rw.WriteEnvelope(env)
|
||||
}
|
||||
}
|
||||
if options.BodyStructure != nil {
|
||||
rw.WriteBodyStructure(imapserver.ExtractBodyStructure(bytes.NewReader(raw)))
|
||||
}
|
||||
for _, bs := range options.BodySection {
|
||||
buf := imapserver.ExtractBodySection(bytes.NewReader(raw), bs)
|
||||
wc := rw.WriteBodySection(bs, int64(len(buf)))
|
||||
wc.Write(buf) //nolint:errcheck
|
||||
wc.Close() //nolint:errcheck
|
||||
}
|
||||
for _, bs := range options.BinarySection {
|
||||
buf := imapserver.ExtractBinarySection(bytes.NewReader(raw), bs)
|
||||
wc := rw.WriteBinarySection(bs, int64(len(buf)))
|
||||
wc.Write(buf) //nolint:errcheck
|
||||
wc.Close() //nolint:errcheck
|
||||
}
|
||||
for _, bss := range options.BinarySectionSize {
|
||||
n := imapserver.ExtractBinarySectionSize(bytes.NewReader(raw), bss)
|
||||
rw.WriteBinarySectionSize(bss, n)
|
||||
}
|
||||
}
|
||||
|
||||
if err := rw.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *IMAPSession) Store(w *imapserver.FetchWriter, numSet imap.NumSet, flags *imap.StoreFlags, options *imap.StoreOptions) error {
|
||||
if s.selectedMailbox == nil {
|
||||
return noSelectedMailbox()
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
for i := range s.msgs {
|
||||
m := &s.msgs[i]
|
||||
seqNum := uint32(i + 1)
|
||||
if !numSetContains(numSet, seqNum, m.uid) {
|
||||
continue
|
||||
}
|
||||
applyStoreFlags(m, flags)
|
||||
s.deps.DB.SetMessageFlags(ctx, m.dbID, m.isRead, m.isStarred, m.isDraft, m.extraFlags) //nolint:errcheck
|
||||
if m.isDeleted {
|
||||
s.deps.DB.SoftDeleteMessage(ctx, m.dbID) //nolint:errcheck
|
||||
}
|
||||
|
||||
if !flags.Silent {
|
||||
rw := w.CreateMessage(seqNum)
|
||||
rw.WriteUID(m.uid)
|
||||
rw.WriteFlags(m.flagList())
|
||||
if err := rw.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *IMAPSession) Copy(numSet imap.NumSet, destName string) (*imap.CopyData, error) {
|
||||
if s.selectedMailbox == nil {
|
||||
return nil, noSelectedMailbox()
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
dest, _ := s.deps.DB.GetMailbox(ctx, s.user.ID, destName)
|
||||
if dest == nil {
|
||||
return nil, &imap.Error{Type: imap.StatusResponseTypeNo, Code: imap.ResponseCodeTryCreate, Text: "No such mailbox"}
|
||||
}
|
||||
|
||||
var sourceUIDs, destUIDs imap.UIDSet
|
||||
for i, m := range s.msgs {
|
||||
seqNum := uint32(i + 1)
|
||||
if !numSetContains(numSet, seqNum, m.uid) {
|
||||
continue
|
||||
}
|
||||
newUID, err := s.deps.DB.CopyMessageToMailbox(ctx, m.dbID, dest.ID, s.user.ID)
|
||||
if err != nil {
|
||||
log.Printf("[imap] copy msg %d: %v", m.dbID, err)
|
||||
continue
|
||||
}
|
||||
sourceUIDs.AddNum(m.uid)
|
||||
destUIDs.AddNum(imap.UID(newUID))
|
||||
}
|
||||
return &imap.CopyData{
|
||||
UIDValidity: dest.UIDValidity,
|
||||
SourceUIDs: sourceUIDs,
|
||||
DestUIDs: destUIDs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *IMAPSession) Move(w *imapserver.MoveWriter, numSet imap.NumSet, destName string) error {
|
||||
copyData, err := s.Copy(numSet, destName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := w.WriteCopyData(copyData); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var seqNums []uint32
|
||||
for i := range s.msgs {
|
||||
m := &s.msgs[i]
|
||||
seqNum := uint32(i + 1)
|
||||
if !numSetContains(numSet, seqNum, m.uid) {
|
||||
continue
|
||||
}
|
||||
s.deps.DB.SoftDeleteMessage(ctx, m.dbID) //nolint:errcheck
|
||||
seqNums = append(seqNums, seqNum)
|
||||
}
|
||||
s.deps.DB.HardDeleteMessages(ctx, s.selectedMailbox.ID) //nolint:errcheck
|
||||
|
||||
for i := len(seqNums) - 1; i >= 0; i-- {
|
||||
if err := w.WriteExpunge(seqNums[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Remove moved messages.
|
||||
kept := s.msgs[:0]
|
||||
for i, m := range s.msgs {
|
||||
seqNum := uint32(i + 1)
|
||||
if !numSetContains(numSet, seqNum, m.uid) {
|
||||
kept = append(kept, m)
|
||||
}
|
||||
}
|
||||
s.msgs = kept
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *IMAPSession) Namespace() (*imap.NamespaceData, error) {
|
||||
return &imap.NamespaceData{
|
||||
Personal: []imap.NamespaceDescriptor{{Delim: mailboxDelim}},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ---- Helpers ----
|
||||
|
||||
// numSetContains checks seqNum (for SeqSet) or uid (for UIDSet).
|
||||
func numSetContains(numSet imap.NumSet, seqNum uint32, uid imap.UID) bool {
|
||||
switch ns := numSet.(type) {
|
||||
case imap.SeqSet:
|
||||
return ns.Contains(seqNum)
|
||||
case imap.UIDSet:
|
||||
return ns.Contains(uid)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func applyStoreFlags(m *msgEntry, store *imap.StoreFlags) {
|
||||
flagMap := map[imap.Flag]*bool{
|
||||
imap.FlagSeen: &m.isRead,
|
||||
imap.FlagFlagged: &m.isStarred,
|
||||
imap.FlagDraft: &m.isDraft,
|
||||
imap.FlagDeleted: &m.isDeleted,
|
||||
}
|
||||
switch store.Op {
|
||||
case imap.StoreFlagsSet:
|
||||
m.isRead, m.isStarred, m.isDraft, m.isDeleted = false, false, false, false
|
||||
m.extraFlags = ""
|
||||
for _, f := range store.Flags {
|
||||
if ptr, ok := flagMap[f]; ok {
|
||||
*ptr = true
|
||||
} else {
|
||||
m.extraFlags = strings.TrimSpace(m.extraFlags + " " + string(f))
|
||||
}
|
||||
}
|
||||
case imap.StoreFlagsAdd:
|
||||
for _, f := range store.Flags {
|
||||
if ptr, ok := flagMap[f]; ok {
|
||||
*ptr = true
|
||||
} else if !strings.Contains(m.extraFlags, string(f)) {
|
||||
m.extraFlags = strings.TrimSpace(m.extraFlags + " " + string(f))
|
||||
}
|
||||
}
|
||||
case imap.StoreFlagsDel:
|
||||
for _, f := range store.Flags {
|
||||
if ptr, ok := flagMap[f]; ok {
|
||||
*ptr = false
|
||||
} else {
|
||||
m.extraFlags = strings.TrimSpace(strings.ReplaceAll(m.extraFlags, string(f), ""))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func allFlags() []imap.Flag {
|
||||
return []imap.Flag{
|
||||
imap.FlagSeen,
|
||||
imap.FlagAnswered,
|
||||
imap.FlagFlagged,
|
||||
imap.FlagDeleted,
|
||||
imap.FlagDraft,
|
||||
}
|
||||
}
|
||||
|
||||
func mboxAttrs(mbox *models.Mailbox) []imap.MailboxAttr {
|
||||
su := db.MailboxTypeToSpecialUse(mbox.Type)
|
||||
if su != "" {
|
||||
return []imap.MailboxAttr{imap.MailboxAttr(su)}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func extractEnvelope(raw []byte) *imap.Envelope {
|
||||
br := bufio.NewReader(bytes.NewReader(raw))
|
||||
header, err := textproto.ReadHeader(br)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return imapserver.ExtractEnvelope(header)
|
||||
}
|
||||
|
||||
func matchDate(t, since, before time.Time) bool {
|
||||
t = time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC)
|
||||
if !since.IsZero() && t.Before(since) {
|
||||
return false
|
||||
}
|
||||
if !before.IsZero() && !t.Before(before) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func readLiteral(r imap.LiteralReader, maxSize int64) ([]byte, error) {
|
||||
buf := make([]byte, 0, 4096)
|
||||
tmp := make([]byte, 32768)
|
||||
var total int64
|
||||
for {
|
||||
n, err := r.Read(tmp)
|
||||
if n > 0 {
|
||||
total += int64(n)
|
||||
if total > maxSize {
|
||||
return nil, fmt.Errorf("message too large")
|
||||
}
|
||||
buf = append(buf, tmp[:n]...)
|
||||
}
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
func imapErr(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if _, ok := err.(*imap.Error); ok {
|
||||
return err
|
||||
}
|
||||
return &imap.Error{Type: imap.StatusResponseTypeNo, Text: "server error"}
|
||||
}
|
||||
|
||||
func noSuchMailbox() error {
|
||||
return &imap.Error{Type: imap.StatusResponseTypeNo, Code: imap.ResponseCodeNonExistent, Text: "No such mailbox"}
|
||||
}
|
||||
|
||||
func noSelectedMailbox() error {
|
||||
return &imap.Error{Type: imap.StatusResponseTypeBad, Text: "no mailbox selected"}
|
||||
}
|
||||
@@ -0,0 +1,295 @@
|
||||
// Package models defines all shared data structures.
|
||||
package models
|
||||
|
||||
import "time"
|
||||
|
||||
// ---- Domain & User ----
|
||||
|
||||
type Domain struct {
|
||||
ID int64
|
||||
Name string
|
||||
Enabled bool
|
||||
DKIMPrivateEnc []byte // AES-256-GCM encrypted PEM private key
|
||||
DKIMPublic string // PEM public key (not secret)
|
||||
DKIMSelector string
|
||||
DKIMAlgo string // rsa2048 | ed25519
|
||||
SPFPolicy string // stored for reference
|
||||
DMARCPolicy string // stored for reference
|
||||
MaxUsers int
|
||||
MaxQuotaBytes int64
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
type User struct {
|
||||
ID int64
|
||||
DomainID int64
|
||||
Username string // local part (before @)
|
||||
Email string // full address
|
||||
PasswordHash string // bcrypt
|
||||
DisplayName string
|
||||
QuotaBytes int64
|
||||
UsedBytes int64
|
||||
Enabled bool
|
||||
Admin bool // global admin
|
||||
DomainAdmin bool // admin of own domain only
|
||||
MFASecretEnc []byte // encrypted TOTP secret; nil = MFA disabled
|
||||
MFAEnabled bool
|
||||
RecoveryCodesEnc []byte // encrypted JSON array of one-time codes
|
||||
CreatedAt time.Time
|
||||
LastLogin time.Time
|
||||
}
|
||||
|
||||
type UserAlias struct {
|
||||
ID int64
|
||||
UserID int64
|
||||
AliasEmail string
|
||||
}
|
||||
|
||||
// ---- Mail storage ----
|
||||
|
||||
// MailboxType canonical values.
|
||||
const (
|
||||
MailboxInbox = "inbox"
|
||||
MailboxSent = "sent"
|
||||
MailboxDrafts = "drafts"
|
||||
MailboxTrash = "trash"
|
||||
MailboxSpam = "spam"
|
||||
MailboxArchive = "archive"
|
||||
MailboxCustom = "custom"
|
||||
)
|
||||
|
||||
type Mailbox struct {
|
||||
ID int64
|
||||
UserID int64
|
||||
Name string // IMAP mailbox name (e.g. "INBOX", "Sent", "Folder/Sub")
|
||||
Type string // MailboxInbox … MailboxCustom
|
||||
ParentID *int64
|
||||
UIDValidity uint32
|
||||
UIDNext uint32
|
||||
Subscribed bool
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
ID int64
|
||||
MailboxID int64
|
||||
UID uint32
|
||||
MessageID string // RFC 2822 Message-ID header (not encrypted — needed for threading)
|
||||
Subject string // plaintext for list/search (consider sensitivity vs usability)
|
||||
FromEmail string
|
||||
FromName string
|
||||
ToList string // comma-separated
|
||||
CCList string
|
||||
BCCList string
|
||||
ReplyTo string
|
||||
Date time.Time
|
||||
BodyTextEnc []byte // AES-256-GCM encrypted text/plain
|
||||
BodyHTMLEnc []byte // AES-256-GCM encrypted text/html
|
||||
RawEnc []byte // AES-256-GCM encrypted full RFC822 message
|
||||
SizeBytes int64
|
||||
HasAttachment bool
|
||||
IsRead bool
|
||||
IsStarred bool
|
||||
IsDraft bool
|
||||
Flags string // raw IMAP flags string
|
||||
SpamScore int
|
||||
ReceivedAt time.Time
|
||||
DeletedAt *time.Time // soft delete; nil = not deleted
|
||||
}
|
||||
|
||||
type Attachment struct {
|
||||
ID int64
|
||||
MessageID int64
|
||||
Filename string
|
||||
ContentType string
|
||||
SizeBytes int64
|
||||
DataEnc []byte // AES-256-GCM encrypted bytes (or nil if fs-backed)
|
||||
DataPath string // filesystem path (if STORAGE_BACKEND=fs)
|
||||
ContentID string // MIME Content-ID for inline images
|
||||
Inline bool
|
||||
MIMEPath string // dot-separated MIME section path e.g. "1.2"
|
||||
}
|
||||
|
||||
// ---- Delivery queue ----
|
||||
|
||||
const (
|
||||
QueuePending = "pending"
|
||||
QueueSending = "sending"
|
||||
QueueDelivered = "delivered"
|
||||
QueueFailed = "failed"
|
||||
QueueBounced = "bounced"
|
||||
)
|
||||
|
||||
type QueueEntry struct {
|
||||
ID int64
|
||||
DomainID int64
|
||||
FromAddr string
|
||||
ToAddr string
|
||||
RawEnc []byte // encrypted RFC822
|
||||
MessageID string
|
||||
Status string // QueuePending … QueueBounced
|
||||
Attempts int
|
||||
LastAttempt *time.Time
|
||||
NextAttempt time.Time
|
||||
ErrorLog string
|
||||
CreatedAt time.Time
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
type DeliveryLog struct {
|
||||
ID int64
|
||||
QueueID int64
|
||||
FromAddr string
|
||||
ToAddr string
|
||||
Status string
|
||||
SMTPCode int
|
||||
SMTPMessage string
|
||||
MXHost string
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// ---- Sessions & Security ----
|
||||
|
||||
type Session struct {
|
||||
ID int64
|
||||
UserID int64
|
||||
TokenHash string // SHA-256 hex of the raw bearer token
|
||||
IP string
|
||||
UserAgent string
|
||||
CreatedAt time.Time
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
type IPBan struct {
|
||||
ID int64
|
||||
IP string
|
||||
Reason string
|
||||
BannedAt time.Time
|
||||
ExpiresAt *time.Time // nil = permanent
|
||||
ReleasedBy string
|
||||
}
|
||||
|
||||
type LoginAttempt struct {
|
||||
ID int64
|
||||
IP string
|
||||
UserEmail string
|
||||
Success bool
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
type SecurityEvent struct {
|
||||
ID int64
|
||||
Type string // brute_ban | auth_fail | relay_attempt | etc.
|
||||
IP string
|
||||
UserID *int64
|
||||
Detail string
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// ---- External accounts (Gmail / Outlook / custom IMAP) ----
|
||||
|
||||
const (
|
||||
ProviderGmail = "gmail"
|
||||
ProviderOutlook = "outlook"
|
||||
ProviderCustom = "custom"
|
||||
)
|
||||
|
||||
type ExternalAccount struct {
|
||||
ID int64
|
||||
UserID int64
|
||||
Provider string // ProviderGmail | ProviderOutlook | ProviderCustom
|
||||
EmailAddress string
|
||||
DisplayName string
|
||||
AccessTokenEnc []byte // encrypted OAuth2 or password
|
||||
RefreshTokenEnc []byte // encrypted OAuth2 refresh token
|
||||
TokenExpiry time.Time
|
||||
IMAPHost string
|
||||
IMAPPort int
|
||||
SMTPHost string
|
||||
SMTPPort int
|
||||
Enabled bool
|
||||
SyncEnabled bool
|
||||
LastSync time.Time
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// ---- Contacts (CardDAV) ----
|
||||
|
||||
type AddressBook struct {
|
||||
ID int64
|
||||
UserID int64
|
||||
Name string
|
||||
Description string
|
||||
Color string
|
||||
SyncToken int64
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
type Contact struct {
|
||||
ID int64
|
||||
AddressBookID int64
|
||||
UID string
|
||||
VCardEnc []byte // AES-256-GCM encrypted vCard 3.0/4.0
|
||||
ETag string // hex(sha256(raw vcard)) — for sync
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// ---- Calendar (CalDAV) ----
|
||||
|
||||
type Calendar struct {
|
||||
ID int64
|
||||
UserID int64
|
||||
Name string
|
||||
Description string
|
||||
Color string
|
||||
Timezone string
|
||||
SyncToken int64
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
type CalendarEvent struct {
|
||||
ID int64
|
||||
CalendarID int64
|
||||
UID string
|
||||
ICalEnc []byte // AES-256-GCM encrypted iCalendar data
|
||||
ETag string // hex(sha256(raw ical))
|
||||
DTStart time.Time
|
||||
DTEnd time.Time
|
||||
Summary string // plaintext for calendar grid display
|
||||
Recurring bool
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// ---- Spam / Bayesian ----
|
||||
|
||||
type SpamToken struct {
|
||||
ID int64
|
||||
UserID int64
|
||||
Token string
|
||||
SpamCount int64
|
||||
HamCount int64
|
||||
}
|
||||
|
||||
// ---- Compose helpers (not persisted directly) ----
|
||||
|
||||
type Attachment_Upload struct {
|
||||
Filename string
|
||||
ContentType string
|
||||
Data []byte
|
||||
}
|
||||
|
||||
type ComposeRequest struct {
|
||||
AccountID int64 // 0 = local account, >0 = external account ID
|
||||
FromEmail string
|
||||
To []string
|
||||
CC []string
|
||||
BCC []string
|
||||
Subject string
|
||||
BodyText string
|
||||
BodyHTML string
|
||||
Attachments []Attachment_Upload
|
||||
InReplyTo string
|
||||
References string
|
||||
}
|
||||
@@ -0,0 +1,265 @@
|
||||
package smtp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/mail"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
gosmtp "github.com/emersion/go-smtp"
|
||||
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/dkim"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/models"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/spam"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/spf"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/storage"
|
||||
)
|
||||
|
||||
// InboundBackend implements gosmtp.Backend for port 25 (receive from internet).
|
||||
type InboundBackend struct {
|
||||
deps *Deps
|
||||
}
|
||||
|
||||
func (b *InboundBackend) NewSession(c *gosmtp.Conn) (gosmtp.Session, error) {
|
||||
clientIP, _, _ := net.SplitHostPort(c.Conn().RemoteAddr().String())
|
||||
|
||||
// Check IP ban.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var banned bool
|
||||
var err error
|
||||
if b.deps.Cfg.BruteMaxTries > 0 && b.deps.Brute != nil {
|
||||
banned, err = b.deps.Brute.IsBanned(ctx, clientIP)
|
||||
if err != nil {
|
||||
log.Printf("[smtp/inbound] ban check error %s: %v", clientIP, err)
|
||||
}
|
||||
}
|
||||
if banned {
|
||||
return nil, fmt.Errorf("connection refused")
|
||||
}
|
||||
|
||||
return &InboundSession{
|
||||
deps: b.deps,
|
||||
clientIP: net.ParseIP(clientIP),
|
||||
log: func(f string, a ...any) { log.Printf("[smtp/inbound] "+f, a...) },
|
||||
}, nil
|
||||
}
|
||||
|
||||
// InboundSession handles one inbound SMTP connection.
|
||||
type InboundSession struct {
|
||||
deps *Deps
|
||||
clientIP net.IP
|
||||
from string
|
||||
fromDomain string
|
||||
rcpts []string // validated local recipients (email addresses)
|
||||
rcptUsers []int64 // corresponding user IDs
|
||||
spfResult spf.Result
|
||||
log func(string, ...any)
|
||||
}
|
||||
|
||||
// AuthPlain is not used on port 25; always return error.
|
||||
func (s *InboundSession) AuthPlain(username, password string) error {
|
||||
return fmt.Errorf("AUTH not supported on port 25")
|
||||
}
|
||||
|
||||
func (s *InboundSession) Mail(from string, opts *gosmtp.MailOptions) error {
|
||||
if from == "" {
|
||||
// Bounce messages use empty envelope sender — allow.
|
||||
s.from = ""
|
||||
s.fromDomain = ""
|
||||
return nil
|
||||
}
|
||||
|
||||
addr, err := mail.ParseAddress(from)
|
||||
if err != nil {
|
||||
return &gosmtp.SMTPError{Code: 501, EnhancedCode: gosmtp.EnhancedCode{5, 1, 7}, Message: "invalid sender address"}
|
||||
}
|
||||
|
||||
s.from = addr.Address
|
||||
at := strings.LastIndex(s.from, "@")
|
||||
if at < 0 {
|
||||
return &gosmtp.SMTPError{Code: 501, EnhancedCode: gosmtp.EnhancedCode{5, 1, 7}, Message: "invalid sender domain"}
|
||||
}
|
||||
s.fromDomain = strings.ToLower(s.from[at+1:])
|
||||
|
||||
// SPF check (async, best-effort).
|
||||
if s.deps.Cfg.SpamCheckSPF && s.clientIP != nil && s.fromDomain != "" {
|
||||
s.spfResult, _ = spf.Check(s.clientIP, s.fromDomain)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *InboundSession) Rcpt(to string, opts *gosmtp.RcptOptions) error {
|
||||
addr, err := mail.ParseAddress(to)
|
||||
if err != nil {
|
||||
return &gosmtp.SMTPError{Code: 501, EnhancedCode: gosmtp.EnhancedCode{5, 1, 3}, Message: "invalid recipient"}
|
||||
}
|
||||
email := strings.ToLower(addr.Address)
|
||||
|
||||
at := strings.LastIndex(email, "@")
|
||||
if at < 0 {
|
||||
return &gosmtp.SMTPError{Code: 501, EnhancedCode: gosmtp.EnhancedCode{5, 1, 3}, Message: "invalid recipient"}
|
||||
}
|
||||
domain := email[at+1:]
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Must be a local domain.
|
||||
local, err := s.deps.DB.IsLocalDomain(ctx, domain)
|
||||
if err != nil || !local {
|
||||
return &gosmtp.SMTPError{Code: 550, EnhancedCode: gosmtp.EnhancedCode{5, 1, 2}, Message: "relay access denied"}
|
||||
}
|
||||
|
||||
// Resolve to a user (handles aliases).
|
||||
user, err := s.deps.DB.ResolveEmail(ctx, email)
|
||||
if err != nil {
|
||||
return &gosmtp.SMTPError{Code: 451, EnhancedCode: gosmtp.EnhancedCode{4, 3, 0}, Message: "temporary lookup error"}
|
||||
}
|
||||
if user == nil || !user.Enabled {
|
||||
return &gosmtp.SMTPError{Code: 550, EnhancedCode: gosmtp.EnhancedCode{5, 1, 1}, Message: "user unknown"}
|
||||
}
|
||||
|
||||
s.rcpts = append(s.rcpts, email)
|
||||
s.rcptUsers = append(s.rcptUsers, user.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *InboundSession) Data(r io.Reader) error {
|
||||
if len(s.rcpts) == 0 {
|
||||
return &gosmtp.SMTPError{Code: 503, Message: "no recipients"}
|
||||
}
|
||||
|
||||
// Read message with size cap (already enforced by go-smtp, but be safe).
|
||||
raw, err := io.ReadAll(io.LimitReader(r, s.deps.Cfg.MaxMessageSize+1))
|
||||
if err != nil {
|
||||
return fmt.Errorf("read data: %w", err)
|
||||
}
|
||||
if int64(len(raw)) > s.deps.Cfg.MaxMessageSize {
|
||||
return &gosmtp.SMTPError{Code: 552, EnhancedCode: gosmtp.EnhancedCode{5, 3, 4}, Message: "message too large"}
|
||||
}
|
||||
|
||||
// Parse headers for metadata.
|
||||
msg, err := mail.ReadMessage(strings.NewReader(string(raw)))
|
||||
if err != nil {
|
||||
return &gosmtp.SMTPError{Code: 550, EnhancedCode: gosmtp.EnhancedCode{5, 6, 0}, Message: "malformed message"}
|
||||
}
|
||||
|
||||
subject := decodeHeader(msg.Header.Get("Subject"))
|
||||
fromHeader := msg.Header.Get("From")
|
||||
fromAddr, fromName := parseFromHeader(fromHeader)
|
||||
if fromAddr == "" && s.from != "" {
|
||||
fromAddr = s.from
|
||||
}
|
||||
msgID := msg.Header.Get("Message-ID")
|
||||
dateStr := msg.Header.Get("Date")
|
||||
msgDate, _ := mail.ParseDate(dateStr)
|
||||
if msgDate.IsZero() {
|
||||
msgDate = time.Now().UTC()
|
||||
}
|
||||
|
||||
// DKIM verification.
|
||||
dkimValid := false
|
||||
dkimPresent := strings.Contains(string(raw), "DKIM-Signature:")
|
||||
if dkimPresent {
|
||||
_, dkimErr := dkim.Verify(raw)
|
||||
dkimValid = dkimErr == nil
|
||||
}
|
||||
|
||||
// Spam scoring (per recipient — each has their own Bayesian model).
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Deliver to each local recipient.
|
||||
for i, userID := range s.rcptUsers {
|
||||
rcptEmail := s.rcpts[i]
|
||||
|
||||
// Build spam score params.
|
||||
params := &spam.Params{
|
||||
ClientIP: s.clientIP,
|
||||
SenderDomain: s.fromDomain,
|
||||
SPFResult: s.spfResult,
|
||||
DKIMValid: dkimValid,
|
||||
DKIMPresent: dkimPresent,
|
||||
Subject: subject,
|
||||
FromHeader: fromHeader,
|
||||
RecipCount: len(s.rcpts),
|
||||
HasDateHeader: dateStr != "",
|
||||
HasMsgIDHeader: msgID != "",
|
||||
}
|
||||
|
||||
spamResult := s.deps.Scorer.Score(ctx, userID, params)
|
||||
|
||||
// Choose target mailbox.
|
||||
mboxType := models.MailboxInbox
|
||||
if spamResult.IsSpam {
|
||||
mboxType = models.MailboxSpam
|
||||
s.log("message for %s scored %d (spam), delivering to Spam", rcptEmail, spamResult.Total)
|
||||
}
|
||||
|
||||
mbox, err := s.deps.DB.GetMailboxByType(ctx, userID, mboxType)
|
||||
if err != nil || mbox == nil {
|
||||
// Fallback to INBOX.
|
||||
mbox, err = s.deps.DB.GetMailboxByType(ctx, userID, models.MailboxInbox)
|
||||
if err != nil || mbox == nil {
|
||||
s.log("no inbox for user %d (%s): %v", userID, rcptEmail, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
incoming := &storage.IncomingMessage{
|
||||
Raw: raw,
|
||||
FromEmail: fromAddr,
|
||||
FromName: fromName,
|
||||
ToList: strings.Join(s.rcpts, ", "),
|
||||
Subject: subject,
|
||||
Date: msgDate,
|
||||
MessageID: msgID,
|
||||
SpamScore: spamResult.Total,
|
||||
}
|
||||
|
||||
_, err = s.deps.Store.SaveIncoming(ctx, userID, mbox.ID, incoming)
|
||||
if err != nil {
|
||||
s.log("store message for %s: %v", rcptEmail, err)
|
||||
return &gosmtp.SMTPError{Code: 451, EnhancedCode: gosmtp.EnhancedCode{4, 3, 0}, Message: "storage error"}
|
||||
}
|
||||
|
||||
s.log("delivered to %s (spam=%v score=%d)", rcptEmail, spamResult.IsSpam, spamResult.Total)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *InboundSession) Reset() {
|
||||
s.from = ""
|
||||
s.fromDomain = ""
|
||||
s.rcpts = s.rcpts[:0]
|
||||
s.rcptUsers = s.rcptUsers[:0]
|
||||
s.spfResult = spf.ResultNone
|
||||
}
|
||||
|
||||
func (s *InboundSession) Logout() error { return nil }
|
||||
|
||||
// ---- Helpers ----
|
||||
|
||||
func parseFromHeader(h string) (addr, name string) {
|
||||
if h == "" {
|
||||
return "", ""
|
||||
}
|
||||
a, err := mail.ParseAddress(h)
|
||||
if err != nil {
|
||||
return h, ""
|
||||
}
|
||||
return a.Address, a.Name
|
||||
}
|
||||
|
||||
func decodeHeader(h string) string {
|
||||
// mime.WordDecoder would be imported from mime package. Keep it simple.
|
||||
return h
|
||||
}
|
||||
@@ -0,0 +1,148 @@
|
||||
package smtp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/crypto"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/delivery"
|
||||
)
|
||||
|
||||
// QueueWorker polls the delivery queue and dispatches messages.
|
||||
type QueueWorker struct {
|
||||
deps *Deps
|
||||
interval time.Duration // how often to poll (default 30s)
|
||||
}
|
||||
|
||||
// NewQueueWorker creates a QueueWorker backed by the given deps.
|
||||
func NewQueueWorker(deps *Deps) *QueueWorker {
|
||||
return &QueueWorker{
|
||||
deps: deps,
|
||||
interval: 30 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// Run polls the queue until stopCh is closed.
|
||||
func (w *QueueWorker) Run(stopCh <-chan struct{}) {
|
||||
log.Println("[queue] worker started")
|
||||
ticker := time.NewTicker(w.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Drain immediately on start.
|
||||
w.drainQueue()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-stopCh:
|
||||
log.Println("[queue] worker stopped")
|
||||
return
|
||||
case <-ticker.C:
|
||||
w.drainQueue()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// drainQueue fetches all due entries and delivers them.
|
||||
func (w *QueueWorker) drainQueue() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
entries, err := w.deps.DB.PeekQueue(ctx, 100)
|
||||
if err != nil {
|
||||
log.Printf("[queue] peek error: %v", err)
|
||||
return
|
||||
}
|
||||
if len(entries) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("[queue] %d entries ready for delivery", len(entries))
|
||||
|
||||
// Derive the global queue decryption key once.
|
||||
queueKey, err := w.deps.Crypt.DeriveKeyGlobal("queue")
|
||||
if err != nil {
|
||||
log.Printf("[queue] derive key: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
// Mark in-progress to prevent parallel workers picking the same entry.
|
||||
nextAttempt := time.Now().Add(5 * time.Minute) // safety: reset on success
|
||||
if err := w.deps.DB.SetQueueStatus(ctx, entry.ID, "sending", "", &nextAttempt); err != nil {
|
||||
log.Printf("[queue] mark sending %d: %v", entry.ID, err)
|
||||
continue
|
||||
}
|
||||
|
||||
raw, err := crypto.Decrypt(queueKey, entry.RawEnc)
|
||||
if err != nil {
|
||||
log.Printf("[queue] decrypt %d: %v", entry.ID, err)
|
||||
w.markFailed(ctx, entry.ID, entry.FromAddr, entry.ToAddr, "decrypt error: "+err.Error(), true)
|
||||
continue
|
||||
}
|
||||
|
||||
ehlo := w.deps.Cfg.SMTPHostname
|
||||
if ehlo == "" {
|
||||
ehlo = w.deps.Cfg.Hostname
|
||||
}
|
||||
|
||||
result := delivery.Deliver(ctx, ehlo, entry.FromAddr, entry.ToAddr, raw)
|
||||
|
||||
if result.SMTPCode == 250 {
|
||||
// Success.
|
||||
if err := w.deps.DB.SetQueueStatus(ctx, entry.ID, "delivered", "delivered", nil); err != nil {
|
||||
log.Printf("[queue] mark delivered %d: %v", entry.ID, err)
|
||||
}
|
||||
w.deps.DB.LogDelivery(ctx, entry.ID, //nolint:errcheck
|
||||
entry.FromAddr, entry.ToAddr,
|
||||
"delivered", result.SMTPCode, result.Message, result.MXHost)
|
||||
log.Printf("[queue] delivered %d: %s → %s via %s", entry.ID, entry.FromAddr, entry.ToAddr, result.MXHost)
|
||||
continue
|
||||
}
|
||||
|
||||
// Failure.
|
||||
w.markFailed(ctx, entry.ID, entry.FromAddr, entry.ToAddr, result.Message, result.Perm)
|
||||
}
|
||||
}
|
||||
|
||||
// markFailed updates queue status with exponential back-off or marks permanent failure.
|
||||
func (w *QueueWorker) markFailed(ctx context.Context, queueID int64, from, to, errMsg string, perm bool) {
|
||||
status := "failed"
|
||||
if perm {
|
||||
status = "bounced"
|
||||
}
|
||||
|
||||
var nextAttempt *time.Time
|
||||
if !perm {
|
||||
// Exponential back-off: 5m, 15m, 1h, 4h, 8h (then give up per expires_at).
|
||||
backoff := w.nextBackoff(ctx, queueID)
|
||||
t := time.Now().Add(backoff)
|
||||
nextAttempt = &t
|
||||
}
|
||||
|
||||
if err := w.deps.DB.SetQueueStatus(ctx, queueID, status, errMsg, nextAttempt); err != nil {
|
||||
log.Printf("[queue] update status %d: %v", queueID, err)
|
||||
}
|
||||
w.deps.DB.LogDelivery(ctx, queueID, from, to, status, 0, errMsg, "") //nolint:errcheck
|
||||
log.Printf("[queue] %s %d → %s: %s", status, queueID, to, errMsg)
|
||||
}
|
||||
|
||||
// nextBackoff returns the back-off duration based on attempt count using
|
||||
// the configured schedule (or a default).
|
||||
func (w *QueueWorker) nextBackoff(ctx context.Context, queueID int64) time.Duration {
|
||||
var attempts int
|
||||
_ = w.deps.DB.SQL().QueryRowContext(ctx,
|
||||
"SELECT attempts FROM queue WHERE id=?", queueID).Scan(&attempts)
|
||||
|
||||
schedule := w.deps.Cfg.QueueRetryMins
|
||||
if len(schedule) == 0 {
|
||||
// Default: 5m, 15m, 60m, 240m, 480m
|
||||
schedule = []int{5, 15, 60, 240, 480}
|
||||
}
|
||||
|
||||
idx := attempts
|
||||
if idx >= len(schedule) {
|
||||
idx = len(schedule) - 1
|
||||
}
|
||||
return time.Duration(schedule[idx]) * time.Minute
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
// Package smtp provides SMTP inbound (port 25) and submission (587/465) servers
|
||||
// using github.com/emersion/go-smtp as the protocol layer.
|
||||
package smtp
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
gosmtp "github.com/emersion/go-smtp"
|
||||
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/auth"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/config"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/crypto"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/db"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/spam"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/storage"
|
||||
)
|
||||
|
||||
// Deps groups all dependencies shared by SMTP servers.
|
||||
type Deps struct {
|
||||
DB *db.DB
|
||||
Crypt *crypto.Crypto
|
||||
Store *storage.Store
|
||||
Scorer *spam.Scorer
|
||||
Brute *auth.BruteGuard
|
||||
Cfg *config.Config
|
||||
}
|
||||
|
||||
// NewInboundServer creates the SMTP server for port 25 (inbound MTA).
|
||||
// Accepts mail for local domains from any sender. STARTTLS offered but not required.
|
||||
func NewInboundServer(d *Deps, tlsCfg *tls.Config) *gosmtp.Server {
|
||||
be := &InboundBackend{deps: d}
|
||||
|
||||
s := gosmtp.NewServer(be)
|
||||
s.Addr = fmt.Sprintf("%s:%d", d.Cfg.SMTPIface, d.Cfg.SMTPPort)
|
||||
s.Domain = d.Cfg.SMTPHostname
|
||||
s.MaxMessageBytes = d.Cfg.MaxMessageSize
|
||||
s.MaxRecipients = d.Cfg.MaxRcptPer
|
||||
s.WriteTimeout = 5 * time.Minute
|
||||
s.ReadTimeout = 5 * time.Minute
|
||||
s.AllowInsecureAuth = false // AUTH not offered on port 25
|
||||
|
||||
if tlsCfg != nil {
|
||||
// Setting TLSConfig enables STARTTLS automatically in go-smtp v0.24+.
|
||||
s.TLSConfig = tlsCfg
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// NewSubmissionServer creates the SMTP server for port 587 (STARTTLS submission).
|
||||
// Auth required. STARTTLS mandatory before AUTH.
|
||||
func NewSubmissionServer(d *Deps, tlsCfg *tls.Config) *gosmtp.Server {
|
||||
be := &SubmissionBackend{deps: d}
|
||||
|
||||
s := gosmtp.NewServer(be)
|
||||
s.Addr = fmt.Sprintf("%s:%d", d.Cfg.SubmitIface, d.Cfg.SubmitPort)
|
||||
s.Domain = d.Cfg.SMTPHostname
|
||||
s.MaxMessageBytes = d.Cfg.MaxMessageSize
|
||||
s.MaxRecipients = d.Cfg.MaxRcptPer
|
||||
s.WriteTimeout = 5 * time.Minute
|
||||
s.ReadTimeout = 5 * time.Minute
|
||||
s.AllowInsecureAuth = false // auth only after STARTTLS
|
||||
|
||||
if tlsCfg != nil {
|
||||
s.TLSConfig = tlsCfg
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// NewSMTPSServer creates the SMTP server for port 465 (implicit TLS submission).
|
||||
func NewSMTPSServer(d *Deps, tlsCfg *tls.Config) *gosmtp.Server {
|
||||
be := &SubmissionBackend{deps: d}
|
||||
|
||||
s := gosmtp.NewServer(be)
|
||||
s.Addr = fmt.Sprintf("%s:%d", d.Cfg.SMTPIface, d.Cfg.SMTPSPort)
|
||||
s.Domain = d.Cfg.SMTPHostname
|
||||
s.MaxMessageBytes = d.Cfg.MaxMessageSize
|
||||
s.MaxRecipients = d.Cfg.MaxRcptPer
|
||||
s.WriteTimeout = 5 * time.Minute
|
||||
s.ReadTimeout = 5 * time.Minute
|
||||
s.AllowInsecureAuth = false
|
||||
|
||||
if tlsCfg != nil {
|
||||
s.TLSConfig = tlsCfg
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// ListenAndServe starts a server and logs errors.
|
||||
func ListenAndServe(s *gosmtp.Server, name string) error {
|
||||
log.Printf("[%s] listening on %s", name, s.Addr)
|
||||
return s.ListenAndServe()
|
||||
}
|
||||
|
||||
// ListenAndServeTLS starts a server with implicit TLS (port 465).
|
||||
func ListenAndServeTLS(s *gosmtp.Server, name string) error {
|
||||
log.Printf("[%s] listening on %s (TLS)", name, s.Addr)
|
||||
return s.ListenAndServeTLS()
|
||||
}
|
||||
@@ -0,0 +1,281 @@
|
||||
package smtp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/mail"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
gosmtp "github.com/emersion/go-smtp"
|
||||
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/crypto"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/dkim"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/models"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/storage"
|
||||
)
|
||||
|
||||
// SubmissionBackend implements gosmtp.Backend for ports 587/465.
|
||||
// Requires authenticated users. Signs outbound mail with DKIM. Queues for delivery.
|
||||
type SubmissionBackend struct {
|
||||
deps *Deps
|
||||
}
|
||||
|
||||
func (b *SubmissionBackend) NewSession(c *gosmtp.Conn) (gosmtp.Session, error) {
|
||||
clientIP, _, _ := net.SplitHostPort(c.Conn().RemoteAddr().String())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
var banned bool
|
||||
if b.deps.Brute != nil {
|
||||
banned, _ = b.deps.Brute.IsBanned(ctx, clientIP)
|
||||
}
|
||||
if banned {
|
||||
return nil, fmt.Errorf("connection refused")
|
||||
}
|
||||
|
||||
return &SubmissionSession{
|
||||
deps: b.deps,
|
||||
clientIP: clientIP,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SubmissionSession handles one authenticated submission connection.
|
||||
type SubmissionSession struct {
|
||||
deps *Deps
|
||||
clientIP string
|
||||
user *models.User // set after AUTH
|
||||
from string
|
||||
rcpts []string
|
||||
}
|
||||
|
||||
func (s *SubmissionSession) AuthPlain(username, password string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
user, err := s.deps.DB.GetUserByEmail(ctx, username)
|
||||
if err != nil {
|
||||
log.Printf("[smtp/submission] auth lookup error %s: %v", username, err)
|
||||
return &gosmtp.SMTPError{Code: 535, Message: "authentication failed"}
|
||||
}
|
||||
if user == nil || !user.Enabled {
|
||||
s.logAttempt(ctx, username, false)
|
||||
return &gosmtp.SMTPError{Code: 535, Message: "authentication failed"}
|
||||
}
|
||||
|
||||
if err := crypto.CheckPassword(user.PasswordHash, password); err != nil {
|
||||
s.logAttempt(ctx, username, false)
|
||||
log.Printf("[smtp/submission] auth failed for %s from %s", username, s.clientIP)
|
||||
return &gosmtp.SMTPError{Code: 535, Message: "authentication failed"}
|
||||
}
|
||||
|
||||
s.user = user
|
||||
s.deps.DB.UpdateLastLogin(ctx, user.ID)
|
||||
log.Printf("[smtp/submission] auth OK for %s from %s", username, s.clientIP)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SubmissionSession) Mail(from string, opts *gosmtp.MailOptions) error {
|
||||
if s.user == nil {
|
||||
return &gosmtp.SMTPError{Code: 530, Message: "authentication required"}
|
||||
}
|
||||
|
||||
addr, err := mail.ParseAddress(from)
|
||||
if err != nil {
|
||||
return &gosmtp.SMTPError{Code: 501, EnhancedCode: gosmtp.EnhancedCode{5, 1, 7}, Message: "invalid sender"}
|
||||
}
|
||||
|
||||
// Sender must be user's own address or an alias they own.
|
||||
fromEmail := strings.ToLower(addr.Address)
|
||||
if !strings.EqualFold(fromEmail, s.user.Email) {
|
||||
// Check aliases.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
resolved, err := s.deps.DB.ResolveEmail(ctx, fromEmail)
|
||||
if err != nil || resolved == nil || resolved.ID != s.user.ID {
|
||||
return &gosmtp.SMTPError{Code: 553, EnhancedCode: gosmtp.EnhancedCode{5, 1, 8}, Message: "sender not owned by authenticated user"}
|
||||
}
|
||||
}
|
||||
|
||||
s.from = addr.Address
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SubmissionSession) Rcpt(to string, opts *gosmtp.RcptOptions) error {
|
||||
if s.user == nil {
|
||||
return &gosmtp.SMTPError{Code: 530, Message: "authentication required"}
|
||||
}
|
||||
|
||||
addr, err := mail.ParseAddress(to)
|
||||
if err != nil {
|
||||
return &gosmtp.SMTPError{Code: 501, EnhancedCode: gosmtp.EnhancedCode{5, 1, 3}, Message: "invalid recipient"}
|
||||
}
|
||||
|
||||
s.rcpts = append(s.rcpts, addr.Address)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SubmissionSession) Data(r io.Reader) error {
|
||||
if s.user == nil {
|
||||
return &gosmtp.SMTPError{Code: 530, Message: "authentication required"}
|
||||
}
|
||||
if len(s.rcpts) == 0 {
|
||||
return &gosmtp.SMTPError{Code: 503, Message: "no recipients"}
|
||||
}
|
||||
|
||||
raw, err := io.ReadAll(io.LimitReader(r, s.deps.Cfg.MaxMessageSize+1))
|
||||
if err != nil {
|
||||
return fmt.Errorf("read submission data: %w", err)
|
||||
}
|
||||
if int64(len(raw)) > s.deps.Cfg.MaxMessageSize {
|
||||
return &gosmtp.SMTPError{Code: 552, EnhancedCode: gosmtp.EnhancedCode{5, 3, 4}, Message: "message too large"}
|
||||
}
|
||||
|
||||
// Parse for basic header validation.
|
||||
_, err = mail.ReadMessage(bytes.NewReader(raw))
|
||||
if err != nil {
|
||||
return &gosmtp.SMTPError{Code: 550, EnhancedCode: gosmtp.EnhancedCode{5, 6, 0}, Message: "malformed message"}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// DKIM-sign the message if the sender's domain has keys configured.
|
||||
senderDomain := domainOf(s.from)
|
||||
raw = s.signDKIM(ctx, raw, senderDomain)
|
||||
|
||||
msgID := extractMsgID(raw)
|
||||
|
||||
// Queue each recipient for delivery.
|
||||
// For local recipients we could deliver directly, but queuing is simpler and
|
||||
// provides a consistent audit trail.
|
||||
dom, err := s.deps.DB.GetDomain(ctx, senderDomain)
|
||||
var domainID int64
|
||||
if err == nil && dom != nil {
|
||||
domainID = dom.ID
|
||||
}
|
||||
|
||||
// Encrypt raw for queue storage using a global (non-user) key.
|
||||
queueKey, err := s.deps.Crypt.DeriveKeyGlobal("queue")
|
||||
if err != nil {
|
||||
return fmt.Errorf("queue key: %w", err)
|
||||
}
|
||||
rawEnc, err := crypto.Encrypt(queueKey, raw)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encrypt for queue: %w", err)
|
||||
}
|
||||
|
||||
for _, rcpt := range s.rcpts {
|
||||
_, err := s.deps.DB.EnqueueMessage(ctx, domainID, s.from, rcpt, msgID, rawEnc, s.deps.Cfg.QueueMaxAgeHours)
|
||||
if err != nil {
|
||||
log.Printf("[smtp/submission] enqueue to %s: %v", rcpt, err)
|
||||
return &gosmtp.SMTPError{Code: 451, EnhancedCode: gosmtp.EnhancedCode{4, 3, 0}, Message: "queue error"}
|
||||
}
|
||||
log.Printf("[smtp/submission] queued %s → %s", s.from, rcpt)
|
||||
}
|
||||
|
||||
// Also save a copy in sender's Sent folder.
|
||||
s.saveSentCopy(ctx, raw)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SubmissionSession) Reset() {
|
||||
s.from = ""
|
||||
s.rcpts = s.rcpts[:0]
|
||||
}
|
||||
|
||||
func (s *SubmissionSession) Logout() error { return nil }
|
||||
|
||||
// signDKIM signs the message with the sender domain's DKIM key if available.
|
||||
// Returns the original raw on any error (DKIM is best-effort).
|
||||
func (s *SubmissionSession) signDKIM(ctx context.Context, raw []byte, senderDomain string) []byte {
|
||||
dom, err := s.deps.DB.GetDomain(ctx, senderDomain)
|
||||
if err != nil || dom == nil || dom.DKIMPrivateEnc == nil {
|
||||
return raw // no key configured
|
||||
}
|
||||
|
||||
privPEM, err := s.deps.Crypt.DecryptGlobal("dkim", dom.DKIMPrivateEnc)
|
||||
if err != nil {
|
||||
log.Printf("[smtp/submission] dkim key decrypt for %s: %v", senderDomain, err)
|
||||
return raw
|
||||
}
|
||||
|
||||
signer, err := dkim.NewSigner(string(privPEM), senderDomain, dom.DKIMSelector)
|
||||
if err != nil {
|
||||
log.Printf("[smtp/submission] dkim signer for %s: %v", senderDomain, err)
|
||||
return raw
|
||||
}
|
||||
|
||||
header, err := signer.Sign(raw)
|
||||
if err != nil {
|
||||
log.Printf("[smtp/submission] dkim sign for %s: %v", senderDomain, err)
|
||||
return raw
|
||||
}
|
||||
|
||||
// Prepend DKIM-Signature header.
|
||||
return append([]byte(header+"\r\n"), raw...)
|
||||
}
|
||||
|
||||
// saveSentCopy stores a copy in the user's Sent mailbox (best-effort).
|
||||
func (s *SubmissionSession) saveSentCopy(ctx context.Context, raw []byte) {
|
||||
mbox, err := s.deps.DB.GetMailboxByType(ctx, s.user.ID, models.MailboxSent)
|
||||
if err != nil || mbox == nil {
|
||||
return
|
||||
}
|
||||
|
||||
msg, err := mail.ReadMessage(strings.NewReader(string(raw)))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
subject := msg.Header.Get("Subject")
|
||||
msgDate, _ := mail.ParseDate(msg.Header.Get("Date"))
|
||||
if msgDate.IsZero() {
|
||||
msgDate = time.Now().UTC()
|
||||
}
|
||||
|
||||
incoming := &storage.IncomingMessage{
|
||||
Raw: raw,
|
||||
FromEmail: s.from,
|
||||
ToList: strings.Join(s.rcpts, ", "),
|
||||
Subject: subject,
|
||||
Date: msgDate,
|
||||
MessageID: extractMsgID(raw),
|
||||
SpamScore: 0,
|
||||
}
|
||||
|
||||
if _, err := s.deps.Store.SaveIncoming(ctx, s.user.ID, mbox.ID, incoming); err != nil {
|
||||
log.Printf("[smtp/submission] save sent copy: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ---- helpers ----
|
||||
|
||||
func (s *SubmissionSession) logAttempt(ctx context.Context, email string, success bool) {
|
||||
if s.deps.Brute != nil {
|
||||
s.deps.Brute.RecordAttempt(ctx, s.clientIP, email, success) //nolint:errcheck
|
||||
}
|
||||
}
|
||||
|
||||
func domainOf(email string) string {
|
||||
at := strings.LastIndex(email, "@")
|
||||
if at < 0 {
|
||||
return ""
|
||||
}
|
||||
return strings.ToLower(email[at+1:])
|
||||
}
|
||||
|
||||
func extractMsgID(raw []byte) string {
|
||||
msg, err := mail.ReadMessage(bytes.NewReader(raw))
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return msg.Header.Get("Message-ID")
|
||||
}
|
||||
@@ -0,0 +1,314 @@
|
||||
// Package spam scores inbound messages using static heuristics plus
|
||||
// optional Bayesian token analysis (per RFC 5965 conventions).
|
||||
// Score >= threshold → deliver to Spam folder.
|
||||
package spam
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"math"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/db"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/spf"
|
||||
)
|
||||
|
||||
// Scorer evaluates spam likelihood.
|
||||
type Scorer struct {
|
||||
db *db.DB
|
||||
threshold int
|
||||
dnsbl []string
|
||||
checkSPF bool
|
||||
checkDKIM bool
|
||||
}
|
||||
|
||||
// Result holds the total spam score and the component breakdown.
|
||||
type Result struct {
|
||||
Total int
|
||||
Reasons []string
|
||||
IsSpam bool
|
||||
}
|
||||
|
||||
// Params groups message features for scoring.
|
||||
type Params struct {
|
||||
ClientIP net.IP
|
||||
SenderDomain string
|
||||
SPFResult spf.Result
|
||||
DKIMValid bool
|
||||
DKIMPresent bool
|
||||
DMARCFail bool
|
||||
Subject string
|
||||
FromHeader string
|
||||
HasHTMLOnly bool // true if no text/plain part
|
||||
RecipCount int
|
||||
HasDateHeader bool
|
||||
HasMsgIDHeader bool
|
||||
BodyText string // first 1000 bytes of plain text for token analysis
|
||||
}
|
||||
|
||||
// NewScorer creates a scorer from config values.
|
||||
func NewScorer(database *db.DB, threshold int, dnsbl []string, checkSPF, checkDKIM bool) *Scorer {
|
||||
return &Scorer{
|
||||
db: database,
|
||||
threshold: threshold,
|
||||
dnsbl: dnsbl,
|
||||
checkSPF: checkSPF,
|
||||
checkDKIM: checkDKIM,
|
||||
}
|
||||
}
|
||||
|
||||
// Score evaluates the message and returns a Result.
|
||||
func (s *Scorer) Score(ctx context.Context, userID int64, p *Params) *Result {
|
||||
r := &Result{}
|
||||
add := func(pts int, reason string) {
|
||||
r.Total += pts
|
||||
r.Reasons = append(r.Reasons, fmt.Sprintf("+%d: %s", pts, reason))
|
||||
}
|
||||
|
||||
// DNSBL check (async-ish: each lookup gets its own goroutine with timeout)
|
||||
if p.ClientIP != nil {
|
||||
hits := s.dnsblCheck(ctx, p.ClientIP)
|
||||
for _, bl := range hits {
|
||||
add(5, "DNSBL hit: "+bl)
|
||||
}
|
||||
}
|
||||
|
||||
// SPF
|
||||
if s.checkSPF {
|
||||
switch p.SPFResult {
|
||||
case spf.ResultFail:
|
||||
add(4, "SPF fail")
|
||||
case spf.ResultSoftFail:
|
||||
add(2, "SPF softfail")
|
||||
case spf.ResultNone:
|
||||
add(1, "SPF none (no record)")
|
||||
}
|
||||
}
|
||||
|
||||
// DKIM
|
||||
if s.checkDKIM {
|
||||
if !p.DKIMPresent {
|
||||
add(2, "DKIM absent")
|
||||
} else if !p.DKIMValid {
|
||||
add(3, "DKIM invalid signature")
|
||||
}
|
||||
}
|
||||
|
||||
// DMARC
|
||||
if p.DMARCFail {
|
||||
add(5, "DMARC fail")
|
||||
}
|
||||
|
||||
// Missing required headers.
|
||||
if !p.HasDateHeader {
|
||||
add(1, "missing Date header")
|
||||
}
|
||||
if !p.HasMsgIDHeader {
|
||||
add(1, "missing Message-ID header")
|
||||
}
|
||||
|
||||
// HTML-only (no text/plain).
|
||||
if p.HasHTMLOnly {
|
||||
add(1, "HTML-only body")
|
||||
}
|
||||
|
||||
// All-caps subject.
|
||||
if p.Subject != "" && isAllCaps(p.Subject) {
|
||||
add(1, "all-caps subject")
|
||||
}
|
||||
|
||||
// Excessive recipients.
|
||||
if p.RecipCount > 20 {
|
||||
add(2, fmt.Sprintf("excessive recipients (%d)", p.RecipCount))
|
||||
}
|
||||
|
||||
// Bayesian (per-user trained model).
|
||||
if userID > 0 && p.BodyText != "" {
|
||||
bayesScore, err := s.bayesScore(ctx, userID, p.BodyText)
|
||||
if err == nil && bayesScore >= 0.8 {
|
||||
pts := int((bayesScore - 0.7) * 20) // 0.8→2 pts, 0.9→4 pts, 1.0→6 pts
|
||||
add(pts, fmt.Sprintf("Bayesian score %.2f", bayesScore))
|
||||
}
|
||||
}
|
||||
|
||||
r.IsSpam = r.Total >= s.threshold
|
||||
return r
|
||||
}
|
||||
|
||||
// TrainSpam adds body tokens to the user's spam corpus.
|
||||
func (s *Scorer) TrainSpam(ctx context.Context, userID int64, body string) error {
|
||||
return s.trainTokens(ctx, userID, body, true)
|
||||
}
|
||||
|
||||
// TrainHam adds body tokens to the user's ham corpus.
|
||||
func (s *Scorer) TrainHam(ctx context.Context, userID int64, body string) error {
|
||||
return s.trainTokens(ctx, userID, body, false)
|
||||
}
|
||||
|
||||
// ---- DNSBL ----
|
||||
|
||||
func (s *Scorer) dnsblCheck(ctx context.Context, ip net.IP) []string {
|
||||
ipv4 := ip.To4()
|
||||
if ipv4 == nil {
|
||||
return nil // DNSBL queries are IPv4-only for now
|
||||
}
|
||||
|
||||
// Reverse the IP octets: 1.2.3.4 → 4.3.2.1
|
||||
reversed := fmt.Sprintf("%d.%d.%d.%d", ipv4[3], ipv4[2], ipv4[1], ipv4[0])
|
||||
|
||||
type result struct{ bl string }
|
||||
hits := make(chan result, len(s.dnsbl))
|
||||
|
||||
timeout, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
for _, bl := range s.dnsbl {
|
||||
bl := bl
|
||||
go func() {
|
||||
query := reversed + "." + bl
|
||||
addrs, err := net.DefaultResolver.LookupHost(timeout, query)
|
||||
if err == nil && len(addrs) > 0 {
|
||||
hits <- result{bl}
|
||||
} else {
|
||||
hits <- result{}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
var matched []string
|
||||
for range s.dnsbl {
|
||||
if h := <-hits; h.bl != "" {
|
||||
matched = append(matched, h.bl)
|
||||
}
|
||||
}
|
||||
return matched
|
||||
}
|
||||
|
||||
// ---- Bayesian ----
|
||||
|
||||
func tokenize(body string) []string {
|
||||
body = strings.ToLower(body)
|
||||
var tokens []string
|
||||
seen := make(map[string]struct{})
|
||||
words := strings.FieldsFunc(body, func(r rune) bool {
|
||||
return !unicode.IsLetter(r) && !unicode.IsDigit(r)
|
||||
})
|
||||
for _, w := range words {
|
||||
if len(w) < 3 || len(w) > 30 {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[w]; ok {
|
||||
continue
|
||||
}
|
||||
seen[w] = struct{}{}
|
||||
tokens = append(tokens, w)
|
||||
if len(tokens) >= 200 {
|
||||
break
|
||||
}
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
// bayesScore returns the probability [0,1] that the message is spam.
|
||||
func (s *Scorer) bayesScore(ctx context.Context, userID int64, body string) (float64, error) {
|
||||
tokens := tokenize(body)
|
||||
if len(tokens) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Fetch total spam/ham message counts for this user (proxy: count rows with nonzero counts).
|
||||
var totalSpam, totalHam int64
|
||||
err := s.db.SQL().QueryRowContext(ctx, `
|
||||
SELECT COALESCE(SUM(spam_count),0), COALESCE(SUM(ham_count),0)
|
||||
FROM spam_tokens WHERE user_id=?`, userID).Scan(&totalSpam, &totalHam)
|
||||
if err != nil || (totalSpam+totalHam) < 50 {
|
||||
// Not enough training data.
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Naive Bayes: P(spam|words) ∝ Π P(word|spam) / Π P(word|ham)
|
||||
// Use log-probabilities to avoid underflow.
|
||||
var logP float64
|
||||
|
||||
placeholders := make([]string, len(tokens))
|
||||
args := make([]interface{}, len(tokens)+1)
|
||||
args[0] = userID
|
||||
for i, tok := range tokens {
|
||||
placeholders[i] = "?"
|
||||
args[i+1] = tok
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT token, spam_count, ham_count FROM spam_tokens
|
||||
WHERE user_id=? AND token IN (%s)`, strings.Join(placeholders, ","))
|
||||
|
||||
rows, err := s.db.SQL().QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
tokenData := make(map[string][2]int64)
|
||||
for rows.Next() {
|
||||
var tok string
|
||||
var sc, hc int64
|
||||
if err := rows.Scan(&tok, &sc, &hc); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
tokenData[tok] = [2]int64{sc, hc}
|
||||
}
|
||||
|
||||
for _, tok := range tokens {
|
||||
counts := tokenData[tok]
|
||||
sc, hc := counts[0], counts[1]
|
||||
|
||||
// Laplace smoothing.
|
||||
pSpam := float64(sc+1) / float64(totalSpam+2)
|
||||
pHam := float64(hc+1) / float64(totalHam+2)
|
||||
|
||||
logP += math.Log(pSpam) - math.Log(pHam)
|
||||
}
|
||||
|
||||
// Convert log-odds back to probability.
|
||||
prob := 1.0 / (1.0 + math.Exp(-logP))
|
||||
return prob, nil
|
||||
}
|
||||
|
||||
func (s *Scorer) trainTokens(ctx context.Context, userID int64, body string, isSpam bool) error {
|
||||
tokens := tokenize(body)
|
||||
for _, tok := range tokens {
|
||||
var col string
|
||||
if isSpam {
|
||||
col = "spam_count"
|
||||
} else {
|
||||
col = "ham_count"
|
||||
}
|
||||
|
||||
_, err := s.db.SQL().ExecContext(ctx, fmt.Sprintf(`
|
||||
INSERT INTO spam_tokens (user_id, token, %s)
|
||||
VALUES (?, ?, 1)
|
||||
ON CONFLICT(user_id, token) DO UPDATE SET %s=%s+1`, col, col, col),
|
||||
userID, tok)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return fmt.Errorf("train token %q: %w", tok, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isAllCaps(s string) bool {
|
||||
hasLetter := false
|
||||
for _, r := range s {
|
||||
if unicode.IsLetter(r) {
|
||||
hasLetter = true
|
||||
if unicode.IsLower(r) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return hasLetter
|
||||
}
|
||||
@@ -0,0 +1,200 @@
|
||||
// Package spf implements basic SPF (RFC 7208) DNS lookup and policy evaluation.
|
||||
// Only stdlib net is used for DNS. Lookup limit: 10 DNS mechanisms per spec.
|
||||
package spf
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Result values per RFC 7208 §2.6.
|
||||
type Result int
|
||||
|
||||
const (
|
||||
ResultNone Result = iota // no SPF record
|
||||
ResultNeutral // ?all
|
||||
ResultPass // sender is permitted
|
||||
ResultFail // sender is not permitted (hard fail)
|
||||
ResultSoftFail // ~all (likely spam)
|
||||
ResultTempError // transient DNS error
|
||||
ResultPermError // permanent SPF parse error
|
||||
)
|
||||
|
||||
func (r Result) String() string {
|
||||
return [...]string{"none", "neutral", "pass", "fail", "softfail", "temperror", "permerror"}[r]
|
||||
}
|
||||
|
||||
// Check performs SPF evaluation for the given sender IP and envelope-from domain.
|
||||
// Returns the SPF result and an explanation string.
|
||||
func Check(clientIP net.IP, senderDomain string) (Result, string) {
|
||||
if senderDomain == "" {
|
||||
return ResultNone, "empty sender domain"
|
||||
}
|
||||
|
||||
record, err := fetchSPF(senderDomain)
|
||||
if err != nil {
|
||||
return ResultTempError, fmt.Sprintf("SPF DNS lookup %s: %v", senderDomain, err)
|
||||
}
|
||||
if record == "" {
|
||||
return ResultNone, "no SPF record for " + senderDomain
|
||||
}
|
||||
|
||||
result, msg := evaluate(clientIP, senderDomain, record, 0)
|
||||
return result, msg
|
||||
}
|
||||
|
||||
// fetchSPF queries TXT records for the domain and returns the first SPF record.
|
||||
func fetchSPF(domain string) (string, error) {
|
||||
txts, err := net.LookupTXT(domain)
|
||||
if err != nil {
|
||||
// Treat NXDOMAIN as "no record" not an error.
|
||||
if dnsErr, ok := err.(*net.DNSError); ok && dnsErr.IsNotFound {
|
||||
return "", nil
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
for _, txt := range txts {
|
||||
txt = strings.TrimSpace(txt)
|
||||
if strings.HasPrefix(txt, "v=spf1") {
|
||||
return txt, nil
|
||||
}
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// evaluate parses and applies SPF mechanisms. dnsLookups tracks the lookup count.
|
||||
func evaluate(ip net.IP, domain, record string, dnsLookups int) (Result, string) {
|
||||
parts := strings.Fields(record)
|
||||
if len(parts) == 0 || !strings.EqualFold(parts[0], "v=spf1") {
|
||||
return ResultPermError, "invalid SPF record: " + record
|
||||
}
|
||||
|
||||
for _, term := range parts[1:] {
|
||||
if dnsLookups > 10 {
|
||||
return ResultPermError, "exceeded 10 DNS lookups"
|
||||
}
|
||||
|
||||
// Qualifier prefix: +pass -fail ~softfail ?neutral
|
||||
qualifier := ResultPass
|
||||
switch term[0] {
|
||||
case '+':
|
||||
qualifier = ResultPass
|
||||
term = term[1:]
|
||||
case '-':
|
||||
qualifier = ResultFail
|
||||
term = term[1:]
|
||||
case '~':
|
||||
qualifier = ResultSoftFail
|
||||
term = term[1:]
|
||||
case '?':
|
||||
qualifier = ResultNeutral
|
||||
term = term[1:]
|
||||
}
|
||||
|
||||
lower := strings.ToLower(term)
|
||||
|
||||
switch {
|
||||
case lower == "all":
|
||||
return qualifier, "matched 'all'"
|
||||
|
||||
case strings.HasPrefix(lower, "ip4:"):
|
||||
cidr := term[4:]
|
||||
if !strings.Contains(cidr, "/") {
|
||||
cidr += "/32"
|
||||
}
|
||||
_, network, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if ip.To4() != nil && network.Contains(ip) {
|
||||
return qualifier, fmt.Sprintf("matched ip4:%s", term[4:])
|
||||
}
|
||||
|
||||
case strings.HasPrefix(lower, "ip6:"):
|
||||
cidr := term[4:]
|
||||
if !strings.Contains(cidr, "/") {
|
||||
cidr += "/128"
|
||||
}
|
||||
_, network, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if ip.To4() == nil && network.Contains(ip) {
|
||||
return qualifier, fmt.Sprintf("matched ip6:%s", term[4:])
|
||||
}
|
||||
|
||||
case lower == "a" || strings.HasPrefix(lower, "a:") || strings.HasPrefix(lower, "a/"):
|
||||
dnsLookups++
|
||||
checkDomain := domain
|
||||
if strings.HasPrefix(lower, "a:") {
|
||||
checkDomain = term[2:]
|
||||
}
|
||||
addrs, err := net.LookupHost(checkDomain)
|
||||
if err == nil {
|
||||
for _, addr := range addrs {
|
||||
if net.ParseIP(addr).Equal(ip) {
|
||||
return qualifier, fmt.Sprintf("matched a:%s", checkDomain)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case lower == "mx" || strings.HasPrefix(lower, "mx:") || strings.HasPrefix(lower, "mx/"):
|
||||
dnsLookups++
|
||||
checkDomain := domain
|
||||
if strings.HasPrefix(lower, "mx:") {
|
||||
checkDomain = term[3:]
|
||||
}
|
||||
mxs, err := net.LookupMX(checkDomain)
|
||||
if err == nil {
|
||||
for _, mx := range mxs {
|
||||
dnsLookups++
|
||||
addrs, err := net.LookupHost(mx.Host)
|
||||
if err == nil {
|
||||
for _, addr := range addrs {
|
||||
if net.ParseIP(addr).Equal(ip) {
|
||||
return qualifier, fmt.Sprintf("matched mx:%s", checkDomain)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case strings.HasPrefix(lower, "include:"):
|
||||
includeDomain := term[8:]
|
||||
dnsLookups++
|
||||
includeRecord, err := fetchSPF(includeDomain)
|
||||
if err != nil {
|
||||
return ResultTempError, fmt.Sprintf("include %s DNS error: %v", includeDomain, err)
|
||||
}
|
||||
if includeRecord == "" {
|
||||
return ResultPermError, "include domain has no SPF: " + includeDomain
|
||||
}
|
||||
subResult, subMsg := evaluate(ip, includeDomain, includeRecord, dnsLookups)
|
||||
if subResult == ResultPass {
|
||||
return qualifier, fmt.Sprintf("include:%s → %s", includeDomain, subMsg)
|
||||
}
|
||||
|
||||
case strings.HasPrefix(lower, "redirect="):
|
||||
redirectDomain := term[9:]
|
||||
dnsLookups++
|
||||
redirectRecord, err := fetchSPF(redirectDomain)
|
||||
if err != nil {
|
||||
return ResultTempError, fmt.Sprintf("redirect %s DNS error: %v", redirectDomain, err)
|
||||
}
|
||||
if redirectRecord == "" {
|
||||
return ResultNone, "redirect domain has no SPF: " + redirectDomain
|
||||
}
|
||||
return evaluate(ip, redirectDomain, redirectRecord, dnsLookups)
|
||||
|
||||
case strings.HasPrefix(lower, "exp="):
|
||||
// Explanation modifier — ignore.
|
||||
|
||||
default:
|
||||
// Unknown mechanism — ignore per spec.
|
||||
}
|
||||
}
|
||||
|
||||
// No mechanism matched.
|
||||
return ResultNeutral, "no mechanism matched"
|
||||
}
|
||||
@@ -0,0 +1,358 @@
|
||||
// Package storage provides encrypted message persistence for both SQLite (db)
|
||||
// and filesystem (fs) backends. All body and raw data is encrypted with
|
||||
// AES-256-GCM before being written.
|
||||
package storage
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"mime/quotedprintable"
|
||||
"net/mail"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/crypto"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/db"
|
||||
"ghb.freebede.com/nahakubuilder/mailgosend/internal/models"
|
||||
)
|
||||
|
||||
// Store is the encrypted message store.
|
||||
type Store struct {
|
||||
db *db.DB
|
||||
crypt *crypto.Crypto
|
||||
backend string // "db" | "fs"
|
||||
fsPath string
|
||||
}
|
||||
|
||||
// IncomingMessage is the raw inbound message plus parsed envelope fields.
|
||||
type IncomingMessage struct {
|
||||
Raw []byte // full RFC822
|
||||
FromEmail string
|
||||
FromName string
|
||||
ToList string
|
||||
CCList string
|
||||
BCCList string
|
||||
ReplyTo string
|
||||
Subject string
|
||||
Date time.Time
|
||||
MessageID string
|
||||
SpamScore int
|
||||
}
|
||||
|
||||
type parsedAttachment struct {
|
||||
Filename string
|
||||
ContentType string
|
||||
Data []byte
|
||||
ContentID string
|
||||
Inline bool
|
||||
MIMEPath string
|
||||
}
|
||||
|
||||
// New validates the backend and returns a ready Store.
|
||||
func New(database *db.DB, crypt *crypto.Crypto, backend, fsPath string) (*Store, error) {
|
||||
if backend != "db" && backend != "fs" {
|
||||
return nil, fmt.Errorf("storage: unknown backend %q (want \"db\" or \"fs\")", backend)
|
||||
}
|
||||
if backend == "fs" {
|
||||
if fsPath == "" {
|
||||
return nil, fmt.Errorf("storage: fsPath required for fs backend")
|
||||
}
|
||||
if err := os.MkdirAll(fsPath, 0700); err != nil {
|
||||
return nil, fmt.Errorf("storage: create fs dir: %w", err)
|
||||
}
|
||||
}
|
||||
return &Store{db: database, crypt: crypt, backend: backend, fsPath: fsPath}, nil
|
||||
}
|
||||
|
||||
// SaveIncoming encrypts and persists an incoming message plus its attachments.
|
||||
// Returns the new message row ID.
|
||||
func (s *Store) SaveIncoming(ctx context.Context, userID, mailboxID int64, msg *IncomingMessage) (int64, error) {
|
||||
uid, err := s.db.NextUID(ctx, mailboxID)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("storage: next uid: %w", err)
|
||||
}
|
||||
|
||||
bodyText, bodyHTML, attachments := parseMIME(msg.Raw)
|
||||
|
||||
key, err := s.crypt.DeriveKey("messages", userID)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("storage: derive key: %w", err)
|
||||
}
|
||||
|
||||
bodyTextEnc, err := crypto.Encrypt(key, []byte(bodyText))
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("storage: encrypt body text: %w", err)
|
||||
}
|
||||
|
||||
bodyHTMLEnc, err := crypto.Encrypt(key, []byte(bodyHTML))
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("storage: encrypt body html: %w", err)
|
||||
}
|
||||
|
||||
rawEnc, err := crypto.Encrypt(key, msg.Raw)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("storage: encrypt raw: %w", err)
|
||||
}
|
||||
|
||||
insert := &db.MessageInsert{
|
||||
MailboxID: mailboxID,
|
||||
UID: uid,
|
||||
MessageID: msg.MessageID,
|
||||
Subject: msg.Subject,
|
||||
FromEmail: msg.FromEmail,
|
||||
FromName: msg.FromName,
|
||||
ToList: msg.ToList,
|
||||
CCList: msg.CCList,
|
||||
BCCList: msg.BCCList,
|
||||
ReplyTo: msg.ReplyTo,
|
||||
Date: msg.Date,
|
||||
BodyTextEnc: bodyTextEnc,
|
||||
BodyHTMLEnc: bodyHTMLEnc,
|
||||
RawEnc: rawEnc,
|
||||
SizeBytes: int64(len(msg.Raw)),
|
||||
HasAttachment: len(attachments) > 0,
|
||||
IsRead: false,
|
||||
IsStarred: false,
|
||||
IsDraft: false,
|
||||
Flags: "",
|
||||
SpamScore: msg.SpamScore,
|
||||
}
|
||||
|
||||
messageID, err := s.db.InsertMessage(ctx, insert)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("storage: insert message: %w", err)
|
||||
}
|
||||
|
||||
if err := s.db.UpdateUsedBytes(ctx, userID, int64(len(msg.Raw))); err != nil {
|
||||
return 0, fmt.Errorf("storage: update used bytes: %w", err)
|
||||
}
|
||||
|
||||
for _, att := range attachments {
|
||||
attEnc, err := crypto.Encrypt(key, att.Data)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("storage: encrypt attachment: %w", err)
|
||||
}
|
||||
|
||||
var dataPath string
|
||||
var dataEnc []byte
|
||||
|
||||
if s.backend == "fs" {
|
||||
h := sha256.Sum256(att.Data)
|
||||
name := hex.EncodeToString(h[:]) + ".att"
|
||||
p := filepath.Join(s.fsPath, name)
|
||||
if err := os.WriteFile(p, attEnc, 0600); err != nil {
|
||||
return 0, fmt.Errorf("storage: write attachment file: %w", err)
|
||||
}
|
||||
dataPath = p
|
||||
} else {
|
||||
dataEnc = attEnc
|
||||
}
|
||||
|
||||
attInsert := &db.AttachmentInsert{
|
||||
MessageID: messageID,
|
||||
Filename: att.Filename,
|
||||
ContentType: att.ContentType,
|
||||
SizeBytes: int64(len(att.Data)),
|
||||
DataEnc: dataEnc,
|
||||
DataPath: dataPath,
|
||||
ContentID: att.ContentID,
|
||||
Inline: att.Inline,
|
||||
MIMEPath: att.MIMEPath,
|
||||
}
|
||||
if _, err := s.db.InsertAttachment(ctx, attInsert); err != nil {
|
||||
return 0, fmt.Errorf("storage: insert attachment: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return messageID, nil
|
||||
}
|
||||
|
||||
// GetRaw returns the decrypted RFC822 blob for a message.
|
||||
func (s *Store) GetRaw(ctx context.Context, userID, messageID int64) ([]byte, error) {
|
||||
rawEnc, err := s.db.GetMessageRaw(ctx, messageID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("storage: get raw enc: %w", err)
|
||||
}
|
||||
if rawEnc == nil {
|
||||
return nil, fmt.Errorf("storage: message %d not found", messageID)
|
||||
}
|
||||
|
||||
key, err := s.crypt.DeriveKey("messages", userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("storage: derive key: %w", err)
|
||||
}
|
||||
|
||||
plain, err := crypto.Decrypt(key, rawEnc)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("storage: decrypt raw: %w", err)
|
||||
}
|
||||
return plain, nil
|
||||
}
|
||||
|
||||
// parseMIME walks a raw RFC822 message and extracts body parts and attachments.
|
||||
func parseMIME(raw []byte) (bodyText, bodyHTML string, attachments []parsedAttachment) {
|
||||
m, err := mail.ReadMessage(bytes.NewReader(raw))
|
||||
if err != nil {
|
||||
return "", "", nil
|
||||
}
|
||||
ct := m.Header.Get("Content-Type")
|
||||
body, err := io.ReadAll(m.Body)
|
||||
if err != nil {
|
||||
return "", "", nil
|
||||
}
|
||||
walkPart(ct, m.Header.Get("Content-Transfer-Encoding"),
|
||||
m.Header.Get("Content-Disposition"), m.Header.Get("Content-ID"),
|
||||
body, "1", &bodyText, &bodyHTML, &attachments)
|
||||
return bodyText, bodyHTML, attachments
|
||||
}
|
||||
|
||||
// walkPart recursively processes one MIME part.
|
||||
func walkPart(
|
||||
ctHeader, cteHeader, cdHeader, cidHeader string,
|
||||
data []byte,
|
||||
mimePath string,
|
||||
bodyText, bodyHTML *string,
|
||||
attachments *[]parsedAttachment,
|
||||
) {
|
||||
mediaType, params, err := mime.ParseMediaType(ctHeader)
|
||||
if err != nil {
|
||||
// Treat unparseable as text/plain.
|
||||
mediaType = "text/plain"
|
||||
params = map[string]string{}
|
||||
}
|
||||
|
||||
// Decode transfer encoding first if this is a leaf part.
|
||||
if !strings.HasPrefix(mediaType, "multipart/") {
|
||||
data = decodeCTE(cteHeader, data)
|
||||
}
|
||||
|
||||
switch {
|
||||
case mediaType == "text/plain" && !isAttachment(cdHeader):
|
||||
*bodyText = string(data)
|
||||
|
||||
case mediaType == "text/html" && !isAttachment(cdHeader):
|
||||
*bodyHTML = string(data)
|
||||
|
||||
case strings.HasPrefix(mediaType, "multipart/"):
|
||||
boundary, ok := params["boundary"]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
mr := multipart.NewReader(bytes.NewReader(data), boundary)
|
||||
partIdx := 1
|
||||
for {
|
||||
part, err := mr.NextPart()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
partData, err := io.ReadAll(part)
|
||||
if err != nil {
|
||||
part.Close()
|
||||
break
|
||||
}
|
||||
part.Close()
|
||||
|
||||
subCT := part.Header.Get("Content-Type")
|
||||
if subCT == "" {
|
||||
subCT = "text/plain"
|
||||
}
|
||||
subCTE := part.Header.Get("Content-Transfer-Encoding")
|
||||
subCD := part.Header.Get("Content-Disposition")
|
||||
subCID := part.Header.Get("Content-ID")
|
||||
subPath := fmt.Sprintf("%s.%d", mimePath, partIdx)
|
||||
|
||||
walkPart(subCT, subCTE, subCD, subCID, partData, subPath,
|
||||
bodyText, bodyHTML, attachments)
|
||||
partIdx++
|
||||
}
|
||||
|
||||
default:
|
||||
// Treat as attachment.
|
||||
filename := filenameFrom(cdHeader, ctHeader)
|
||||
inline := strings.HasPrefix(strings.ToLower(cdHeader), "inline")
|
||||
cleanCID := strings.Trim(cidHeader, "<>")
|
||||
*attachments = append(*attachments, parsedAttachment{
|
||||
Filename: filename,
|
||||
ContentType: mediaType,
|
||||
Data: data,
|
||||
ContentID: cleanCID,
|
||||
Inline: inline,
|
||||
MIMEPath: mimePath,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// decodeCTE applies quoted-printable or base64 decoding based on the
|
||||
// Content-Transfer-Encoding header value.
|
||||
func decodeCTE(cte string, data []byte) []byte {
|
||||
switch strings.ToLower(strings.TrimSpace(cte)) {
|
||||
case "quoted-printable":
|
||||
decoded, err := io.ReadAll(quotedprintable.NewReader(bytes.NewReader(data)))
|
||||
if err != nil {
|
||||
return data
|
||||
}
|
||||
return decoded
|
||||
case "base64":
|
||||
// Strip whitespace — base64 in email is line-wrapped.
|
||||
clean := strings.Map(func(r rune) rune {
|
||||
if r == '\r' || r == '\n' || r == ' ' || r == '\t' {
|
||||
return -1
|
||||
}
|
||||
return r
|
||||
}, string(data))
|
||||
decoded, err := base64.StdEncoding.DecodeString(clean)
|
||||
if err != nil {
|
||||
// Try raw/URL encoding as fallback.
|
||||
decoded, err = base64.RawStdEncoding.DecodeString(clean)
|
||||
if err != nil {
|
||||
return data
|
||||
}
|
||||
}
|
||||
return decoded
|
||||
default:
|
||||
return data
|
||||
}
|
||||
}
|
||||
|
||||
// isAttachment returns true when the Content-Disposition header signals attachment.
|
||||
func isAttachment(cd string) bool {
|
||||
if cd == "" {
|
||||
return false
|
||||
}
|
||||
disp, _, _ := mime.ParseMediaType(cd)
|
||||
return strings.EqualFold(disp, "attachment")
|
||||
}
|
||||
|
||||
// filenameFrom extracts a filename from Content-Disposition, falling back to
|
||||
// the name= param of Content-Type.
|
||||
func filenameFrom(cd, ct string) string {
|
||||
if cd != "" {
|
||||
_, params, err := mime.ParseMediaType(cd)
|
||||
if err == nil {
|
||||
if name, ok := params["filename"]; ok && name != "" {
|
||||
return filepath.Base(name)
|
||||
}
|
||||
}
|
||||
}
|
||||
if ct != "" {
|
||||
_, params, err := mime.ParseMediaType(ct)
|
||||
if err == nil {
|
||||
if name, ok := params["name"]; ok && name != "" {
|
||||
return filepath.Base(name)
|
||||
}
|
||||
}
|
||||
}
|
||||
return "attachment"
|
||||
}
|
||||
|
||||
// Ensure models import is used (Subject field type reference avoids import pruning).
|
||||
var _ = models.MailboxInbox
|
||||
@@ -0,0 +1,350 @@
|
||||
package tls
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/go-acme/lego/v4/certcrypto"
|
||||
"github.com/go-acme/lego/v4/certificate"
|
||||
"github.com/go-acme/lego/v4/challenge"
|
||||
"github.com/go-acme/lego/v4/challenge/http01"
|
||||
"github.com/go-acme/lego/v4/lego"
|
||||
"github.com/go-acme/lego/v4/providers/dns/cloudflare"
|
||||
"github.com/go-acme/lego/v4/providers/dns/digitalocean"
|
||||
"github.com/go-acme/lego/v4/providers/dns/hetzner"
|
||||
"github.com/go-acme/lego/v4/providers/dns/route53"
|
||||
"github.com/go-acme/lego/v4/registration"
|
||||
)
|
||||
|
||||
// ACME manages Let's Encrypt certificate issuance and renewal.
|
||||
type ACME struct {
|
||||
cfg ACMEConfig
|
||||
client *lego.Client
|
||||
account *acmeAccount
|
||||
cacheDir string
|
||||
}
|
||||
|
||||
// ACMEConfig holds all ACME-related settings from app config.
|
||||
type ACMEConfig struct {
|
||||
Email string
|
||||
CacheDir string
|
||||
Staging bool
|
||||
Domains []string // e.g. ["example.com", "*.example.com"]
|
||||
Mode string // "dns01" | "http01"
|
||||
DNSProvider string // "cloudflare" | "route53" | "digitalocean" | "hetzner" | ...
|
||||
}
|
||||
|
||||
// acmeAccount implements lego's registration.User interface.
|
||||
type acmeAccount struct {
|
||||
Email string `json:"email"`
|
||||
Registration *registration.Resource `json:"registration"`
|
||||
key crypto.PrivateKey
|
||||
keyPEM []byte
|
||||
}
|
||||
|
||||
func (a *acmeAccount) GetEmail() string { return a.Email }
|
||||
func (a *acmeAccount) GetRegistration() *registration.Resource { return a.Registration }
|
||||
func (a *acmeAccount) GetPrivateKey() crypto.PrivateKey { return a.key }
|
||||
|
||||
// NewACME initialises the ACME client. Does not yet obtain a cert.
|
||||
// Call ObtainOrRenew() to get/refresh the certificate.
|
||||
func NewACME(cfg ACMEConfig) (*ACME, error) {
|
||||
if cfg.Email == "" {
|
||||
return nil, fmt.Errorf("ACME_EMAIL required for Let's Encrypt")
|
||||
}
|
||||
if len(cfg.Domains) == 0 {
|
||||
return nil, fmt.Errorf("ACME_DOMAINS must not be empty")
|
||||
}
|
||||
if err := os.MkdirAll(cfg.CacheDir, 0700); err != nil {
|
||||
return nil, fmt.Errorf("acme cache dir: %w", err)
|
||||
}
|
||||
|
||||
a := &ACME{cfg: cfg, cacheDir: cfg.CacheDir}
|
||||
|
||||
acc, err := a.loadOrCreateAccount()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("acme account: %w", err)
|
||||
}
|
||||
a.account = acc
|
||||
|
||||
legoConfig := lego.NewConfig(a.account)
|
||||
if cfg.Staging {
|
||||
legoConfig.CADirURL = lego.LEDirectoryStaging
|
||||
} else {
|
||||
legoConfig.CADirURL = lego.LEDirectoryProduction
|
||||
}
|
||||
legoConfig.Certificate.KeyType = certcrypto.RSA2048
|
||||
|
||||
client, err := lego.NewClient(legoConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("lego client: %w", err)
|
||||
}
|
||||
a.client = client
|
||||
|
||||
// Configure challenge provider.
|
||||
if err := a.setProvider(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Register account if not already registered.
|
||||
if acc.Registration == nil {
|
||||
reg, err := client.Registration.Register(registration.RegisterOptions{
|
||||
TermsOfServiceAgreed: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("acme register: %w", err)
|
||||
}
|
||||
a.account.Registration = reg
|
||||
if err := a.saveAccount(); err != nil {
|
||||
return nil, fmt.Errorf("save acme account: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// setProvider configures the ACME challenge based on cfg.Mode and cfg.DNSProvider.
|
||||
func (a *ACME) setProvider() error {
|
||||
switch a.cfg.Mode {
|
||||
case "http01":
|
||||
// Lego manages an ephemeral HTTP server on port 80.
|
||||
return a.client.Challenge.SetHTTP01Provider(http01.NewProviderServer("", "80"))
|
||||
|
||||
case "dns01":
|
||||
provider, err := a.buildDNSProvider()
|
||||
if err != nil {
|
||||
return fmt.Errorf("dns01 provider %q: %w", a.cfg.DNSProvider, err)
|
||||
}
|
||||
return a.client.Challenge.SetDNS01Provider(provider)
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unknown ACME mode %q (want dns01 or http01)", a.cfg.Mode)
|
||||
}
|
||||
}
|
||||
|
||||
// buildDNSProvider returns the lego DNS provider for cfg.DNSProvider.
|
||||
// Credentials are read from env vars set by config.exportProviderEnv().
|
||||
func (a *ACME) buildDNSProvider() (challenge.Provider, error) {
|
||||
switch a.cfg.DNSProvider {
|
||||
case "cloudflare":
|
||||
return cloudflare.NewDNSProvider()
|
||||
|
||||
case "route53":
|
||||
return route53.NewDNSProvider()
|
||||
|
||||
case "digitalocean":
|
||||
return digitalocean.NewDNSProvider()
|
||||
|
||||
case "hetzner":
|
||||
return hetzner.NewDNSProvider()
|
||||
|
||||
default:
|
||||
// Generic: lego supports 90+ providers; if the user sets the correct
|
||||
// env vars and the provider name matches a lego provider, it WILL work
|
||||
// through the lego plugin system. For unlisted providers, document that
|
||||
// the user must set env vars matching the lego provider docs.
|
||||
return nil, fmt.Errorf(
|
||||
"provider %q not built-in; set lego env vars and use a supported provider name.\n"+
|
||||
"Built-in: cloudflare, route53, digitalocean, hetzner.\n"+
|
||||
"Full list: https://go-acme.github.io/lego/dns/",
|
||||
a.cfg.DNSProvider,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// ObtainOrRenew returns a valid *tls.Certificate, obtaining or renewing as needed.
|
||||
// Uses cached cert on disk if it has >30 days remaining.
|
||||
func (a *ACME) ObtainOrRenew() (*tls.Certificate, error) {
|
||||
cached, err := a.loadCachedCert()
|
||||
if err == nil && cached != nil {
|
||||
return cached, nil
|
||||
}
|
||||
|
||||
return a.obtain()
|
||||
}
|
||||
|
||||
// obtain requests a new certificate from Let's Encrypt.
|
||||
func (a *ACME) obtain() (*tls.Certificate, error) {
|
||||
req := certificate.ObtainRequest{
|
||||
Domains: a.cfg.Domains,
|
||||
Bundle: true, // include full chain in cert PEM
|
||||
}
|
||||
|
||||
res, err := a.client.Certificate.Obtain(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("acme obtain %v: %w", a.cfg.Domains, err)
|
||||
}
|
||||
|
||||
if err := a.saveCert(res); err != nil {
|
||||
// Non-fatal: cert is in memory, just can't persist.
|
||||
fmt.Printf("[acme] WARNING: could not cache cert: %v\n", err)
|
||||
}
|
||||
|
||||
return parseCert(res.Certificate, res.PrivateKey)
|
||||
}
|
||||
|
||||
// RenewalLoop blocks forever, renewing the cert 30 days before expiry.
|
||||
// manager.UpdateCert() is called on each renewal to hot-swap the TLS config.
|
||||
// Call as a goroutine. stopCh receives when it should quit.
|
||||
func (a *ACME) RenewalLoop(manager *Manager, stopCh <-chan struct{}) {
|
||||
ticker := time.NewTicker(12 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
cached, err := a.loadCachedCert()
|
||||
if err != nil || cached == nil {
|
||||
// No cert or corrupt cache — re-obtain.
|
||||
cert, err := a.obtain()
|
||||
if err != nil {
|
||||
fmt.Printf("[acme] renewal obtain error: %v\n", err)
|
||||
continue
|
||||
}
|
||||
manager.UpdateCert(cert)
|
||||
fmt.Printf("[acme] cert renewed for %v\n", a.cfg.Domains)
|
||||
}
|
||||
// loadCachedCert returns nil if cert expires in < 30 days → triggers re-obtain above.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Cert persistence ----
|
||||
|
||||
func (a *ACME) certPath() string { return filepath.Join(a.cacheDir, "cert.pem") }
|
||||
func (a *ACME) keyPath() string { return filepath.Join(a.cacheDir, "key.pem") }
|
||||
func (a *ACME) accPath() string { return filepath.Join(a.cacheDir, "account.json") }
|
||||
func (a *ACME) accKeyPath() string { return filepath.Join(a.cacheDir, "account.key") }
|
||||
|
||||
func (a *ACME) saveCert(res *certificate.Resource) error {
|
||||
if err := os.WriteFile(a.certPath(), res.Certificate, 0600); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(a.keyPath(), res.PrivateKey, 0600)
|
||||
}
|
||||
|
||||
// loadCachedCert returns the cached cert if it has >30 days remaining, nil otherwise.
|
||||
func (a *ACME) loadCachedCert() (*tls.Certificate, error) {
|
||||
certPEM, err := os.ReadFile(a.certPath())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keyPEM, err := os.ReadFile(a.keyPath())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cert, err := parseCert(certPEM, keyPEM)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check expiry — renew if <30 days left.
|
||||
if cert.Leaf != nil && time.Until(cert.Leaf.NotAfter) < 30*24*time.Hour {
|
||||
return nil, nil // signal: needs renewal
|
||||
}
|
||||
// Parse leaf if not already parsed.
|
||||
if cert.Leaf == nil {
|
||||
leaf, err := x509.ParseCertificate(cert.Certificate[0])
|
||||
if err == nil && time.Until(leaf.NotAfter) < 30*24*time.Hour {
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
func parseCert(certPEM, keyPEM []byte) (*tls.Certificate, error) {
|
||||
cert, err := tls.X509KeyPair(certPEM, keyPEM)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse cert: %w", err)
|
||||
}
|
||||
// Pre-parse leaf for expiry checks.
|
||||
if len(cert.Certificate) > 0 {
|
||||
leaf, err := x509.ParseCertificate(cert.Certificate[0])
|
||||
if err == nil {
|
||||
cert.Leaf = leaf
|
||||
}
|
||||
}
|
||||
return &cert, nil
|
||||
}
|
||||
|
||||
// ---- Account persistence ----
|
||||
|
||||
func (a *ACME) loadOrCreateAccount() (*acmeAccount, error) {
|
||||
// Try loading existing account.
|
||||
accData, errAcc := os.ReadFile(a.accPath())
|
||||
keyData, errKey := os.ReadFile(a.accKeyPath())
|
||||
|
||||
if errAcc == nil && errKey == nil {
|
||||
acc := &acmeAccount{}
|
||||
if err := json.Unmarshal(accData, acc); err != nil {
|
||||
return nil, fmt.Errorf("parse account: %w", err)
|
||||
}
|
||||
key, err := parseECKey(keyData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse account key: %w", err)
|
||||
}
|
||||
acc.key = key
|
||||
acc.keyPEM = keyData
|
||||
return acc, nil
|
||||
}
|
||||
|
||||
// Generate new account key.
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gen account key: %w", err)
|
||||
}
|
||||
keyPEM, err := encodeECKey(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
acc := &acmeAccount{
|
||||
Email: a.cfg.Email,
|
||||
key: key,
|
||||
keyPEM: keyPEM,
|
||||
}
|
||||
return acc, nil
|
||||
}
|
||||
|
||||
func (a *ACME) saveAccount() error {
|
||||
data, err := json.MarshalIndent(a.account, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.WriteFile(a.accPath(), data, 0600); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(a.accKeyPath(), a.account.keyPEM, 0600)
|
||||
}
|
||||
|
||||
// ---- EC key helpers ----
|
||||
|
||||
func encodeECKey(key *ecdsa.PrivateKey) ([]byte, error) {
|
||||
der, err := x509.MarshalECPrivateKey(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal ec key: %w", err)
|
||||
}
|
||||
return pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: der}), nil
|
||||
}
|
||||
|
||||
func parseECKey(pemData []byte) (*ecdsa.PrivateKey, error) {
|
||||
block, _ := pem.Decode(pemData)
|
||||
if block == nil {
|
||||
return nil, fmt.Errorf("no PEM block in account key")
|
||||
}
|
||||
return x509.ParseECPrivateKey(block.Bytes)
|
||||
}
|
||||
@@ -0,0 +1,158 @@
|
||||
// Package tls provides TLS configuration builders, cert loading,
|
||||
// and hot-reload on SIGHUP.
|
||||
package tls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// MinVersion enforced across all listeners.
|
||||
const MinVersion = tls.VersionTLS12
|
||||
|
||||
// cipherSuites is the approved list — ECDHE + AEAD only.
|
||||
var cipherSuites = []uint16{
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
|
||||
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
}
|
||||
|
||||
// Manager holds the current TLS certificate and rebuilds tls.Config on demand.
|
||||
// Thread-safe: cert updates are atomic.
|
||||
type Manager struct {
|
||||
mu sync.Mutex
|
||||
certFile string
|
||||
keyFile string
|
||||
cert unsafe.Pointer // *tls.Certificate, accessed atomically
|
||||
}
|
||||
|
||||
// NewManager creates a Manager that loads cert + key from disk.
|
||||
// Call Reload() after obtaining/renewing a cert.
|
||||
func NewManager(certFile, keyFile string) (*Manager, error) {
|
||||
m := &Manager{certFile: certFile, keyFile: keyFile}
|
||||
if certFile != "" && keyFile != "" {
|
||||
if err := m.Reload(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Reload re-reads the cert and key from disk atomically.
|
||||
// Safe to call from a SIGHUP handler or ACME renewal goroutine.
|
||||
func (m *Manager) Reload() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.certFile == "" || m.keyFile == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
cert, err := tls.LoadX509KeyPair(m.certFile, m.keyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("tls load cert %s / %s: %w", m.certFile, m.keyFile, err)
|
||||
}
|
||||
atomic.StorePointer(&m.cert, unsafe.Pointer(&cert))
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateCert replaces the in-memory certificate (used by ACME after renewal).
|
||||
func (m *Manager) UpdateCert(cert *tls.Certificate) {
|
||||
atomic.StorePointer(&m.cert, unsafe.Pointer(cert))
|
||||
}
|
||||
|
||||
// GetCertificate implements tls.Config.GetCertificate — returns the current cert
|
||||
// regardless of SNI (single-domain or wildcard usage).
|
||||
func (m *Manager) GetCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
p := atomic.LoadPointer(&m.cert)
|
||||
if p == nil {
|
||||
return nil, fmt.Errorf("no TLS certificate loaded")
|
||||
}
|
||||
return (*tls.Certificate)(p), nil
|
||||
}
|
||||
|
||||
// Config returns a *tls.Config with secure defaults and GetCertificate set.
|
||||
func (m *Manager) Config() *tls.Config {
|
||||
cfg := &tls.Config{
|
||||
GetCertificate: m.GetCertificate,
|
||||
MinVersion: MinVersion,
|
||||
CipherSuites: cipherSuites,
|
||||
// TLS 1.3 cipher suites are configured by the Go runtime automatically.
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// ConfigForSMTP returns a tls.Config for SMTP servers.
|
||||
// SMTP requires setting ServerName in ClientHello so SNI works;
|
||||
// the GetCertificate callback handles the lookup.
|
||||
func (m *Manager) ConfigForSMTP() *tls.Config {
|
||||
cfg := m.Config()
|
||||
// SMTP clients often don't send SNI, so always return our cert.
|
||||
return cfg
|
||||
}
|
||||
|
||||
// ---- Manual cert loader (TLS_MODE=manual without hot-reload) ----
|
||||
|
||||
// LoadCert loads a certificate pair from disk and returns a ready tls.Config.
|
||||
// Used for services where the Manager is not needed (e.g. one-off TLS dial).
|
||||
func LoadCert(certFile, keyFile string) (*tls.Config, error) {
|
||||
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load cert %s: %w", certFile, err)
|
||||
}
|
||||
return &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
MinVersion: MinVersion,
|
||||
CipherSuites: cipherSuites,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ---- TLS_MODE=off placeholder ----
|
||||
|
||||
// IsOff returns true if the given mode string means no TLS.
|
||||
func IsOff(mode string) bool { return mode == "off" || mode == "" }
|
||||
|
||||
// ---- Per-service cert resolution ----
|
||||
|
||||
// ServiceCert describes the cert files for one service (SMTP, IMAP, web).
|
||||
type ServiceCert struct {
|
||||
CertFile string
|
||||
KeyFile string
|
||||
}
|
||||
|
||||
// Resolve returns the cert pair to use for a service:
|
||||
// service-specific override → global fallback → empty (ACME handles it).
|
||||
func Resolve(svcCert, svcKey, globalCert, globalKey string) ServiceCert {
|
||||
if svcCert != "" && svcKey != "" {
|
||||
return ServiceCert{CertFile: svcCert, KeyFile: svcKey}
|
||||
}
|
||||
if globalCert != "" && globalKey != "" {
|
||||
return ServiceCert{CertFile: globalCert, KeyFile: globalKey}
|
||||
}
|
||||
return ServiceCert{} // ACME-managed
|
||||
}
|
||||
|
||||
// FileExists returns true if both files exist.
|
||||
func (sc ServiceCert) FileExists() bool {
|
||||
if sc.CertFile == "" || sc.KeyFile == "" {
|
||||
return false
|
||||
}
|
||||
_, errC := os.Stat(sc.CertFile)
|
||||
_, errK := os.Stat(sc.KeyFile)
|
||||
return errC == nil && errK == nil
|
||||
}
|
||||
|
||||
// ---- TLS listener ----
|
||||
|
||||
// NewListener wraps a net.Listener with TLS using the provided config.
|
||||
func NewListener(ln net.Listener, cfg *tls.Config) net.Listener {
|
||||
return tls.NewListener(ln, cfg)
|
||||
}
|
||||
Reference in New Issue
Block a user