user authentication
This commit is contained in:
385
internal/auth/service.go
Normal file
385
internal/auth/service.go
Normal file
@@ -0,0 +1,385 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
_ "modernc.org/sqlite"
|
||||
|
||||
"gobsidian/internal/config"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
DB *sql.DB
|
||||
Config *config.Config
|
||||
}
|
||||
|
||||
// HasReadAccess determines if a user (or the public) can read a given path.
|
||||
// Semantics:
|
||||
// - If no permission rows exist matching the path (by prefix), access is ALLOWED by default.
|
||||
// - If permission rows exist for the path:
|
||||
// - Access is allowed if any matching row has can_read=1 and the user belongs to that group.
|
||||
// - Unauthenticated users are treated as belonging only to the implicit 'public' group.
|
||||
// - Any row for group name 'public' grants access to everyone.
|
||||
// Path matching uses prefix match on path_prefix.
|
||||
func (s *Service) HasReadAccess(userID *int64, path string) (bool, error) {
|
||||
// First check if any permission rows exist for this path prefix
|
||||
var total int
|
||||
err := s.DB.QueryRow(`
|
||||
SELECT COUNT(1)
|
||||
FROM permissions p
|
||||
WHERE ? LIKE p.path_prefix || '%'
|
||||
`, path).Scan(&total)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
// Default allow when no permissions defined for this path
|
||||
if total == 0 {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// If user is nil (public), allow if any matching permission for group 'public' with can_read
|
||||
if userID == nil {
|
||||
var exists int
|
||||
err := s.DB.QueryRow(`
|
||||
SELECT 1
|
||||
FROM permissions p
|
||||
JOIN groups g ON g.id = p.group_id
|
||||
WHERE p.can_read = 1
|
||||
AND ? LIKE p.path_prefix || '%'
|
||||
AND g.name = 'public'
|
||||
LIMIT 1
|
||||
`, path).Scan(&exists)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// For authenticated users, allow if any matching row with can_read and group's either 'public' or user is member
|
||||
var exists int
|
||||
err = s.DB.QueryRow(`
|
||||
SELECT 1
|
||||
FROM permissions p
|
||||
JOIN groups g ON g.id = p.group_id
|
||||
LEFT JOIN user_groups ug ON ug.group_id = g.id AND ug.user_id = ?
|
||||
WHERE p.can_read = 1
|
||||
AND ? LIKE p.path_prefix || '%'
|
||||
AND (g.name = 'public' OR ug.user_id IS NOT NULL)
|
||||
LIMIT 1
|
||||
`, *userID, path).Scan(&exists)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// IsUserInGroup returns true if the given user is a member of the given group name.
|
||||
func (s *Service) IsUserInGroup(userID int64, groupName string) (bool, error) {
|
||||
var exists int
|
||||
err := s.DB.QueryRow(`
|
||||
SELECT 1
|
||||
FROM user_groups ug
|
||||
JOIN groups g ON g.id = ug.group_id
|
||||
WHERE ug.user_id = ? AND g.name = ?
|
||||
LIMIT 1
|
||||
`, userID, groupName).Scan(&exists)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
type User struct {
|
||||
ID int64
|
||||
Username string
|
||||
Email string
|
||||
PasswordHash string
|
||||
IsActive bool
|
||||
EmailConfirmed bool
|
||||
MFASecret sql.NullString
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
func Open(cfg *config.Config) (*Service, error) {
|
||||
if strings.ToLower(cfg.DBType) != "sqlite" {
|
||||
return nil, fmt.Errorf("unsupported DB type: %s", cfg.DBType)
|
||||
}
|
||||
dsn := cfg.DBPath
|
||||
db, err := sql.Open("sqlite", dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := db.Exec("PRAGMA foreign_keys = ON;"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s := &Service{DB: db, Config: cfg}
|
||||
if err := s.migrate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.ensureDefaultAdmin(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *Service) migrate() error {
|
||||
stmts := []string{
|
||||
`CREATE TABLE IF NOT EXISTS users (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
username TEXT UNIQUE NOT NULL,
|
||||
email TEXT UNIQUE NOT NULL,
|
||||
password_hash TEXT NOT NULL,
|
||||
is_active INTEGER NOT NULL DEFAULT 1,
|
||||
email_confirmed INTEGER NOT NULL DEFAULT 0,
|
||||
mfa_secret TEXT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
)` ,
|
||||
`CREATE TABLE IF NOT EXISTS groups (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT UNIQUE NOT NULL
|
||||
)` ,
|
||||
`CREATE TABLE IF NOT EXISTS user_groups (
|
||||
user_id INTEGER NOT NULL,
|
||||
group_id INTEGER NOT NULL,
|
||||
PRIMARY KEY(user_id, group_id),
|
||||
FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY(group_id) REFERENCES groups(id) ON DELETE CASCADE
|
||||
)` ,
|
||||
`CREATE TABLE IF NOT EXISTS permissions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
group_id INTEGER NOT NULL,
|
||||
path_prefix TEXT NOT NULL,
|
||||
can_read INTEGER NOT NULL DEFAULT 1,
|
||||
can_write INTEGER NOT NULL DEFAULT 0,
|
||||
can_delete INTEGER NOT NULL DEFAULT 0,
|
||||
FOREIGN KEY(group_id) REFERENCES groups(id) ON DELETE CASCADE
|
||||
)` ,
|
||||
`CREATE TABLE IF NOT EXISTS email_verification_tokens (
|
||||
user_id INTEGER NOT NULL,
|
||||
token TEXT NOT NULL UNIQUE,
|
||||
expires_at DATETIME NOT NULL,
|
||||
PRIMARY KEY(user_id),
|
||||
FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE
|
||||
)` ,
|
||||
`CREATE TABLE IF NOT EXISTS password_reset_tokens (
|
||||
user_id INTEGER NOT NULL,
|
||||
token TEXT NOT NULL UNIQUE,
|
||||
expires_at DATETIME NOT NULL,
|
||||
PRIMARY KEY(user_id),
|
||||
FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE
|
||||
)` ,
|
||||
`CREATE TABLE IF NOT EXISTS mfa_enrollments (
|
||||
user_id INTEGER PRIMARY KEY,
|
||||
secret TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE
|
||||
)` ,
|
||||
}
|
||||
for _, stmt := range stmts {
|
||||
if _, err := s.DB.Exec(stmt); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) ensureDefaultAdmin() error {
|
||||
tx, err := s.DB.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
// Ensure groups exist
|
||||
if _, err = tx.Exec(`INSERT OR IGNORE INTO groups (name) VALUES (?), (?)`, "admin", "public"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensure admin user exists
|
||||
var adminID int64
|
||||
err = tx.QueryRow(`SELECT id FROM users WHERE username = ?`, "admin").Scan(&adminID)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
pwHash, e := bcrypt.GenerateFromPassword([]byte("admin"), bcrypt.DefaultCost)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
res, e := tx.Exec(`INSERT INTO users (username, email, password_hash, is_active, email_confirmed) VALUES (?,?,?,?,?)`,
|
||||
"admin", "admin@local", string(pwHash), 1, 1,
|
||||
)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
adminID, err = res.LastInsertId()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure admin group id
|
||||
var adminGroupID int64
|
||||
if err = tx.QueryRow(`SELECT id FROM groups WHERE name = ?`, "admin").Scan(&adminGroupID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensure membership admin -> admin group
|
||||
if _, err = tx.Exec(`INSERT OR IGNORE INTO user_groups (user_id, group_id) VALUES (?, ?)`, adminID, adminGroupID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Commit
|
||||
if err = tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) Authenticate(usernameOrEmail, password string) (*User, error) {
|
||||
u := &User{}
|
||||
row := s.DB.QueryRow(`SELECT id, username, email, password_hash, is_active, email_confirmed, mfa_secret, created_at, updated_at FROM users WHERE username = ? OR email = ?`, usernameOrEmail, usernameOrEmail)
|
||||
var mfa sql.NullString
|
||||
if err := row.Scan(&u.ID, &u.Username, &u.Email, &u.PasswordHash, &u.IsActive, &u.EmailConfirmed, &mfa, &u.CreatedAt, &u.UpdatedAt); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, errors.New("invalid credentials")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(u.PasswordHash), []byte(password)); err != nil {
|
||||
return nil, errors.New("invalid credentials")
|
||||
}
|
||||
u.MFASecret = mfa
|
||||
if !u.IsActive {
|
||||
return nil, errors.New("account not active")
|
||||
}
|
||||
if s.Config.RequireEmailConfirmation && !u.EmailConfirmed {
|
||||
return nil, errors.New("email not confirmed")
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// CSRF utilities
|
||||
const csrfSessionKey = "csrf_token"
|
||||
const csrfHeader = "X-CSRF-Token"
|
||||
const csrfFormField = "csrf_token"
|
||||
|
||||
func randomToken(n int) (string, error) {
|
||||
b := make([]byte, n)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func IssueCSRF() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
sessAny, exists := c.Get("session")
|
||||
if !exists {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
sess, _ := sessAny.(map[string]interface{})
|
||||
var token string
|
||||
if v, ok := sess[csrfSessionKey].(string); ok && v != "" {
|
||||
token = v
|
||||
} else {
|
||||
var err error
|
||||
token, err = randomToken(32)
|
||||
if err == nil {
|
||||
sess[csrfSessionKey] = token
|
||||
}
|
||||
}
|
||||
c.Set("csrf_token", token)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func RequireCSRF() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead || c.Request.Method == http.MethodOptions {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
var token string
|
||||
// header first
|
||||
if h := c.GetHeader(csrfHeader); h != "" {
|
||||
token = h
|
||||
} else {
|
||||
token = c.PostForm(csrfFormField)
|
||||
}
|
||||
sessAny, exists := c.Get("session")
|
||||
if !exists {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "missing session for CSRF"})
|
||||
return
|
||||
}
|
||||
sess, _ := sessAny.(map[string]interface{})
|
||||
if expected, ok := sess[csrfSessionKey].(string); !ok || expected == "" || token == "" || !hmacEqual(expected, token) {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "invalid CSRF token"})
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func hmacEqual(a, b string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
var v byte
|
||||
for i := 0; i < len(a); i++ {
|
||||
v |= a[i] ^ b[i]
|
||||
}
|
||||
return v == 0
|
||||
}
|
||||
|
||||
// Session helpers
|
||||
const sessionCookieName = "gobsidian_session"
|
||||
|
||||
// AttachSession should be installed early to create a simple session map backed by cookie store in handlers package
|
||||
func AttachSession() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Handlers already use gorilla sessions. Here we just ensure a map exists to carry CSRF and auth info for templates.
|
||||
if _, exists := c.Get("session"); !exists {
|
||||
c.Set("session", map[string]interface{}{})
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// RequireAuth checks presence of user_id in context
|
||||
func RequireAuth() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if _, exists := c.Get("user_id"); !exists {
|
||||
c.Redirect(http.StatusFound, "/editor/login")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user