202 lines
5.4 KiB
Go
202 lines
5.4 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"time"
|
|
|
|
"ghb.freebede.com/nahakubuilder/mailgosend/internal/db"
|
|
)
|
|
|
|
// BruteGuard tracks login attempts and bans IPs exceeding the threshold.
|
|
type BruteGuard struct {
|
|
db *db.DB
|
|
maxTries int
|
|
windowMin int
|
|
banHours int
|
|
whitelist map[string]struct{} // exempt IPs
|
|
}
|
|
|
|
// NewBruteGuard creates a brute-force guard.
|
|
func NewBruteGuard(database *db.DB, maxTries, windowMin, banHours int, whitelist []string) *BruteGuard {
|
|
wl := make(map[string]struct{}, len(whitelist))
|
|
for _, ip := range whitelist {
|
|
wl[ip] = struct{}{}
|
|
}
|
|
return &BruteGuard{
|
|
db: database,
|
|
maxTries: maxTries,
|
|
windowMin: windowMin,
|
|
banHours: banHours,
|
|
whitelist: wl,
|
|
}
|
|
}
|
|
|
|
// IsBanned returns true if the IP is currently banned.
|
|
func (g *BruteGuard) IsBanned(ctx context.Context, ip string) (bool, error) {
|
|
if _, ok := g.whitelist[ip]; ok {
|
|
return false, nil
|
|
}
|
|
|
|
var expiresAt sql.NullTime
|
|
err := g.db.SQL().QueryRowContext(ctx,
|
|
"SELECT expires_at FROM ip_bans WHERE ip = ?", ip).Scan(&expiresAt)
|
|
if err == sql.ErrNoRows {
|
|
return false, nil
|
|
}
|
|
if err != nil {
|
|
return false, fmt.Errorf("check ban: %w", err)
|
|
}
|
|
|
|
// Permanent ban (expires_at IS NULL) or not yet expired.
|
|
if !expiresAt.Valid || expiresAt.Time.After(time.Now().UTC()) {
|
|
return true, nil
|
|
}
|
|
|
|
// Expired ban — clean up.
|
|
_, _ = g.db.SQL().ExecContext(ctx, "DELETE FROM ip_bans WHERE ip = ?", ip)
|
|
return false, nil
|
|
}
|
|
|
|
// RecordAttempt records a login attempt and bans the IP if threshold exceeded.
|
|
// Returns (banned, error).
|
|
func (g *BruteGuard) RecordAttempt(ctx context.Context, ip, email string, success bool) (bool, error) {
|
|
if _, ok := g.whitelist[ip]; ok {
|
|
return false, nil
|
|
}
|
|
|
|
// Insert attempt.
|
|
_, err := g.db.SQL().ExecContext(ctx,
|
|
"INSERT INTO login_attempts (ip, user_email, success, created_at) VALUES (?, ?, ?, ?)",
|
|
ip, email, success, time.Now().UTC())
|
|
if err != nil {
|
|
return false, fmt.Errorf("record attempt: %w", err)
|
|
}
|
|
|
|
if success {
|
|
// Successful login resets the counter (remove recent failed attempts).
|
|
window := time.Now().UTC().Add(-time.Duration(g.windowMin) * time.Minute)
|
|
_, _ = g.db.SQL().ExecContext(ctx,
|
|
"DELETE FROM login_attempts WHERE ip = ? AND success = 0 AND created_at >= ?",
|
|
ip, window)
|
|
return false, nil
|
|
}
|
|
|
|
// Count failures in window.
|
|
window := time.Now().UTC().Add(-time.Duration(g.windowMin) * time.Minute)
|
|
var count int
|
|
err = g.db.SQL().QueryRowContext(ctx,
|
|
"SELECT COUNT(*) FROM login_attempts WHERE ip = ? AND success = 0 AND created_at >= ?",
|
|
ip, window).Scan(&count)
|
|
if err != nil {
|
|
return false, fmt.Errorf("count attempts: %w", err)
|
|
}
|
|
|
|
if count < g.maxTries {
|
|
return false, nil
|
|
}
|
|
|
|
// Threshold exceeded — ban the IP.
|
|
var expiresAt *time.Time
|
|
if g.banHours > 0 {
|
|
t := time.Now().UTC().Add(time.Duration(g.banHours) * time.Hour)
|
|
expiresAt = &t
|
|
}
|
|
|
|
_, err = g.db.SQL().ExecContext(ctx, `
|
|
INSERT INTO ip_bans (ip, reason, banned_at, expires_at)
|
|
VALUES (?, ?, ?, ?)
|
|
ON CONFLICT(ip) DO UPDATE SET
|
|
reason = excluded.reason,
|
|
banned_at = excluded.banned_at,
|
|
expires_at = excluded.expires_at`,
|
|
ip,
|
|
fmt.Sprintf("brute force: %d failed attempts in %d minutes", count, g.windowMin),
|
|
time.Now().UTC(),
|
|
expiresAt,
|
|
)
|
|
if err != nil {
|
|
return false, fmt.Errorf("ban ip: %w", err)
|
|
}
|
|
|
|
// Log security event.
|
|
_, _ = g.db.SQL().ExecContext(ctx, `
|
|
INSERT INTO security_events (type, ip, detail, created_at)
|
|
VALUES ('brute_ban', ?, ?, ?)`,
|
|
ip,
|
|
fmt.Sprintf("%d failed attempts in %d min window, last target: %s", count, g.windowMin, email),
|
|
time.Now().UTC(),
|
|
)
|
|
|
|
return true, nil
|
|
}
|
|
|
|
// BanIP manually bans an IP (from admin panel).
|
|
func (g *BruteGuard) BanIP(ctx context.Context, ip, reason, bannedBy string, hours int) error {
|
|
var expiresAt *time.Time
|
|
if hours > 0 {
|
|
t := time.Now().UTC().Add(time.Duration(hours) * time.Hour)
|
|
expiresAt = &t
|
|
}
|
|
_, err := g.db.SQL().ExecContext(ctx, `
|
|
INSERT INTO ip_bans (ip, reason, banned_at, expires_at)
|
|
VALUES (?, ?, ?, ?)
|
|
ON CONFLICT(ip) DO UPDATE SET
|
|
reason = excluded.reason,
|
|
banned_at = excluded.banned_at,
|
|
expires_at = excluded.expires_at`,
|
|
ip, reason, time.Now().UTC(), expiresAt)
|
|
return err
|
|
}
|
|
|
|
// UnbanIP removes a ban.
|
|
func (g *BruteGuard) UnbanIP(ctx context.Context, ip, releasedBy string) error {
|
|
_, err := g.db.SQL().ExecContext(ctx,
|
|
"DELETE FROM ip_bans WHERE ip = ?", ip)
|
|
return err
|
|
}
|
|
|
|
// ListBans returns active bans ordered by newest first.
|
|
func (g *BruteGuard) ListBans(ctx context.Context, limit int) ([]banRow, error) {
|
|
rows, err := g.db.SQL().QueryContext(ctx, `
|
|
SELECT ip, reason, banned_at, expires_at
|
|
FROM ip_bans
|
|
ORDER BY banned_at DESC
|
|
LIMIT ?`, limit)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var bans []banRow
|
|
for rows.Next() {
|
|
var b banRow
|
|
var exp sql.NullTime
|
|
if err := rows.Scan(&b.IP, &b.Reason, &b.BannedAt, &exp); err != nil {
|
|
return nil, err
|
|
}
|
|
if exp.Valid {
|
|
b.ExpiresAt = &exp.Time
|
|
}
|
|
bans = append(bans, b)
|
|
}
|
|
return bans, rows.Err()
|
|
}
|
|
|
|
// PurgeOldAttempts removes login attempt records older than windowMin*2.
|
|
// Call periodically to keep the table small.
|
|
func (g *BruteGuard) PurgeOldAttempts(ctx context.Context) error {
|
|
cutoff := time.Now().UTC().Add(-2 * time.Duration(g.windowMin) * time.Minute)
|
|
_, err := g.db.SQL().ExecContext(ctx,
|
|
"DELETE FROM login_attempts WHERE created_at < ?", cutoff)
|
|
return err
|
|
}
|
|
|
|
type banRow struct {
|
|
IP string
|
|
Reason string
|
|
BannedAt time.Time
|
|
ExpiresAt *time.Time
|
|
}
|