Files
gobsidian/internal/auth/service.go
2025-08-25 21:19:15 +01:00

386 lines
11 KiB
Go

package auth
import (
"crypto/rand"
"database/sql"
"encoding/base64"
"errors"
"fmt"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"golang.org/x/crypto/bcrypt"
_ "modernc.org/sqlite"
"gobsidian/internal/config"
)
type Service struct {
DB *sql.DB
Config *config.Config
}
// HasReadAccess determines if a user (or the public) can read a given path.
// Semantics:
// - If no permission rows exist matching the path (by prefix), access is ALLOWED by default.
// - If permission rows exist for the path:
// - Access is allowed if any matching row has can_read=1 and the user belongs to that group.
// - Unauthenticated users are treated as belonging only to the implicit 'public' group.
// - Any row for group name 'public' grants access to everyone.
// Path matching uses prefix match on path_prefix.
func (s *Service) HasReadAccess(userID *int64, path string) (bool, error) {
// First check if any permission rows exist for this path prefix
var total int
err := s.DB.QueryRow(`
SELECT COUNT(1)
FROM permissions p
WHERE ? LIKE p.path_prefix || '%'
`, path).Scan(&total)
if err != nil {
return false, err
}
// Default allow when no permissions defined for this path
if total == 0 {
return true, nil
}
// If user is nil (public), allow if any matching permission for group 'public' with can_read
if userID == nil {
var exists int
err := s.DB.QueryRow(`
SELECT 1
FROM permissions p
JOIN groups g ON g.id = p.group_id
WHERE p.can_read = 1
AND ? LIKE p.path_prefix || '%'
AND g.name = 'public'
LIMIT 1
`, path).Scan(&exists)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return false, nil
}
return false, err
}
return true, nil
}
// For authenticated users, allow if any matching row with can_read and group's either 'public' or user is member
var exists int
err = s.DB.QueryRow(`
SELECT 1
FROM permissions p
JOIN groups g ON g.id = p.group_id
LEFT JOIN user_groups ug ON ug.group_id = g.id AND ug.user_id = ?
WHERE p.can_read = 1
AND ? LIKE p.path_prefix || '%'
AND (g.name = 'public' OR ug.user_id IS NOT NULL)
LIMIT 1
`, *userID, path).Scan(&exists)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return false, nil
}
return false, err
}
return true, nil
}
// IsUserInGroup returns true if the given user is a member of the given group name.
func (s *Service) IsUserInGroup(userID int64, groupName string) (bool, error) {
var exists int
err := s.DB.QueryRow(`
SELECT 1
FROM user_groups ug
JOIN groups g ON g.id = ug.group_id
WHERE ug.user_id = ? AND g.name = ?
LIMIT 1
`, userID, groupName).Scan(&exists)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return false, nil
}
return false, err
}
return true, nil
}
type User struct {
ID int64
Username string
Email string
PasswordHash string
IsActive bool
EmailConfirmed bool
MFASecret sql.NullString
CreatedAt time.Time
UpdatedAt time.Time
}
func Open(cfg *config.Config) (*Service, error) {
if strings.ToLower(cfg.DBType) != "sqlite" {
return nil, fmt.Errorf("unsupported DB type: %s", cfg.DBType)
}
dsn := cfg.DBPath
db, err := sql.Open("sqlite", dsn)
if err != nil {
return nil, err
}
if _, err := db.Exec("PRAGMA foreign_keys = ON;"); err != nil {
return nil, err
}
s := &Service{DB: db, Config: cfg}
if err := s.migrate(); err != nil {
return nil, err
}
if err := s.ensureDefaultAdmin(); err != nil {
return nil, err
}
return s, nil
}
func (s *Service) migrate() error {
stmts := []string{
`CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT UNIQUE NOT NULL,
email TEXT UNIQUE NOT NULL,
password_hash TEXT NOT NULL,
is_active INTEGER NOT NULL DEFAULT 1,
email_confirmed INTEGER NOT NULL DEFAULT 0,
mfa_secret TEXT NULL,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
)` ,
`CREATE TABLE IF NOT EXISTS groups (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT UNIQUE NOT NULL
)` ,
`CREATE TABLE IF NOT EXISTS user_groups (
user_id INTEGER NOT NULL,
group_id INTEGER NOT NULL,
PRIMARY KEY(user_id, group_id),
FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE,
FOREIGN KEY(group_id) REFERENCES groups(id) ON DELETE CASCADE
)` ,
`CREATE TABLE IF NOT EXISTS permissions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
group_id INTEGER NOT NULL,
path_prefix TEXT NOT NULL,
can_read INTEGER NOT NULL DEFAULT 1,
can_write INTEGER NOT NULL DEFAULT 0,
can_delete INTEGER NOT NULL DEFAULT 0,
FOREIGN KEY(group_id) REFERENCES groups(id) ON DELETE CASCADE
)` ,
`CREATE TABLE IF NOT EXISTS email_verification_tokens (
user_id INTEGER NOT NULL,
token TEXT NOT NULL UNIQUE,
expires_at DATETIME NOT NULL,
PRIMARY KEY(user_id),
FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE
)` ,
`CREATE TABLE IF NOT EXISTS password_reset_tokens (
user_id INTEGER NOT NULL,
token TEXT NOT NULL UNIQUE,
expires_at DATETIME NOT NULL,
PRIMARY KEY(user_id),
FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE
)` ,
`CREATE TABLE IF NOT EXISTS mfa_enrollments (
user_id INTEGER PRIMARY KEY,
secret TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE
)` ,
}
for _, stmt := range stmts {
if _, err := s.DB.Exec(stmt); err != nil {
return err
}
}
return nil
}
func (s *Service) ensureDefaultAdmin() error {
tx, err := s.DB.Begin()
if err != nil {
return err
}
defer func() {
if err != nil {
_ = tx.Rollback()
}
}()
// Ensure groups exist
if _, err = tx.Exec(`INSERT OR IGNORE INTO groups (name) VALUES (?), (?)`, "admin", "public"); err != nil {
return err
}
// Ensure admin user exists
var adminID int64
err = tx.QueryRow(`SELECT id FROM users WHERE username = ?`, "admin").Scan(&adminID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
pwHash, e := bcrypt.GenerateFromPassword([]byte("admin"), bcrypt.DefaultCost)
if e != nil {
return e
}
res, e := tx.Exec(`INSERT INTO users (username, email, password_hash, is_active, email_confirmed) VALUES (?,?,?,?,?)`,
"admin", "admin@local", string(pwHash), 1, 1,
)
if e != nil {
return e
}
adminID, err = res.LastInsertId()
if err != nil {
return err
}
} else {
return err
}
}
// Ensure admin group id
var adminGroupID int64
if err = tx.QueryRow(`SELECT id FROM groups WHERE name = ?`, "admin").Scan(&adminGroupID); err != nil {
return err
}
// Ensure membership admin -> admin group
if _, err = tx.Exec(`INSERT OR IGNORE INTO user_groups (user_id, group_id) VALUES (?, ?)`, adminID, adminGroupID); err != nil {
return err
}
// Commit
if err = tx.Commit(); err != nil {
return err
}
return nil
}
func (s *Service) Authenticate(usernameOrEmail, password string) (*User, error) {
u := &User{}
row := s.DB.QueryRow(`SELECT id, username, email, password_hash, is_active, email_confirmed, mfa_secret, created_at, updated_at FROM users WHERE username = ? OR email = ?`, usernameOrEmail, usernameOrEmail)
var mfa sql.NullString
if err := row.Scan(&u.ID, &u.Username, &u.Email, &u.PasswordHash, &u.IsActive, &u.EmailConfirmed, &mfa, &u.CreatedAt, &u.UpdatedAt); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, errors.New("invalid credentials")
}
return nil, err
}
if err := bcrypt.CompareHashAndPassword([]byte(u.PasswordHash), []byte(password)); err != nil {
return nil, errors.New("invalid credentials")
}
u.MFASecret = mfa
if !u.IsActive {
return nil, errors.New("account not active")
}
if s.Config.RequireEmailConfirmation && !u.EmailConfirmed {
return nil, errors.New("email not confirmed")
}
return u, nil
}
// CSRF utilities
const csrfSessionKey = "csrf_token"
const csrfHeader = "X-CSRF-Token"
const csrfFormField = "csrf_token"
func randomToken(n int) (string, error) {
b := make([]byte, n)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
func IssueCSRF() gin.HandlerFunc {
return func(c *gin.Context) {
sessAny, exists := c.Get("session")
if !exists {
c.Next()
return
}
sess, _ := sessAny.(map[string]interface{})
var token string
if v, ok := sess[csrfSessionKey].(string); ok && v != "" {
token = v
} else {
var err error
token, err = randomToken(32)
if err == nil {
sess[csrfSessionKey] = token
}
}
c.Set("csrf_token", token)
c.Next()
}
}
func RequireCSRF() gin.HandlerFunc {
return func(c *gin.Context) {
if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead || c.Request.Method == http.MethodOptions {
c.Next()
return
}
var token string
// header first
if h := c.GetHeader(csrfHeader); h != "" {
token = h
} else {
token = c.PostForm(csrfFormField)
}
sessAny, exists := c.Get("session")
if !exists {
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "missing session for CSRF"})
return
}
sess, _ := sessAny.(map[string]interface{})
if expected, ok := sess[csrfSessionKey].(string); !ok || expected == "" || token == "" || !hmacEqual(expected, token) {
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "invalid CSRF token"})
return
}
c.Next()
}
}
func hmacEqual(a, b string) bool {
if len(a) != len(b) {
return false
}
var v byte
for i := 0; i < len(a); i++ {
v |= a[i] ^ b[i]
}
return v == 0
}
// Session helpers
const sessionCookieName = "gobsidian_session"
// AttachSession should be installed early to create a simple session map backed by cookie store in handlers package
func AttachSession() gin.HandlerFunc {
return func(c *gin.Context) {
// Handlers already use gorilla sessions. Here we just ensure a map exists to carry CSRF and auth info for templates.
if _, exists := c.Get("session"); !exists {
c.Set("session", map[string]interface{}{})
}
c.Next()
}
}
// RequireAuth checks presence of user_id in context
func RequireAuth() gin.HandlerFunc {
return func(c *gin.Context) {
if _, exists := c.Get("user_id"); !exists {
c.Redirect(http.StatusFound, "/editor/login")
c.Abort()
return
}
c.Next()
}
}