Files
mailgosend/internal/auth/brute.go
T
2026-05-21 20:27:58 +00:00

202 lines
5.4 KiB
Go

package auth
import (
"context"
"database/sql"
"fmt"
"time"
"ghb.freebede.com/nahakubuilder/mailgosend/internal/db"
)
// BruteGuard tracks login attempts and bans IPs exceeding the threshold.
type BruteGuard struct {
db *db.DB
maxTries int
windowMin int
banHours int
whitelist map[string]struct{} // exempt IPs
}
// NewBruteGuard creates a brute-force guard.
func NewBruteGuard(database *db.DB, maxTries, windowMin, banHours int, whitelist []string) *BruteGuard {
wl := make(map[string]struct{}, len(whitelist))
for _, ip := range whitelist {
wl[ip] = struct{}{}
}
return &BruteGuard{
db: database,
maxTries: maxTries,
windowMin: windowMin,
banHours: banHours,
whitelist: wl,
}
}
// IsBanned returns true if the IP is currently banned.
func (g *BruteGuard) IsBanned(ctx context.Context, ip string) (bool, error) {
if _, ok := g.whitelist[ip]; ok {
return false, nil
}
var expiresAt sql.NullTime
err := g.db.SQL().QueryRowContext(ctx,
"SELECT expires_at FROM ip_bans WHERE ip = ?", ip).Scan(&expiresAt)
if err == sql.ErrNoRows {
return false, nil
}
if err != nil {
return false, fmt.Errorf("check ban: %w", err)
}
// Permanent ban (expires_at IS NULL) or not yet expired.
if !expiresAt.Valid || expiresAt.Time.After(time.Now().UTC()) {
return true, nil
}
// Expired ban — clean up.
_, _ = g.db.SQL().ExecContext(ctx, "DELETE FROM ip_bans WHERE ip = ?", ip)
return false, nil
}
// RecordAttempt records a login attempt and bans the IP if threshold exceeded.
// Returns (banned, error).
func (g *BruteGuard) RecordAttempt(ctx context.Context, ip, email string, success bool) (bool, error) {
if _, ok := g.whitelist[ip]; ok {
return false, nil
}
// Insert attempt.
_, err := g.db.SQL().ExecContext(ctx,
"INSERT INTO login_attempts (ip, user_email, success, created_at) VALUES (?, ?, ?, ?)",
ip, email, success, time.Now().UTC())
if err != nil {
return false, fmt.Errorf("record attempt: %w", err)
}
if success {
// Successful login resets the counter (remove recent failed attempts).
window := time.Now().UTC().Add(-time.Duration(g.windowMin) * time.Minute)
_, _ = g.db.SQL().ExecContext(ctx,
"DELETE FROM login_attempts WHERE ip = ? AND success = 0 AND created_at >= ?",
ip, window)
return false, nil
}
// Count failures in window.
window := time.Now().UTC().Add(-time.Duration(g.windowMin) * time.Minute)
var count int
err = g.db.SQL().QueryRowContext(ctx,
"SELECT COUNT(*) FROM login_attempts WHERE ip = ? AND success = 0 AND created_at >= ?",
ip, window).Scan(&count)
if err != nil {
return false, fmt.Errorf("count attempts: %w", err)
}
if count < g.maxTries {
return false, nil
}
// Threshold exceeded — ban the IP.
var expiresAt *time.Time
if g.banHours > 0 {
t := time.Now().UTC().Add(time.Duration(g.banHours) * time.Hour)
expiresAt = &t
}
_, err = g.db.SQL().ExecContext(ctx, `
INSERT INTO ip_bans (ip, reason, banned_at, expires_at)
VALUES (?, ?, ?, ?)
ON CONFLICT(ip) DO UPDATE SET
reason = excluded.reason,
banned_at = excluded.banned_at,
expires_at = excluded.expires_at`,
ip,
fmt.Sprintf("brute force: %d failed attempts in %d minutes", count, g.windowMin),
time.Now().UTC(),
expiresAt,
)
if err != nil {
return false, fmt.Errorf("ban ip: %w", err)
}
// Log security event.
_, _ = g.db.SQL().ExecContext(ctx, `
INSERT INTO security_events (type, ip, detail, created_at)
VALUES ('brute_ban', ?, ?, ?)`,
ip,
fmt.Sprintf("%d failed attempts in %d min window, last target: %s", count, g.windowMin, email),
time.Now().UTC(),
)
return true, nil
}
// BanIP manually bans an IP (from admin panel).
func (g *BruteGuard) BanIP(ctx context.Context, ip, reason, bannedBy string, hours int) error {
var expiresAt *time.Time
if hours > 0 {
t := time.Now().UTC().Add(time.Duration(hours) * time.Hour)
expiresAt = &t
}
_, err := g.db.SQL().ExecContext(ctx, `
INSERT INTO ip_bans (ip, reason, banned_at, expires_at)
VALUES (?, ?, ?, ?)
ON CONFLICT(ip) DO UPDATE SET
reason = excluded.reason,
banned_at = excluded.banned_at,
expires_at = excluded.expires_at`,
ip, reason, time.Now().UTC(), expiresAt)
return err
}
// UnbanIP removes a ban.
func (g *BruteGuard) UnbanIP(ctx context.Context, ip, releasedBy string) error {
_, err := g.db.SQL().ExecContext(ctx,
"DELETE FROM ip_bans WHERE ip = ?", ip)
return err
}
// ListBans returns active bans ordered by newest first.
func (g *BruteGuard) ListBans(ctx context.Context, limit int) ([]banRow, error) {
rows, err := g.db.SQL().QueryContext(ctx, `
SELECT ip, reason, banned_at, expires_at
FROM ip_bans
ORDER BY banned_at DESC
LIMIT ?`, limit)
if err != nil {
return nil, err
}
defer rows.Close()
var bans []banRow
for rows.Next() {
var b banRow
var exp sql.NullTime
if err := rows.Scan(&b.IP, &b.Reason, &b.BannedAt, &exp); err != nil {
return nil, err
}
if exp.Valid {
b.ExpiresAt = &exp.Time
}
bans = append(bans, b)
}
return bans, rows.Err()
}
// PurgeOldAttempts removes login attempt records older than windowMin*2.
// Call periodically to keep the table small.
func (g *BruteGuard) PurgeOldAttempts(ctx context.Context) error {
cutoff := time.Now().UTC().Add(-2 * time.Duration(g.windowMin) * time.Minute)
_, err := g.db.SQL().ExecContext(ctx,
"DELETE FROM login_attempts WHERE created_at < ?", cutoff)
return err
}
type banRow struct {
IP string
Reason string
BannedAt time.Time
ExpiresAt *time.Time
}