172 lines
4.4 KiB
Go
172 lines
4.4 KiB
Go
// Package db provides the database/sql wrapper and driver registration.
|
|
// Default driver: modernc.org/sqlite (pure Go, no CGO).
|
|
// Additional drivers registered via build tags: postgres, mysql, mssql.
|
|
package db
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"time"
|
|
|
|
_ "modernc.org/sqlite" // pure-Go SQLite driver
|
|
)
|
|
|
|
// DB wraps sql.DB with convenience methods and prepared statement caching.
|
|
type DB struct {
|
|
db *sql.DB
|
|
driver string
|
|
}
|
|
|
|
// Open opens and validates the database connection, runs migrations, returns DB.
|
|
func Open(driver, dsn string) (*DB, error) {
|
|
if driver == "" {
|
|
driver = "sqlite"
|
|
}
|
|
|
|
// Map friendly driver names to database/sql driver names.
|
|
sqlDriver := sqlDriverName(driver)
|
|
|
|
if driver == "sqlite" {
|
|
dsn = sqliteDSN(dsn)
|
|
}
|
|
|
|
sqlDB, err := sql.Open(sqlDriver, dsn)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("db open %s: %w", driver, err)
|
|
}
|
|
|
|
// Connection pool tuning.
|
|
if driver == "sqlite" {
|
|
// SQLite: serialise with single connection to avoid SQLITE_BUSY.
|
|
sqlDB.SetMaxOpenConns(1)
|
|
sqlDB.SetMaxIdleConns(1)
|
|
sqlDB.SetConnMaxLifetime(0)
|
|
} else {
|
|
sqlDB.SetMaxOpenConns(25)
|
|
sqlDB.SetMaxIdleConns(10)
|
|
sqlDB.SetConnMaxLifetime(5 * time.Minute)
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
defer cancel()
|
|
if err := sqlDB.PingContext(ctx); err != nil {
|
|
sqlDB.Close()
|
|
return nil, fmt.Errorf("db ping: %w", err)
|
|
}
|
|
|
|
d := &DB{db: sqlDB, driver: driver}
|
|
|
|
// Enable WAL mode for SQLite (dramatically improves concurrent read performance).
|
|
if driver == "sqlite" {
|
|
if _, err := sqlDB.Exec(`PRAGMA journal_mode=WAL`); err != nil {
|
|
sqlDB.Close()
|
|
return nil, fmt.Errorf("sqlite WAL: %w", err)
|
|
}
|
|
if _, err := sqlDB.Exec(`PRAGMA foreign_keys=ON`); err != nil {
|
|
sqlDB.Close()
|
|
return nil, fmt.Errorf("sqlite foreign_keys: %w", err)
|
|
}
|
|
if _, err := sqlDB.Exec(`PRAGMA busy_timeout=5000`); err != nil {
|
|
sqlDB.Close()
|
|
return nil, fmt.Errorf("sqlite busy_timeout: %w", err)
|
|
}
|
|
}
|
|
|
|
if err := d.migrate(); err != nil {
|
|
sqlDB.Close()
|
|
return nil, fmt.Errorf("migrate: %w", err)
|
|
}
|
|
|
|
return d, nil
|
|
}
|
|
|
|
// Close closes the underlying sql.DB.
|
|
func (d *DB) Close() error { return d.db.Close() }
|
|
|
|
// Driver returns the driver name (sqlite / postgres / mysql / mssql).
|
|
func (d *DB) Driver() string { return d.driver }
|
|
|
|
// SQL returns the underlying *sql.DB for direct use when needed.
|
|
func (d *DB) SQL() *sql.DB { return d.db }
|
|
|
|
// Exec runs a query with a per-call context timeout.
|
|
func (d *DB) Exec(query string, args ...any) (sql.Result, error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
defer cancel()
|
|
return d.db.ExecContext(ctx, query, args...)
|
|
}
|
|
|
|
// QueryRow runs a single-row query.
|
|
func (d *DB) QueryRow(query string, args ...any) *sql.Row {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
defer cancel()
|
|
return d.db.QueryRowContext(ctx, query, args...)
|
|
}
|
|
|
|
// Query runs a multi-row query.
|
|
func (d *DB) Query(query string, args ...any) (*sql.Rows, error) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
defer cancel()
|
|
return d.db.QueryContext(ctx, query, args...)
|
|
}
|
|
|
|
// WithTx runs fn inside a transaction, rolling back on error or panic.
|
|
func (d *DB) WithTx(ctx context.Context, fn func(*sql.Tx) error) error {
|
|
tx, err := d.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("begin tx: %w", err)
|
|
}
|
|
defer func() {
|
|
if p := recover(); p != nil {
|
|
_ = tx.Rollback()
|
|
panic(p) // re-raise
|
|
}
|
|
}()
|
|
if err := fn(tx); err != nil {
|
|
_ = tx.Rollback()
|
|
return err
|
|
}
|
|
return tx.Commit()
|
|
}
|
|
|
|
// ---- Placeholder helper ----
|
|
|
|
// Placeholder returns the SQL parameter placeholder for the current driver.
|
|
// SQLite and MySQL use ?, PostgreSQL uses $1, $2… MSSQL uses @p1, @p2…
|
|
func (d *DB) Placeholder(n int) string {
|
|
switch d.driver {
|
|
case "postgres":
|
|
return fmt.Sprintf("$%d", n)
|
|
case "mssql":
|
|
return fmt.Sprintf("@p%d", n)
|
|
default:
|
|
return "?"
|
|
}
|
|
}
|
|
|
|
// ---- Private helpers ----
|
|
|
|
func sqlDriverName(driver string) string {
|
|
switch driver {
|
|
case "sqlite":
|
|
return "sqlite" // modernc.org/sqlite registers as "sqlite"
|
|
case "postgres":
|
|
return "postgres"
|
|
case "mysql":
|
|
return "mysql"
|
|
case "mssql":
|
|
return "sqlserver"
|
|
default:
|
|
return driver
|
|
}
|
|
}
|
|
|
|
func sqliteDSN(path string) string {
|
|
if path == "" {
|
|
path = "./data/mail.db"
|
|
}
|
|
// modernc.org/sqlite DSN supports query parameters.
|
|
return path + "?_pragma=foreign_keys(1)"
|
|
}
|