Files
crowdsec-dashy/internal/geoip/updater.go
T

234 lines
5.2 KiB
Go

// Package geoip manages automatic download and refresh of the ipinfo.io MMDB file.
package geoip
import (
"context"
"fmt"
"io"
"log"
"net/http"
"os"
"path/filepath"
"sync"
"time"
)
// baseURL is the ipinfo.io free database endpoint — not user-configurable to prevent SSRF.
const baseURL = "https://ipinfo.io/data/free/"
// Updater downloads and periodically refreshes an ipinfo.io MMDB file.
type Updater struct {
token string
dbFile string // e.g. "asn.mmdb"
dbPath string // absolute local destination
refreshDays int
mu sync.RWMutex
lastUpdated time.Time
lastErr error
updating bool
http *http.Client
}
// Status is a point-in-time snapshot returned by Status().
type Status struct {
DBPath string
DBFile string
LastUpdated time.Time
NextRefresh time.Time
LastErrMsg string
Updating bool
DBExists bool
DBSizeBytes int64
DBSizeHuman string
TokenSet bool
RefreshDays int
}
// New creates an Updater. Call Start in a goroutine to enable auto-refresh.
func New(token, dbFile, dbPath string, refreshDays int) *Updater {
return &Updater{
token: token,
dbFile: dbFile,
dbPath: dbPath,
refreshDays: refreshDays,
http: &http.Client{
Timeout: 10 * time.Minute,
},
}
}
// Status returns current state of the updater and DB file.
func (u *Updater) Status() Status {
u.mu.RLock()
defer u.mu.RUnlock()
s := Status{
DBPath: u.dbPath,
DBFile: u.dbFile,
LastUpdated: u.lastUpdated,
Updating: u.updating,
TokenSet: u.token != "",
RefreshDays: u.refreshDays,
}
if u.lastErr != nil {
s.LastErrMsg = u.lastErr.Error()
}
info, err := os.Stat(u.dbPath)
if err == nil {
s.DBExists = true
s.DBSizeBytes = info.Size()
s.DBSizeHuman = formatBytes(info.Size())
if u.lastUpdated.IsZero() {
s.LastUpdated = info.ModTime()
}
}
if !s.LastUpdated.IsZero() {
s.NextRefresh = s.LastUpdated.Add(time.Duration(u.refreshDays) * 24 * time.Hour)
}
return s
}
// Refresh downloads the DB file atomically. Safe to call concurrently — second
// caller gets "already in progress" error immediately.
func (u *Updater) Refresh(ctx context.Context) error {
u.mu.Lock()
if u.updating {
u.mu.Unlock()
return fmt.Errorf("update already in progress")
}
if u.token == "" {
u.mu.Unlock()
return fmt.Errorf("ipinfo_token not configured in app_config.conf")
}
u.updating = true
u.mu.Unlock()
defer func() {
u.mu.Lock()
u.updating = false
u.mu.Unlock()
}()
err := u.download(ctx)
u.mu.Lock()
if err != nil {
u.lastErr = err
} else {
u.lastUpdated = time.Now()
u.lastErr = nil
}
u.mu.Unlock()
return err
}
func (u *Updater) download(ctx context.Context) error {
dlURL := baseURL + u.dbFile + "?token=" + u.token
req, err := http.NewRequestWithContext(ctx, http.MethodGet, dlURL, nil)
if err != nil {
return fmt.Errorf("build request: %w", err)
}
req.Header.Set("User-Agent", "crowdsec-dashy/1.0")
resp, err := u.http.Do(req)
if err != nil {
return fmt.Errorf("download: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 512))
return fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body))
}
// Temp file in same directory → atomic rename (same filesystem).
dir := filepath.Dir(u.dbPath)
if err := os.MkdirAll(dir, 0755); err != nil {
return fmt.Errorf("create destination dir: %w", err)
}
tmp, err := os.CreateTemp(dir, ".ipinfo-*.mmdb.tmp")
if err != nil {
return fmt.Errorf("create temp file: %w", err)
}
tmpPath := tmp.Name()
_, copyErr := io.Copy(tmp, resp.Body)
tmp.Close()
if copyErr != nil {
os.Remove(tmpPath)
return fmt.Errorf("write temp file: %w", copyErr)
}
if err := os.Rename(tmpPath, u.dbPath); err != nil {
os.Remove(tmpPath)
return fmt.Errorf("rename to %s: %w", u.dbPath, err)
}
return nil
}
// Start runs the background refresh scheduler until ctx is cancelled.
// Call as a goroutine: go updater.Start(ctx).
func (u *Updater) Start(ctx context.Context) {
if u.token == "" {
log.Println("[geoip] ipinfo_token not set — auto-refresh disabled")
return
}
s := u.Status()
needsRefresh := !s.DBExists ||
(!s.LastUpdated.IsZero() && time.Now().After(s.NextRefresh))
if needsRefresh {
log.Println("[geoip] DB missing or stale — refreshing now")
if err := u.Refresh(ctx); err != nil {
log.Printf("[geoip] initial refresh failed: %v", err)
} else {
log.Printf("[geoip] DB saved to %s", u.dbPath)
}
}
// Check twice daily whether a scheduled refresh is due.
ticker := time.NewTicker(12 * time.Hour)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
st := u.Status()
if !st.TokenSet || time.Now().Before(st.NextRefresh) {
continue
}
log.Println("[geoip] scheduled refresh starting")
if err := u.Refresh(ctx); err != nil {
log.Printf("[geoip] scheduled refresh failed: %v", err)
} else {
log.Printf("[geoip] scheduled refresh complete")
}
}
}
}
func formatBytes(n int64) string {
const unit = 1024
if n < unit {
return fmt.Sprintf("%d B", n)
}
div, exp := int64(unit), 0
for x := n / unit; x >= unit; x /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %cB", float64(n)/float64(div), "KMGTPE"[exp])
}