443 lines
14 KiB
Go
443 lines
14 KiB
Go
package auth
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"database/sql"
|
|
"encoding/base64"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"golang.org/x/crypto/bcrypt"
|
|
_ "modernc.org/sqlite"
|
|
|
|
"gobsidian/internal/config"
|
|
)
|
|
|
|
type Service struct {
|
|
DB *sql.DB
|
|
Config *config.Config
|
|
}
|
|
|
|
// HasReadAccess determines if a user (or the public) can read a given path.
|
|
// Semantics:
|
|
// - If no permission rows exist matching the path (by prefix), access is ALLOWED by default.
|
|
// - If permission rows exist for the path:
|
|
// - Access is allowed if any matching row has can_read=1 and the user belongs to that group.
|
|
// - Unauthenticated users are treated as belonging only to the implicit 'public' group.
|
|
// - Any row for group name 'public' grants access to everyone.
|
|
// Path matching uses prefix match on path_prefix.
|
|
func (s *Service) HasReadAccess(userID *int64, path string) (bool, error) {
|
|
// First check if any permission rows exist for this path prefix
|
|
var total int
|
|
err := s.DB.QueryRow(`
|
|
SELECT COUNT(1)
|
|
FROM permissions p
|
|
WHERE ? LIKE p.path_prefix || '%'
|
|
`, path).Scan(&total)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
// Default allow when no permissions defined for this path
|
|
if total == 0 {
|
|
return true, nil
|
|
}
|
|
|
|
// If user is nil (public), allow if any matching permission for group 'public' with can_read
|
|
if userID == nil {
|
|
var exists int
|
|
err := s.DB.QueryRow(`
|
|
SELECT 1
|
|
FROM permissions p
|
|
JOIN groups g ON g.id = p.group_id
|
|
WHERE p.can_read = 1
|
|
AND ? LIKE p.path_prefix || '%'
|
|
AND g.name = 'public'
|
|
LIMIT 1
|
|
`, path).Scan(&exists)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return false, nil
|
|
}
|
|
return false, err
|
|
}
|
|
return true, nil
|
|
}
|
|
|
|
// For authenticated users, allow if any matching row with can_read and group's either 'public' or user is member
|
|
var exists int
|
|
err = s.DB.QueryRow(`
|
|
SELECT 1
|
|
FROM permissions p
|
|
JOIN groups g ON g.id = p.group_id
|
|
LEFT JOIN user_groups ug ON ug.group_id = g.id AND ug.user_id = ?
|
|
WHERE p.can_read = 1
|
|
AND ? LIKE p.path_prefix || '%'
|
|
AND (g.name = 'public' OR ug.user_id IS NOT NULL)
|
|
LIMIT 1
|
|
`, *userID, path).Scan(&exists)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return false, nil
|
|
}
|
|
return false, err
|
|
}
|
|
return true, nil
|
|
}
|
|
|
|
// IsUserInGroup returns true if the given user is a member of the given group name.
|
|
func (s *Service) IsUserInGroup(userID int64, groupName string) (bool, error) {
|
|
var exists int
|
|
err := s.DB.QueryRow(`
|
|
SELECT 1
|
|
FROM user_groups ug
|
|
JOIN groups g ON g.id = ug.group_id
|
|
WHERE ug.user_id = ? AND g.name = ?
|
|
LIMIT 1
|
|
`, userID, groupName).Scan(&exists)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return false, nil
|
|
}
|
|
return false, err
|
|
}
|
|
return true, nil
|
|
}
|
|
|
|
type User struct {
|
|
ID int64
|
|
Username string
|
|
Email string
|
|
PasswordHash string
|
|
IsActive bool
|
|
EmailConfirmed bool
|
|
MFASecret sql.NullString
|
|
CreatedAt time.Time
|
|
UpdatedAt time.Time
|
|
}
|
|
|
|
func Open(cfg *config.Config) (*Service, error) {
|
|
if strings.ToLower(cfg.DBType) != "sqlite" {
|
|
return nil, fmt.Errorf("unsupported DB type: %s", cfg.DBType)
|
|
}
|
|
dsn := cfg.DBPath
|
|
db, err := sql.Open("sqlite", dsn)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if _, err := db.Exec("PRAGMA foreign_keys = ON;"); err != nil {
|
|
return nil, err
|
|
}
|
|
s := &Service{DB: db, Config: cfg}
|
|
if err := s.migrate(); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := s.ensureDefaultAdmin(); err != nil {
|
|
return nil, err
|
|
}
|
|
return s, nil
|
|
}
|
|
|
|
func (s *Service) migrate() error {
|
|
stmts := []string{
|
|
`CREATE TABLE IF NOT EXISTS users (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
username TEXT UNIQUE NOT NULL,
|
|
email TEXT UNIQUE NOT NULL,
|
|
password_hash TEXT NOT NULL,
|
|
is_active INTEGER NOT NULL DEFAULT 1,
|
|
email_confirmed INTEGER NOT NULL DEFAULT 0,
|
|
mfa_secret TEXT NULL,
|
|
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
|
)` ,
|
|
`CREATE TABLE IF NOT EXISTS groups (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
name TEXT UNIQUE NOT NULL
|
|
)` ,
|
|
`CREATE TABLE IF NOT EXISTS user_groups (
|
|
user_id INTEGER NOT NULL,
|
|
group_id INTEGER NOT NULL,
|
|
PRIMARY KEY(user_id, group_id),
|
|
FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE,
|
|
FOREIGN KEY(group_id) REFERENCES groups(id) ON DELETE CASCADE
|
|
)` ,
|
|
`CREATE TABLE IF NOT EXISTS permissions (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
group_id INTEGER NOT NULL,
|
|
path_prefix TEXT NOT NULL,
|
|
can_read INTEGER NOT NULL DEFAULT 1,
|
|
can_write INTEGER NOT NULL DEFAULT 0,
|
|
can_delete INTEGER NOT NULL DEFAULT 0,
|
|
FOREIGN KEY(group_id) REFERENCES groups(id) ON DELETE CASCADE
|
|
)` ,
|
|
`CREATE TABLE IF NOT EXISTS email_verification_tokens (
|
|
user_id INTEGER NOT NULL,
|
|
token TEXT NOT NULL UNIQUE,
|
|
expires_at DATETIME NOT NULL,
|
|
PRIMARY KEY(user_id),
|
|
FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE
|
|
)` ,
|
|
`CREATE TABLE IF NOT EXISTS password_reset_tokens (
|
|
user_id INTEGER NOT NULL,
|
|
token TEXT NOT NULL UNIQUE,
|
|
expires_at DATETIME NOT NULL,
|
|
PRIMARY KEY(user_id),
|
|
FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE
|
|
)` ,
|
|
`CREATE TABLE IF NOT EXISTS mfa_enrollments (
|
|
user_id INTEGER PRIMARY KEY,
|
|
secret TEXT NOT NULL,
|
|
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE
|
|
)` ,
|
|
// Access logs for all requests
|
|
`CREATE TABLE IF NOT EXISTS access_logs (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
user_id INTEGER NULL,
|
|
ip TEXT NOT NULL,
|
|
method TEXT NOT NULL,
|
|
path TEXT NOT NULL,
|
|
status INTEGER NOT NULL,
|
|
duration_ms INTEGER NOT NULL,
|
|
user_agent TEXT NULL,
|
|
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE SET NULL
|
|
)` ,
|
|
`CREATE INDEX IF NOT EXISTS idx_access_logs_created_at ON access_logs(created_at)` ,
|
|
`CREATE INDEX IF NOT EXISTS idx_access_logs_user_id ON access_logs(user_id)` ,
|
|
`CREATE INDEX IF NOT EXISTS idx_access_logs_ip ON access_logs(ip)` ,
|
|
|
|
// Error logs for server-side errors
|
|
`CREATE TABLE IF NOT EXISTS error_logs (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
user_id INTEGER NULL,
|
|
ip TEXT NULL,
|
|
path TEXT NULL,
|
|
message TEXT NOT NULL,
|
|
stack TEXT NULL,
|
|
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE SET NULL
|
|
)` ,
|
|
`CREATE INDEX IF NOT EXISTS idx_error_logs_created_at ON error_logs(created_at)` ,
|
|
|
|
// Failed login attempts (both password and MFA tracked separately)
|
|
`CREATE TABLE IF NOT EXISTS failed_logins (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
ip TEXT NOT NULL,
|
|
user_id INTEGER NULL,
|
|
username TEXT NULL,
|
|
type TEXT NOT NULL CHECK(type IN ('password','mfa')),
|
|
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE SET NULL
|
|
)` ,
|
|
`CREATE INDEX IF NOT EXISTS idx_failed_logins_ip_created ON failed_logins(ip, created_at)` ,
|
|
`CREATE INDEX IF NOT EXISTS idx_failed_logins_type_created ON failed_logins(type, created_at)` ,
|
|
|
|
// IP bans and whitelist
|
|
`CREATE TABLE IF NOT EXISTS ip_bans (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
ip TEXT NOT NULL UNIQUE,
|
|
reason TEXT NULL,
|
|
until DATETIME NULL,
|
|
permanent INTEGER NOT NULL DEFAULT 0,
|
|
whitelisted INTEGER NOT NULL DEFAULT 0,
|
|
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
|
)` ,
|
|
`CREATE INDEX IF NOT EXISTS idx_ip_bans_ip ON ip_bans(ip)` ,
|
|
`CREATE INDEX IF NOT EXISTS idx_ip_bans_whitelist ON ip_bans(whitelisted)` ,
|
|
`CREATE INDEX IF NOT EXISTS idx_ip_bans_permanent ON ip_bans(permanent)` ,
|
|
}
|
|
for _, stmt := range stmts {
|
|
if _, err := s.DB.Exec(stmt); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Service) ensureDefaultAdmin() error {
|
|
tx, err := s.DB.Begin()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() {
|
|
if err != nil {
|
|
_ = tx.Rollback()
|
|
}
|
|
}()
|
|
|
|
// Ensure groups exist
|
|
if _, err = tx.Exec(`INSERT OR IGNORE INTO groups (name) VALUES (?), (?)`, "admin", "public"); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Ensure admin user exists
|
|
var adminID int64
|
|
err = tx.QueryRow(`SELECT id FROM users WHERE username = ?`, "admin").Scan(&adminID)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
pwHash, e := bcrypt.GenerateFromPassword([]byte("admin"), bcrypt.DefaultCost)
|
|
if e != nil {
|
|
return e
|
|
}
|
|
res, e := tx.Exec(`INSERT INTO users (username, email, password_hash, is_active, email_confirmed) VALUES (?,?,?,?,?)`,
|
|
"admin", "admin@local", string(pwHash), 1, 1,
|
|
)
|
|
if e != nil {
|
|
return e
|
|
}
|
|
adminID, err = res.LastInsertId()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
} else {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// Ensure admin group id
|
|
var adminGroupID int64
|
|
if err = tx.QueryRow(`SELECT id FROM groups WHERE name = ?`, "admin").Scan(&adminGroupID); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Ensure membership admin -> admin group
|
|
if _, err = tx.Exec(`INSERT OR IGNORE INTO user_groups (user_id, group_id) VALUES (?, ?)`, adminID, adminGroupID); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Commit
|
|
if err = tx.Commit(); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Service) Authenticate(usernameOrEmail, password string) (*User, error) {
|
|
u := &User{}
|
|
row := s.DB.QueryRow(`SELECT id, username, email, password_hash, is_active, email_confirmed, mfa_secret, created_at, updated_at FROM users WHERE username = ? OR email = ?`, usernameOrEmail, usernameOrEmail)
|
|
var mfa sql.NullString
|
|
if err := row.Scan(&u.ID, &u.Username, &u.Email, &u.PasswordHash, &u.IsActive, &u.EmailConfirmed, &mfa, &u.CreatedAt, &u.UpdatedAt); err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, errors.New("invalid credentials")
|
|
}
|
|
return nil, err
|
|
}
|
|
if err := bcrypt.CompareHashAndPassword([]byte(u.PasswordHash), []byte(password)); err != nil {
|
|
return nil, errors.New("invalid credentials")
|
|
}
|
|
u.MFASecret = mfa
|
|
if !u.IsActive {
|
|
return nil, errors.New("account not active")
|
|
}
|
|
if s.Config.RequireEmailConfirmation && !u.EmailConfirmed {
|
|
return nil, errors.New("email not confirmed")
|
|
}
|
|
return u, nil
|
|
}
|
|
|
|
// CSRF utilities
|
|
const csrfSessionKey = "csrf_token"
|
|
const csrfHeader = "X-CSRF-Token"
|
|
const csrfFormField = "csrf_token"
|
|
|
|
func randomToken(n int) (string, error) {
|
|
b := make([]byte, n)
|
|
if _, err := rand.Read(b); err != nil {
|
|
return "", err
|
|
}
|
|
return base64.RawURLEncoding.EncodeToString(b), nil
|
|
}
|
|
|
|
func IssueCSRF() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
sessAny, exists := c.Get("session")
|
|
if !exists {
|
|
c.Next()
|
|
return
|
|
}
|
|
sess, _ := sessAny.(map[string]interface{})
|
|
var token string
|
|
if v, ok := sess[csrfSessionKey].(string); ok && v != "" {
|
|
token = v
|
|
} else {
|
|
var err error
|
|
token, err = randomToken(32)
|
|
if err == nil {
|
|
sess[csrfSessionKey] = token
|
|
}
|
|
}
|
|
c.Set("csrf_token", token)
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
func RequireCSRF() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead || c.Request.Method == http.MethodOptions {
|
|
c.Next()
|
|
return
|
|
}
|
|
var token string
|
|
// header first
|
|
if h := c.GetHeader(csrfHeader); h != "" {
|
|
token = h
|
|
} else {
|
|
token = c.PostForm(csrfFormField)
|
|
}
|
|
sessAny, exists := c.Get("session")
|
|
if !exists {
|
|
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "missing session for CSRF"})
|
|
return
|
|
}
|
|
sess, _ := sessAny.(map[string]interface{})
|
|
if expected, ok := sess[csrfSessionKey].(string); !ok || expected == "" || token == "" || !hmacEqual(expected, token) {
|
|
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "invalid CSRF token"})
|
|
return
|
|
}
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
func hmacEqual(a, b string) bool {
|
|
if len(a) != len(b) {
|
|
return false
|
|
}
|
|
var v byte
|
|
for i := 0; i < len(a); i++ {
|
|
v |= a[i] ^ b[i]
|
|
}
|
|
return v == 0
|
|
}
|
|
|
|
// Session helpers
|
|
const sessionCookieName = "gobsidian_session"
|
|
|
|
// AttachSession should be installed early to create a simple session map backed by cookie store in handlers package
|
|
func AttachSession() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
// Handlers already use gorilla sessions. Here we just ensure a map exists to carry CSRF and auth info for templates.
|
|
if _, exists := c.Get("session"); !exists {
|
|
c.Set("session", map[string]interface{}{})
|
|
}
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
// RequireAuth checks presence of user_id in context
|
|
func RequireAuth() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
if _, exists := c.Get("user_id"); !exists {
|
|
c.Redirect(http.StatusFound, "/editor/login")
|
|
c.Abort()
|
|
return
|
|
}
|
|
c.Next()
|
|
}
|
|
}
|