user authentication
This commit is contained in:
385
internal/auth/service.go
Normal file
385
internal/auth/service.go
Normal file
@@ -0,0 +1,385 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
_ "modernc.org/sqlite"
|
||||
|
||||
"gobsidian/internal/config"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
DB *sql.DB
|
||||
Config *config.Config
|
||||
}
|
||||
|
||||
// HasReadAccess determines if a user (or the public) can read a given path.
|
||||
// Semantics:
|
||||
// - If no permission rows exist matching the path (by prefix), access is ALLOWED by default.
|
||||
// - If permission rows exist for the path:
|
||||
// - Access is allowed if any matching row has can_read=1 and the user belongs to that group.
|
||||
// - Unauthenticated users are treated as belonging only to the implicit 'public' group.
|
||||
// - Any row for group name 'public' grants access to everyone.
|
||||
// Path matching uses prefix match on path_prefix.
|
||||
func (s *Service) HasReadAccess(userID *int64, path string) (bool, error) {
|
||||
// First check if any permission rows exist for this path prefix
|
||||
var total int
|
||||
err := s.DB.QueryRow(`
|
||||
SELECT COUNT(1)
|
||||
FROM permissions p
|
||||
WHERE ? LIKE p.path_prefix || '%'
|
||||
`, path).Scan(&total)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
// Default allow when no permissions defined for this path
|
||||
if total == 0 {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// If user is nil (public), allow if any matching permission for group 'public' with can_read
|
||||
if userID == nil {
|
||||
var exists int
|
||||
err := s.DB.QueryRow(`
|
||||
SELECT 1
|
||||
FROM permissions p
|
||||
JOIN groups g ON g.id = p.group_id
|
||||
WHERE p.can_read = 1
|
||||
AND ? LIKE p.path_prefix || '%'
|
||||
AND g.name = 'public'
|
||||
LIMIT 1
|
||||
`, path).Scan(&exists)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// For authenticated users, allow if any matching row with can_read and group's either 'public' or user is member
|
||||
var exists int
|
||||
err = s.DB.QueryRow(`
|
||||
SELECT 1
|
||||
FROM permissions p
|
||||
JOIN groups g ON g.id = p.group_id
|
||||
LEFT JOIN user_groups ug ON ug.group_id = g.id AND ug.user_id = ?
|
||||
WHERE p.can_read = 1
|
||||
AND ? LIKE p.path_prefix || '%'
|
||||
AND (g.name = 'public' OR ug.user_id IS NOT NULL)
|
||||
LIMIT 1
|
||||
`, *userID, path).Scan(&exists)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// IsUserInGroup returns true if the given user is a member of the given group name.
|
||||
func (s *Service) IsUserInGroup(userID int64, groupName string) (bool, error) {
|
||||
var exists int
|
||||
err := s.DB.QueryRow(`
|
||||
SELECT 1
|
||||
FROM user_groups ug
|
||||
JOIN groups g ON g.id = ug.group_id
|
||||
WHERE ug.user_id = ? AND g.name = ?
|
||||
LIMIT 1
|
||||
`, userID, groupName).Scan(&exists)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
type User struct {
|
||||
ID int64
|
||||
Username string
|
||||
Email string
|
||||
PasswordHash string
|
||||
IsActive bool
|
||||
EmailConfirmed bool
|
||||
MFASecret sql.NullString
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
func Open(cfg *config.Config) (*Service, error) {
|
||||
if strings.ToLower(cfg.DBType) != "sqlite" {
|
||||
return nil, fmt.Errorf("unsupported DB type: %s", cfg.DBType)
|
||||
}
|
||||
dsn := cfg.DBPath
|
||||
db, err := sql.Open("sqlite", dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := db.Exec("PRAGMA foreign_keys = ON;"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s := &Service{DB: db, Config: cfg}
|
||||
if err := s.migrate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.ensureDefaultAdmin(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *Service) migrate() error {
|
||||
stmts := []string{
|
||||
`CREATE TABLE IF NOT EXISTS users (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
username TEXT UNIQUE NOT NULL,
|
||||
email TEXT UNIQUE NOT NULL,
|
||||
password_hash TEXT NOT NULL,
|
||||
is_active INTEGER NOT NULL DEFAULT 1,
|
||||
email_confirmed INTEGER NOT NULL DEFAULT 0,
|
||||
mfa_secret TEXT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
)` ,
|
||||
`CREATE TABLE IF NOT EXISTS groups (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT UNIQUE NOT NULL
|
||||
)` ,
|
||||
`CREATE TABLE IF NOT EXISTS user_groups (
|
||||
user_id INTEGER NOT NULL,
|
||||
group_id INTEGER NOT NULL,
|
||||
PRIMARY KEY(user_id, group_id),
|
||||
FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY(group_id) REFERENCES groups(id) ON DELETE CASCADE
|
||||
)` ,
|
||||
`CREATE TABLE IF NOT EXISTS permissions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
group_id INTEGER NOT NULL,
|
||||
path_prefix TEXT NOT NULL,
|
||||
can_read INTEGER NOT NULL DEFAULT 1,
|
||||
can_write INTEGER NOT NULL DEFAULT 0,
|
||||
can_delete INTEGER NOT NULL DEFAULT 0,
|
||||
FOREIGN KEY(group_id) REFERENCES groups(id) ON DELETE CASCADE
|
||||
)` ,
|
||||
`CREATE TABLE IF NOT EXISTS email_verification_tokens (
|
||||
user_id INTEGER NOT NULL,
|
||||
token TEXT NOT NULL UNIQUE,
|
||||
expires_at DATETIME NOT NULL,
|
||||
PRIMARY KEY(user_id),
|
||||
FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE
|
||||
)` ,
|
||||
`CREATE TABLE IF NOT EXISTS password_reset_tokens (
|
||||
user_id INTEGER NOT NULL,
|
||||
token TEXT NOT NULL UNIQUE,
|
||||
expires_at DATETIME NOT NULL,
|
||||
PRIMARY KEY(user_id),
|
||||
FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE
|
||||
)` ,
|
||||
`CREATE TABLE IF NOT EXISTS mfa_enrollments (
|
||||
user_id INTEGER PRIMARY KEY,
|
||||
secret TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE
|
||||
)` ,
|
||||
}
|
||||
for _, stmt := range stmts {
|
||||
if _, err := s.DB.Exec(stmt); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) ensureDefaultAdmin() error {
|
||||
tx, err := s.DB.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
// Ensure groups exist
|
||||
if _, err = tx.Exec(`INSERT OR IGNORE INTO groups (name) VALUES (?), (?)`, "admin", "public"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensure admin user exists
|
||||
var adminID int64
|
||||
err = tx.QueryRow(`SELECT id FROM users WHERE username = ?`, "admin").Scan(&adminID)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
pwHash, e := bcrypt.GenerateFromPassword([]byte("admin"), bcrypt.DefaultCost)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
res, e := tx.Exec(`INSERT INTO users (username, email, password_hash, is_active, email_confirmed) VALUES (?,?,?,?,?)`,
|
||||
"admin", "admin@local", string(pwHash), 1, 1,
|
||||
)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
adminID, err = res.LastInsertId()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure admin group id
|
||||
var adminGroupID int64
|
||||
if err = tx.QueryRow(`SELECT id FROM groups WHERE name = ?`, "admin").Scan(&adminGroupID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensure membership admin -> admin group
|
||||
if _, err = tx.Exec(`INSERT OR IGNORE INTO user_groups (user_id, group_id) VALUES (?, ?)`, adminID, adminGroupID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Commit
|
||||
if err = tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) Authenticate(usernameOrEmail, password string) (*User, error) {
|
||||
u := &User{}
|
||||
row := s.DB.QueryRow(`SELECT id, username, email, password_hash, is_active, email_confirmed, mfa_secret, created_at, updated_at FROM users WHERE username = ? OR email = ?`, usernameOrEmail, usernameOrEmail)
|
||||
var mfa sql.NullString
|
||||
if err := row.Scan(&u.ID, &u.Username, &u.Email, &u.PasswordHash, &u.IsActive, &u.EmailConfirmed, &mfa, &u.CreatedAt, &u.UpdatedAt); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, errors.New("invalid credentials")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(u.PasswordHash), []byte(password)); err != nil {
|
||||
return nil, errors.New("invalid credentials")
|
||||
}
|
||||
u.MFASecret = mfa
|
||||
if !u.IsActive {
|
||||
return nil, errors.New("account not active")
|
||||
}
|
||||
if s.Config.RequireEmailConfirmation && !u.EmailConfirmed {
|
||||
return nil, errors.New("email not confirmed")
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// CSRF utilities
|
||||
const csrfSessionKey = "csrf_token"
|
||||
const csrfHeader = "X-CSRF-Token"
|
||||
const csrfFormField = "csrf_token"
|
||||
|
||||
func randomToken(n int) (string, error) {
|
||||
b := make([]byte, n)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func IssueCSRF() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
sessAny, exists := c.Get("session")
|
||||
if !exists {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
sess, _ := sessAny.(map[string]interface{})
|
||||
var token string
|
||||
if v, ok := sess[csrfSessionKey].(string); ok && v != "" {
|
||||
token = v
|
||||
} else {
|
||||
var err error
|
||||
token, err = randomToken(32)
|
||||
if err == nil {
|
||||
sess[csrfSessionKey] = token
|
||||
}
|
||||
}
|
||||
c.Set("csrf_token", token)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func RequireCSRF() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead || c.Request.Method == http.MethodOptions {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
var token string
|
||||
// header first
|
||||
if h := c.GetHeader(csrfHeader); h != "" {
|
||||
token = h
|
||||
} else {
|
||||
token = c.PostForm(csrfFormField)
|
||||
}
|
||||
sessAny, exists := c.Get("session")
|
||||
if !exists {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "missing session for CSRF"})
|
||||
return
|
||||
}
|
||||
sess, _ := sessAny.(map[string]interface{})
|
||||
if expected, ok := sess[csrfSessionKey].(string); !ok || expected == "" || token == "" || !hmacEqual(expected, token) {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "invalid CSRF token"})
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func hmacEqual(a, b string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
var v byte
|
||||
for i := 0; i < len(a); i++ {
|
||||
v |= a[i] ^ b[i]
|
||||
}
|
||||
return v == 0
|
||||
}
|
||||
|
||||
// Session helpers
|
||||
const sessionCookieName = "gobsidian_session"
|
||||
|
||||
// AttachSession should be installed early to create a simple session map backed by cookie store in handlers package
|
||||
func AttachSession() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Handlers already use gorilla sessions. Here we just ensure a map exists to carry CSRF and auth info for templates.
|
||||
if _, exists := c.Get("session"); !exists {
|
||||
c.Set("session", map[string]interface{}{})
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// RequireAuth checks presence of user_id in context
|
||||
func RequireAuth() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if _, exists := c.Get("user_id"); !exists {
|
||||
c.Redirect(http.StatusFound, "/editor/login")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
@@ -35,6 +35,23 @@ type Config struct {
|
||||
ShowFilesInTree bool
|
||||
ShowImagesInFolder bool
|
||||
ShowFilesInFolder bool
|
||||
|
||||
// Database settings
|
||||
DBType string
|
||||
DBPath string
|
||||
|
||||
// Auth settings
|
||||
RequireAdminActivation bool
|
||||
RequireEmailConfirmation bool
|
||||
MFAEnabledByDefault bool
|
||||
|
||||
// Email (SMTP) settings
|
||||
SMTPHost string
|
||||
SMTPPort int
|
||||
SMTPUsername string
|
||||
SMTPPassword string
|
||||
SMTPSender string
|
||||
SMTPUseTLS bool
|
||||
}
|
||||
|
||||
var defaultConfig = map[string]map[string]string{
|
||||
@@ -62,6 +79,23 @@ var defaultConfig = map[string]map[string]string{
|
||||
"SHOW_IMAGES_IN_FOLDER": "true",
|
||||
"SHOW_FILES_IN_FOLDER": "true",
|
||||
},
|
||||
"DATABASE": {
|
||||
"TYPE": "sqlite",
|
||||
"PATH": "data/gobsidian.db",
|
||||
},
|
||||
"AUTH": {
|
||||
"REQUIRE_ADMIN_ACTIVATION": "true",
|
||||
"REQUIRE_EMAIL_CONFIRMATION": "true",
|
||||
"MFA_ENABLED_BY_DEFAULT": "false",
|
||||
},
|
||||
"EMAIL": {
|
||||
"SMTP_HOST": "",
|
||||
"SMTP_PORT": "587",
|
||||
"SMTP_USERNAME": "",
|
||||
"SMTP_PASSWORD": "",
|
||||
"SMTP_SENDER": "",
|
||||
"SMTP_USE_TLS": "true",
|
||||
},
|
||||
}
|
||||
|
||||
func Load() (*Config, error) {
|
||||
@@ -130,6 +164,36 @@ func Load() (*Config, error) {
|
||||
config.ImageStoragePath = filepath.Join(wd, config.ImageStoragePath)
|
||||
}
|
||||
|
||||
// Load DATABASE section
|
||||
dbSection := cfg.Section("DATABASE")
|
||||
config.DBType = strings.ToLower(strings.TrimSpace(dbSection.Key("TYPE").String()))
|
||||
config.DBPath = dbSection.Key("PATH").String()
|
||||
if config.DBType == "sqlite" {
|
||||
if !filepath.IsAbs(config.DBPath) {
|
||||
wd, _ := os.Getwd()
|
||||
config.DBPath = filepath.Join(wd, config.DBPath)
|
||||
}
|
||||
// ensure parent dir exists
|
||||
if err := os.MkdirAll(filepath.Dir(config.DBPath), 0o755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create db directory: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Load AUTH section
|
||||
authSection := cfg.Section("AUTH")
|
||||
config.RequireAdminActivation, _ = authSection.Key("REQUIRE_ADMIN_ACTIVATION").Bool()
|
||||
config.RequireEmailConfirmation, _ = authSection.Key("REQUIRE_EMAIL_CONFIRMATION").Bool()
|
||||
config.MFAEnabledByDefault, _ = authSection.Key("MFA_ENABLED_BY_DEFAULT").Bool()
|
||||
|
||||
// Load EMAIL (SMTP) section
|
||||
emailSection := cfg.Section("EMAIL")
|
||||
config.SMTPHost = emailSection.Key("SMTP_HOST").String()
|
||||
config.SMTPPort, _ = emailSection.Key("SMTP_PORT").Int()
|
||||
config.SMTPUsername = emailSection.Key("SMTP_USERNAME").String()
|
||||
config.SMTPPassword = emailSection.Key("SMTP_PASSWORD").String()
|
||||
config.SMTPSender = emailSection.Key("SMTP_SENDER").String()
|
||||
config.SMTPUseTLS, _ = emailSection.Key("SMTP_USE_TLS").Bool()
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
@@ -265,6 +329,39 @@ func (c *Config) SaveSetting(section, key, value string) error {
|
||||
case "SHOW_FILES_IN_FOLDER":
|
||||
c.ShowFilesInFolder = value == "true"
|
||||
}
|
||||
case "DATABASE":
|
||||
switch key {
|
||||
case "TYPE":
|
||||
c.DBType = strings.ToLower(strings.TrimSpace(value))
|
||||
case "PATH":
|
||||
c.DBPath = value
|
||||
}
|
||||
case "AUTH":
|
||||
switch key {
|
||||
case "REQUIRE_ADMIN_ACTIVATION":
|
||||
c.RequireAdminActivation = value == "true"
|
||||
case "REQUIRE_EMAIL_CONFIRMATION":
|
||||
c.RequireEmailConfirmation = value == "true"
|
||||
case "MFA_ENABLED_BY_DEFAULT":
|
||||
c.MFAEnabledByDefault = value == "true"
|
||||
}
|
||||
case "EMAIL":
|
||||
switch key {
|
||||
case "SMTP_HOST":
|
||||
c.SMTPHost = value
|
||||
case "SMTP_PORT":
|
||||
if v, err := strconv.Atoi(value); err == nil {
|
||||
c.SMTPPort = v
|
||||
}
|
||||
case "SMTP_USERNAME":
|
||||
c.SMTPUsername = value
|
||||
case "SMTP_PASSWORD":
|
||||
c.SMTPPassword = value
|
||||
case "SMTP_SENDER":
|
||||
c.SMTPSender = value
|
||||
case "SMTP_USE_TLS":
|
||||
c.SMTPUseTLS = value == "true"
|
||||
}
|
||||
}
|
||||
|
||||
return cfg.SaveTo(configPath)
|
||||
|
||||
274
internal/handlers/auth.go
Normal file
274
internal/handlers/auth.go
Normal file
@@ -0,0 +1,274 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha1"
|
||||
"crypto/subtle"
|
||||
"crypto/rand"
|
||||
"encoding/base32"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const sessionCookieName = "gobsidian_session"
|
||||
|
||||
// LoginPage renders the login form
|
||||
func (h *Handlers) LoginPage(c *gin.Context) {
|
||||
token, _ := c.Get("csrf_token")
|
||||
c.HTML(http.StatusOK, "login", gin.H{
|
||||
"app_name": h.config.AppName,
|
||||
"csrf_token": token,
|
||||
"ContentTemplate": "login_content",
|
||||
"ScriptsTemplate": "login_scripts",
|
||||
"Page": "login",
|
||||
})
|
||||
}
|
||||
|
||||
// isAllDigits returns true if s consists only of ASCII digits 0-9
|
||||
func isAllDigits(s string) bool {
|
||||
if s == "" {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < len(s); i++ {
|
||||
if s[i] < '0' || s[i] > '9' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// MFALoginPage shows OTP prompt when MFA is enabled
|
||||
func (h *Handlers) MFALoginPage(c *gin.Context) {
|
||||
session, _ := h.store.Get(c.Request, sessionCookieName)
|
||||
if _, ok := session.Values["mfa_user_id"]; !ok {
|
||||
c.Redirect(http.StatusFound, "/editor/login")
|
||||
return
|
||||
}
|
||||
token, _ := c.Get("csrf_token")
|
||||
c.HTML(http.StatusOK, "mfa", gin.H{
|
||||
"app_name": h.config.AppName,
|
||||
"csrf_token": token,
|
||||
"ContentTemplate": "mfa_content",
|
||||
"ScriptsTemplate": "mfa_scripts",
|
||||
"Page": "mfa",
|
||||
})
|
||||
}
|
||||
|
||||
// MFALoginVerify verifies OTP and completes login
|
||||
func (h *Handlers) MFALoginVerify(c *gin.Context) {
|
||||
code := strings.TrimSpace(c.PostForm("code"))
|
||||
if len(code) != 6 || !isAllDigits(code) {
|
||||
token, _ := c.Get("csrf_token")
|
||||
c.HTML(http.StatusUnauthorized, "mfa", gin.H{
|
||||
"app_name": h.config.AppName,
|
||||
"csrf_token": token,
|
||||
"error": "Invalid code format",
|
||||
"ContentTemplate": "mfa_content",
|
||||
"ScriptsTemplate": "mfa_scripts",
|
||||
"Page": "mfa",
|
||||
})
|
||||
return
|
||||
}
|
||||
session, _ := h.store.Get(c.Request, sessionCookieName)
|
||||
uidAny, ok := session.Values["mfa_user_id"]
|
||||
if !ok {
|
||||
c.Redirect(http.StatusFound, "/editor/login")
|
||||
return
|
||||
}
|
||||
uid, _ := uidAny.(int64)
|
||||
var secret string
|
||||
if err := h.authSvc.DB.QueryRow(`SELECT mfa_secret FROM users WHERE id = ?`, uid).Scan(&secret); err != nil || secret == "" {
|
||||
c.HTML(http.StatusUnauthorized, "mfa", gin.H{"error": "MFA not enabled", "Page": "mfa", "ContentTemplate": "mfa_content", "ScriptsTemplate": "mfa_scripts", "app_name": h.config.AppName})
|
||||
return
|
||||
}
|
||||
if !verifyTOTP(secret, code, time.Now()) {
|
||||
token, _ := c.Get("csrf_token")
|
||||
c.HTML(http.StatusUnauthorized, "mfa", gin.H{
|
||||
"app_name": h.config.AppName,
|
||||
"csrf_token": token,
|
||||
"error": "Invalid code",
|
||||
"ContentTemplate": "mfa_content",
|
||||
"ScriptsTemplate": "mfa_scripts",
|
||||
"Page": "mfa",
|
||||
})
|
||||
return
|
||||
}
|
||||
// success: set user_id and clear mfa_user_id
|
||||
delete(session.Values, "mfa_user_id")
|
||||
session.Values["user_id"] = uid
|
||||
_ = session.Save(c.Request, c.Writer)
|
||||
c.Redirect(http.StatusFound, "/")
|
||||
}
|
||||
|
||||
// ProfileMFASetupPage shows QR and input to verify during enrollment
|
||||
func (h *Handlers) ProfileMFASetupPage(c *gin.Context) {
|
||||
uidPtr := getUserIDPtr(c)
|
||||
if uidPtr == nil {
|
||||
c.Redirect(http.StatusFound, "/editor/login")
|
||||
return
|
||||
}
|
||||
// ensure enrollment exists, otherwise create one
|
||||
var secret string
|
||||
err := h.authSvc.DB.QueryRow(`SELECT secret FROM mfa_enrollments WHERE user_id = ?`, *uidPtr).Scan(&secret)
|
||||
if err != nil || secret == "" {
|
||||
// create new enrollment
|
||||
s, e := generateBase32Secret()
|
||||
if e != nil {
|
||||
c.HTML(http.StatusInternalServerError, "error", gin.H{"error": "Failed to create enrollment", "Page": "error", "ContentTemplate": "error_content", "ScriptsTemplate": "error_scripts", "app_name": h.config.AppName})
|
||||
return
|
||||
}
|
||||
_, _ = h.authSvc.DB.Exec(`INSERT OR REPLACE INTO mfa_enrollments (user_id, secret) VALUES (?, ?)`, *uidPtr, s)
|
||||
secret = s
|
||||
}
|
||||
// Fetch username for label
|
||||
var username string
|
||||
_ = h.authSvc.DB.QueryRow(`SELECT username FROM users WHERE id = ?`, *uidPtr).Scan(&username)
|
||||
issuer := h.config.AppName
|
||||
label := url.PathEscape(fmt.Sprintf("%s:%s", issuer, username))
|
||||
otpauth := fmt.Sprintf("otpauth://totp/%s?secret=%s&issuer=%s&digits=6&period=30&algorithm=SHA1", label, secret, url.QueryEscape(issuer))
|
||||
|
||||
// Render simple page (uses base.html shell)
|
||||
c.HTML(http.StatusOK, "mfa_setup", gin.H{
|
||||
"app_name": h.config.AppName,
|
||||
"Secret": secret,
|
||||
"OTPAuthURI": otpauth,
|
||||
"ContentTemplate": "mfa_setup_content",
|
||||
"ScriptsTemplate": "mfa_setup_scripts",
|
||||
"Page": "mfa_setup",
|
||||
})
|
||||
}
|
||||
|
||||
// ProfileMFASetupVerify finalizes enrollment by verifying a TOTP code and setting mfa_secret
|
||||
func (h *Handlers) ProfileMFASetupVerify(c *gin.Context) {
|
||||
uidPtr := getUserIDPtr(c)
|
||||
if uidPtr == nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "not authenticated"})
|
||||
return
|
||||
}
|
||||
code := strings.TrimSpace(c.PostForm("code"))
|
||||
if len(code) != 6 || !isAllDigits(code) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid code format"})
|
||||
return
|
||||
}
|
||||
var secret string
|
||||
if err := h.authSvc.DB.QueryRow(`SELECT secret FROM mfa_enrollments WHERE user_id = ?`, *uidPtr).Scan(&secret); err != nil || secret == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "no enrollment found"})
|
||||
return
|
||||
}
|
||||
if !verifyTOTP(secret, code, time.Now()) {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid code"})
|
||||
return
|
||||
}
|
||||
// move secret to users and delete enrollment
|
||||
if _, err := h.authSvc.DB.Exec(`UPDATE users SET mfa_secret = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?`, secret, *uidPtr); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
_, _ = h.authSvc.DB.Exec(`DELETE FROM mfa_enrollments WHERE user_id = ?`, *uidPtr)
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// TOTP helpers (SHA1, 30s window, 6 digits)
|
||||
func generateBase32Secret() (string, error) {
|
||||
// 20 random bytes -> base32 without padding
|
||||
b := make([]byte, 20)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
enc := base32.StdEncoding.WithPadding(base32.NoPadding)
|
||||
return enc.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func hotp(secret []byte, counter uint64) string {
|
||||
// HMAC-SHA1
|
||||
var buf [8]byte
|
||||
binary.BigEndian.PutUint64(buf[:], counter)
|
||||
mac := hmac.New(sha1.New, secret)
|
||||
mac.Write(buf[:])
|
||||
sum := mac.Sum(nil)
|
||||
// dynamic truncation
|
||||
offset := sum[len(sum)-1] & 0x0F
|
||||
code := (uint32(sum[offset])&0x7F)<<24 | (uint32(sum[offset+1])&0xFF)<<16 | (uint32(sum[offset+2])&0xFF)<<8 | (uint32(sum[offset+3]) & 0xFF)
|
||||
return fmt.Sprintf("%06d", code%1000000)
|
||||
}
|
||||
|
||||
func verifyTOTP(base32Secret, code string, t time.Time) bool {
|
||||
enc := base32.StdEncoding.WithPadding(base32.NoPadding)
|
||||
key, err := enc.DecodeString(strings.ToUpper(base32Secret))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
timestep := uint64(t.Unix() / 30)
|
||||
// allow +/- 1 window
|
||||
candidates := []string{
|
||||
hotp(key, timestep-1),
|
||||
hotp(key, timestep),
|
||||
hotp(key, timestep+1),
|
||||
}
|
||||
for _, c := range candidates {
|
||||
if subtle.ConstantTimeCompare([]byte(c), []byte(code)) == 1 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// LoginPost processes the login form
|
||||
func (h *Handlers) LoginPost(c *gin.Context) {
|
||||
username := c.PostForm("username")
|
||||
password := c.PostForm("password")
|
||||
|
||||
user, err := h.authSvc.Authenticate(username, password)
|
||||
if err != nil {
|
||||
token, _ := c.Get("csrf_token")
|
||||
c.HTML(http.StatusUnauthorized, "login", gin.H{
|
||||
"app_name": h.config.AppName,
|
||||
"csrf_token": token,
|
||||
"error": err.Error(),
|
||||
"ContentTemplate": "login_content",
|
||||
"ScriptsTemplate": "login_scripts",
|
||||
"Page": "login",
|
||||
})
|
||||
return
|
||||
}
|
||||
// If user has MFA enabled, require OTP before setting user_id
|
||||
if user.MFASecret.Valid && user.MFASecret.String != "" {
|
||||
session, _ := h.store.Get(c.Request, sessionCookieName)
|
||||
session.Values["mfa_user_id"] = user.ID
|
||||
_ = session.Save(c.Request, c.Writer)
|
||||
c.Redirect(http.StatusFound, "/editor/mfa")
|
||||
return
|
||||
}
|
||||
|
||||
// If admin created an enrollment for this user, force MFA setup after login
|
||||
var pending int
|
||||
if err := h.authSvc.DB.QueryRow(`SELECT 1 FROM mfa_enrollments WHERE user_id = ?`, user.ID).Scan(&pending); err == nil {
|
||||
// normal login, then redirect to setup
|
||||
session, _ := h.store.Get(c.Request, sessionCookieName)
|
||||
session.Values["user_id"] = user.ID
|
||||
_ = session.Save(c.Request, c.Writer)
|
||||
c.Redirect(http.StatusFound, "/editor/profile/mfa/setup")
|
||||
return
|
||||
}
|
||||
|
||||
// Create normal session
|
||||
session, _ := h.store.Get(c.Request, sessionCookieName)
|
||||
session.Values["user_id"] = user.ID
|
||||
_ = session.Save(c.Request, c.Writer)
|
||||
|
||||
c.Redirect(http.StatusFound, "/")
|
||||
}
|
||||
|
||||
// LogoutPost clears the session
|
||||
func (h *Handlers) LogoutPost(c *gin.Context) {
|
||||
session, _ := h.store.Get(c.Request, sessionCookieName)
|
||||
session.Options.MaxAge = -1
|
||||
_ = session.Save(c.Request, c.Writer)
|
||||
c.Redirect(http.StatusFound, "/editor/login")
|
||||
}
|
||||
@@ -1,18 +1,25 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"crypto/rand"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/h2non/filetype"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
"gobsidian/internal/config"
|
||||
"gobsidian/internal/auth"
|
||||
"gobsidian/internal/markdown"
|
||||
"gobsidian/internal/models"
|
||||
"gobsidian/internal/utils"
|
||||
@@ -22,144 +29,569 @@ type Handlers struct {
|
||||
config *config.Config
|
||||
store *sessions.CookieStore
|
||||
renderer *markdown.Renderer
|
||||
authSvc *auth.Service
|
||||
}
|
||||
|
||||
// ProfilePage renders the user profile page for the signed-in user
|
||||
func (h *Handlers) ProfilePage(c *gin.Context) {
|
||||
// Must be authenticated; middleware ensures user_id is set
|
||||
uidPtr := getUserIDPtr(c)
|
||||
if uidPtr == nil {
|
||||
c.Redirect(http.StatusFound, "/editor/login")
|
||||
return
|
||||
}
|
||||
|
||||
// Load notes tree for sidebar
|
||||
notesTree, err := utils.BuildTreeStructure(h.config.NotesDir, h.config.NotesDirHideSidepane, h.config)
|
||||
if err != nil {
|
||||
c.HTML(http.StatusInternalServerError, "error", gin.H{
|
||||
"error": "Failed to build tree structure",
|
||||
"app_name": h.config.AppName,
|
||||
"message": err.Error(),
|
||||
"ContentTemplate": "error_content",
|
||||
"ScriptsTemplate": "error_scripts",
|
||||
"Page": "error",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Fetch current user basic info
|
||||
var email string
|
||||
var mfa sql.NullString
|
||||
row := h.authSvc.DB.QueryRow(`SELECT email, mfa_secret FROM users WHERE id = ?`, *uidPtr)
|
||||
if err := row.Scan(&email, &mfa); err != nil {
|
||||
c.HTML(http.StatusInternalServerError, "error", gin.H{
|
||||
"error": "Failed to load profile",
|
||||
"app_name": h.config.AppName,
|
||||
"message": err.Error(),
|
||||
"ContentTemplate": "error_content",
|
||||
"ScriptsTemplate": "error_scripts",
|
||||
"Page": "error",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.HTML(http.StatusOK, "profile", gin.H{
|
||||
"app_name": h.config.AppName,
|
||||
"notes_tree": notesTree,
|
||||
"active_path": []string{},
|
||||
"current_note": nil,
|
||||
"breadcrumbs": utils.GenerateBreadcrumbs(""),
|
||||
"Authenticated": true,
|
||||
"IsAdmin": isAdmin(c),
|
||||
"Email": email,
|
||||
"MFAEnabled": mfa.Valid && mfa.String != "",
|
||||
"ContentTemplate": "profile_content",
|
||||
"ScriptsTemplate": "profile_scripts",
|
||||
"Page": "profile",
|
||||
})
|
||||
}
|
||||
|
||||
// PostProfileChangePassword allows the user to change their password with current password verification
|
||||
func (h *Handlers) PostProfileChangePassword(c *gin.Context) {
|
||||
uidPtr := getUserIDPtr(c)
|
||||
if uidPtr == nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
current := c.PostForm("current_password")
|
||||
newpw := c.PostForm("new_password")
|
||||
confirm := c.PostForm("confirm_password")
|
||||
if current == "" || newpw == "" || confirm == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "all password fields are required"})
|
||||
return
|
||||
}
|
||||
if newpw != confirm {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "new password and confirmation do not match"})
|
||||
return
|
||||
}
|
||||
if len(newpw) < 8 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "password must be at least 8 characters"})
|
||||
return
|
||||
}
|
||||
|
||||
var pwHash string
|
||||
row := h.authSvc.DB.QueryRow(`SELECT password_hash FROM users WHERE id = ?`, *uidPtr)
|
||||
if err := row.Scan(&pwHash); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to load user"})
|
||||
return
|
||||
}
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(pwHash), []byte(current)); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "current password is incorrect"})
|
||||
return
|
||||
}
|
||||
// Hash new password
|
||||
newHashBytes, err := bcrypt.GenerateFromPassword([]byte(newpw), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to hash password"})
|
||||
return
|
||||
}
|
||||
if _, err := h.authSvc.DB.Exec(`UPDATE users SET password_hash = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?`, string(newHashBytes), *uidPtr); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// PostProfileChangeEmail allows the user to update their email
|
||||
func (h *Handlers) PostProfileChangeEmail(c *gin.Context) {
|
||||
uidPtr := getUserIDPtr(c)
|
||||
if uidPtr == nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "not authenticated"})
|
||||
return
|
||||
}
|
||||
email := strings.TrimSpace(c.PostForm("email"))
|
||||
if email == "" || !strings.Contains(email, "@") {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid email"})
|
||||
return
|
||||
}
|
||||
if _, err := h.authSvc.DB.Exec(`UPDATE users SET email = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?`, email, *uidPtr); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// PostProfileEnableMFA generates and stores a new MFA secret for the user
|
||||
func (h *Handlers) PostProfileEnableMFA(c *gin.Context) {
|
||||
uidPtr := getUserIDPtr(c)
|
||||
if uidPtr == nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "not authenticated"})
|
||||
return
|
||||
}
|
||||
// Create or replace enrollment for this user
|
||||
secret, err := generateBase32Secret()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate secret"})
|
||||
return
|
||||
}
|
||||
if _, err := h.authSvc.DB.Exec(`INSERT OR REPLACE INTO mfa_enrollments (user_id, secret) VALUES (?, ?)`, *uidPtr, secret); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "setup": true, "redirect": "/editor/profile/mfa/setup"})
|
||||
}
|
||||
|
||||
// PostProfileDisableMFA clears the user's MFA secret
|
||||
func (h *Handlers) PostProfileDisableMFA(c *gin.Context) {
|
||||
uidPtr := getUserIDPtr(c)
|
||||
if uidPtr == nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "not authenticated"})
|
||||
return
|
||||
}
|
||||
if _, err := h.authSvc.DB.Exec(`UPDATE users SET mfa_secret = NULL, updated_at = CURRENT_TIMESTAMP WHERE id = ?`, *uidPtr); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// AdminCreateUser creates a new user (admin only)
|
||||
func (h *Handlers) AdminCreateUser(c *gin.Context) {
|
||||
username := strings.TrimSpace(c.PostForm("username"))
|
||||
email := strings.TrimSpace(c.PostForm("email"))
|
||||
password := c.PostForm("password")
|
||||
if username == "" || email == "" || password == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "username, email and password are required"})
|
||||
return
|
||||
}
|
||||
// hash password
|
||||
pwHashBytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to hash password"})
|
||||
return
|
||||
}
|
||||
if _, err := h.authSvc.DB.Exec(`INSERT INTO users (username, email, password_hash, is_active, email_confirmed) VALUES (?,?,?,?,?)`,
|
||||
username, email, string(pwHashBytes), 1, 1,
|
||||
); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// AdminDeleteUser deletes a user by id (admin only)
|
||||
func (h *Handlers) AdminDeleteUser(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil || id <= 0 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid user id"})
|
||||
return
|
||||
}
|
||||
// prevent deleting own account
|
||||
if v, ok := c.Get("user_id"); ok {
|
||||
if uid, ok2 := v.(int64); ok2 && uid == id {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "cannot delete your own account"})
|
||||
return
|
||||
}
|
||||
}
|
||||
if _, err := h.authSvc.DB.Exec(`DELETE FROM user_groups WHERE user_id = ?`, id); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if _, err := h.authSvc.DB.Exec(`DELETE FROM users WHERE id = ?`, id); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// AdminSetUserActive enables or disables a user account
|
||||
func (h *Handlers) AdminSetUserActive(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil || id <= 0 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid user id"})
|
||||
return
|
||||
}
|
||||
activeStr := c.PostForm("active")
|
||||
if activeStr == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "missing active value"})
|
||||
return
|
||||
}
|
||||
var activeInt int64
|
||||
if activeStr == "1" || strings.EqualFold(activeStr, "true") {
|
||||
activeInt = 1
|
||||
} else if activeStr == "0" || strings.EqualFold(activeStr, "false") {
|
||||
activeInt = 0
|
||||
} else {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "active must be 0/1 or true/false"})
|
||||
return
|
||||
}
|
||||
if _, err := h.authSvc.DB.Exec(`UPDATE users SET is_active = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?`, activeInt, id); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// AdminDisableUserMFA clears mfa_secret to disable MFA
|
||||
func (h *Handlers) AdminDisableUserMFA(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil || id <= 0 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid user id"})
|
||||
return
|
||||
}
|
||||
if _, err := h.authSvc.DB.Exec(`UPDATE users SET mfa_secret = NULL, updated_at = CURRENT_TIMESTAMP WHERE id = ?`, id); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// AdminResetUserMFA resets MFA by clearing secret (user must re-enroll)
|
||||
func (h *Handlers) AdminResetUserMFA(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil || id <= 0 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid user id"})
|
||||
return
|
||||
}
|
||||
if _, err := h.authSvc.DB.Exec(`UPDATE users SET mfa_secret = NULL, updated_at = CURRENT_TIMESTAMP WHERE id = ?`, id); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// AdminEnableUserMFA generates a new MFA secret for the user
|
||||
func (h *Handlers) AdminEnableUserMFA(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil || id <= 0 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid user id"})
|
||||
return
|
||||
}
|
||||
// Create or replace an enrollment so user is prompted on next login
|
||||
secret, err := generateBase32Secret()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate secret"})
|
||||
return
|
||||
}
|
||||
if _, err := h.authSvc.DB.Exec(`INSERT OR REPLACE INTO mfa_enrollments (user_id, secret) VALUES (?, ?)`, id, secret); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// generateSecret returns a URL-safe random string
|
||||
func generateSecret() (string, error) {
|
||||
// reuse auth.randomToken-style but local implementation
|
||||
b := make([]byte, 20)
|
||||
if _, err := io.ReadFull(randReader{}, b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// randReader wraps crypto/rand.Reader to satisfy io.Reader in a context where imports are at top
|
||||
type randReader struct{}
|
||||
|
||||
func (randReader) Read(p []byte) (int, error) { return rand.Read(p) }
|
||||
|
||||
// AdminCreateGroup creates a group
|
||||
func (h *Handlers) AdminCreateGroup(c *gin.Context) {
|
||||
name := strings.TrimSpace(c.PostForm("name"))
|
||||
if name == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "group name required"})
|
||||
return
|
||||
}
|
||||
if _, err := h.authSvc.DB.Exec(`INSERT OR IGNORE INTO groups (name) VALUES (?)`, name); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// AdminDeleteGroup deletes a group by id
|
||||
func (h *Handlers) AdminDeleteGroup(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil || id <= 0 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid group id"})
|
||||
return
|
||||
}
|
||||
if _, err := h.authSvc.DB.Exec(`DELETE FROM user_groups WHERE group_id = ?`, id); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if _, err := h.authSvc.DB.Exec(`DELETE FROM groups WHERE id = ?`, id); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// AdminAddUserToGroup links a user to a group
|
||||
func (h *Handlers) AdminAddUserToGroup(c *gin.Context) {
|
||||
userIDStr := c.PostForm("user_id")
|
||||
groupIDStr := c.PostForm("group_id")
|
||||
userID, err1 := strconv.ParseInt(userIDStr, 10, 64)
|
||||
groupID, err2 := strconv.ParseInt(groupIDStr, 10, 64)
|
||||
if err1 != nil || err2 != nil || userID <= 0 || groupID <= 0 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid user_id or group_id"})
|
||||
return
|
||||
}
|
||||
if _, err := h.authSvc.DB.Exec(`INSERT OR IGNORE INTO user_groups (user_id, group_id) VALUES (?, ?)`, userID, groupID); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// AdminRemoveUserFromGroup unlinks a user from a group
|
||||
func (h *Handlers) AdminRemoveUserFromGroup(c *gin.Context) {
|
||||
userIDStr := c.PostForm("user_id")
|
||||
groupIDStr := c.PostForm("group_id")
|
||||
userID, err1 := strconv.ParseInt(userIDStr, 10, 64)
|
||||
groupID, err2 := strconv.ParseInt(groupIDStr, 10, 64)
|
||||
if err1 != nil || err2 != nil || userID <= 0 || groupID <= 0 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid user_id or group_id"})
|
||||
return
|
||||
}
|
||||
if _, err := h.authSvc.DB.Exec(`DELETE FROM user_groups WHERE user_id = ? AND group_id = ?`, userID, groupID); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// isAuthenticated returns true if a user_id exists in the Gin context
|
||||
func isAuthenticated(c *gin.Context) bool {
|
||||
_, ok := c.Get("user_id")
|
||||
return ok
|
||||
}
|
||||
|
||||
// getUserIDPtr returns a pointer to user_id from context or nil if unauthenticated
|
||||
func getUserIDPtr(c *gin.Context) *int64 {
|
||||
if v, ok := c.Get("user_id"); ok {
|
||||
if id, ok2 := v.(int64); ok2 {
|
||||
return &id
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// isAdmin returns true if the Gin context has is_admin flag set by middleware
|
||||
func isAdmin(c *gin.Context) bool {
|
||||
_, ok := c.Get("is_admin")
|
||||
return ok
|
||||
}
|
||||
|
||||
// EditTextPageHandler renders an editor for allowed text files (json, html, xml, yaml, etc.)
|
||||
func (h *Handlers) EditTextPageHandler(c *gin.Context) {
|
||||
filePath := strings.TrimPrefix(c.Param("path"), "/")
|
||||
filePath := strings.TrimPrefix(c.Param("path"), "/")
|
||||
|
||||
// Security check
|
||||
if strings.Contains(filePath, "..") {
|
||||
c.HTML(http.StatusBadRequest, "error", gin.H{
|
||||
"error": "Invalid path",
|
||||
"app_name": h.config.AppName,
|
||||
"message": "Path traversal is not allowed",
|
||||
"ContentTemplate": "error_content",
|
||||
"ScriptsTemplate": "error_scripts",
|
||||
"Page": "error",
|
||||
})
|
||||
return
|
||||
}
|
||||
// Security check
|
||||
if strings.Contains(filePath, "..") {
|
||||
c.HTML(http.StatusBadRequest, "error", gin.H{
|
||||
"error": "Invalid path",
|
||||
"app_name": h.config.AppName,
|
||||
"message": "Path traversal is not allowed",
|
||||
"ContentTemplate": "error_content",
|
||||
"ScriptsTemplate": "error_scripts",
|
||||
"Page": "error",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
fullPath := filepath.Join(h.config.NotesDir, filePath)
|
||||
// Access control
|
||||
if allowed, err := h.authSvc.HasReadAccess(getUserIDPtr(c), filePath); err != nil {
|
||||
c.HTML(http.StatusInternalServerError, "error", gin.H{
|
||||
"error": "Permission check failed",
|
||||
"app_name": h.config.AppName,
|
||||
"message": err.Error(),
|
||||
"ContentTemplate": "error_content",
|
||||
"ScriptsTemplate": "error_scripts",
|
||||
"Page": "error",
|
||||
})
|
||||
return
|
||||
} else if !allowed {
|
||||
c.HTML(http.StatusForbidden, "error", gin.H{
|
||||
"error": "Access denied",
|
||||
"app_name": h.config.AppName,
|
||||
"message": "You do not have permission to view this file",
|
||||
"ContentTemplate": "error_content",
|
||||
"ScriptsTemplate": "error_scripts",
|
||||
"Page": "error",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure file exists
|
||||
if _, err := os.Stat(fullPath); os.IsNotExist(err) {
|
||||
c.HTML(http.StatusNotFound, "error", gin.H{
|
||||
"error": "File not found",
|
||||
"app_name": h.config.AppName,
|
||||
"message": "The requested file does not exist",
|
||||
"ContentTemplate": "error_content",
|
||||
"ScriptsTemplate": "error_scripts",
|
||||
"Page": "error",
|
||||
})
|
||||
return
|
||||
}
|
||||
fullPath := filepath.Join(h.config.NotesDir, filePath)
|
||||
|
||||
// Only allow editing of configured text file types (not markdown here)
|
||||
ext := filepath.Ext(fullPath)
|
||||
ftype := models.GetFileType(ext, h.config.AllowedImageExtensions, h.config.AllowedFileExtensions)
|
||||
if ftype != models.FileTypeText {
|
||||
c.HTML(http.StatusForbidden, "error", gin.H{
|
||||
"error": "Editing not allowed",
|
||||
"app_name": h.config.AppName,
|
||||
"message": "This file type cannot be edited here",
|
||||
"ContentTemplate": "error_content",
|
||||
"ScriptsTemplate": "error_scripts",
|
||||
"Page": "error",
|
||||
})
|
||||
return
|
||||
}
|
||||
// Ensure file exists
|
||||
if _, err := os.Stat(fullPath); os.IsNotExist(err) {
|
||||
c.HTML(http.StatusNotFound, "error", gin.H{
|
||||
"error": "File not found",
|
||||
"app_name": h.config.AppName,
|
||||
"message": "The requested file does not exist",
|
||||
"ContentTemplate": "error_content",
|
||||
"ScriptsTemplate": "error_scripts",
|
||||
"Page": "error",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Load content
|
||||
data, err := os.ReadFile(fullPath)
|
||||
if err != nil {
|
||||
c.HTML(http.StatusInternalServerError, "error", gin.H{
|
||||
"error": "Failed to read file",
|
||||
"app_name": h.config.AppName,
|
||||
"message": err.Error(),
|
||||
"ContentTemplate": "error_content",
|
||||
"ScriptsTemplate": "error_scripts",
|
||||
"Page": "error",
|
||||
})
|
||||
return
|
||||
}
|
||||
// Only allow editing of configured text file types (not markdown here)
|
||||
ext := filepath.Ext(fullPath)
|
||||
ftype := models.GetFileType(ext, h.config.AllowedImageExtensions, h.config.AllowedFileExtensions)
|
||||
if ftype != models.FileTypeText {
|
||||
c.HTML(http.StatusForbidden, "error", gin.H{
|
||||
"error": "Editing not allowed",
|
||||
"app_name": h.config.AppName,
|
||||
"message": "This file type cannot be edited here",
|
||||
"ContentTemplate": "error_content",
|
||||
"ScriptsTemplate": "error_scripts",
|
||||
"Page": "error",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Build notes tree
|
||||
notesTree, err := utils.BuildTreeStructure(h.config.NotesDir, h.config.NotesDirHideSidepane, h.config)
|
||||
if err != nil {
|
||||
c.HTML(http.StatusInternalServerError, "error", gin.H{
|
||||
"error": "Failed to build notes tree",
|
||||
"app_name": h.config.AppName,
|
||||
"message": err.Error(),
|
||||
"ContentTemplate": "error_content",
|
||||
"ScriptsTemplate": "error_scripts",
|
||||
"Page": "error",
|
||||
})
|
||||
return
|
||||
}
|
||||
// Load content
|
||||
data, err := os.ReadFile(fullPath)
|
||||
if err != nil {
|
||||
c.HTML(http.StatusInternalServerError, "error", gin.H{
|
||||
"error": "Failed to read file",
|
||||
"app_name": h.config.AppName,
|
||||
"message": err.Error(),
|
||||
"ContentTemplate": "error_content",
|
||||
"ScriptsTemplate": "error_scripts",
|
||||
"Page": "error",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
folderPath := filepath.Dir(filePath)
|
||||
if folderPath == "." {
|
||||
folderPath = ""
|
||||
}
|
||||
// Build notes tree
|
||||
notesTree, err := utils.BuildTreeStructure(h.config.NotesDir, h.config.NotesDirHideSidepane, h.config)
|
||||
if err != nil {
|
||||
c.HTML(http.StatusInternalServerError, "error", gin.H{
|
||||
"error": "Failed to build notes tree",
|
||||
"app_name": h.config.AppName,
|
||||
"message": err.Error(),
|
||||
"ContentTemplate": "error_content",
|
||||
"ScriptsTemplate": "error_scripts",
|
||||
"Page": "error",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.HTML(http.StatusOK, "edit_text", gin.H{
|
||||
"app_name": h.config.AppName,
|
||||
"title": filepath.Base(filePath),
|
||||
"content": string(data),
|
||||
"file_path": filePath,
|
||||
"file_ext": strings.TrimPrefix(strings.ToLower(ext), "."),
|
||||
"folder_path": folderPath,
|
||||
"notes_tree": notesTree,
|
||||
"active_path": utils.GetActivePath(folderPath),
|
||||
"breadcrumbs": utils.GenerateBreadcrumbs(folderPath),
|
||||
"ContentTemplate": "edit_text_content",
|
||||
"ScriptsTemplate": "edit_text_scripts",
|
||||
"Page": "edit_text",
|
||||
})
|
||||
folderPath := filepath.Dir(filePath)
|
||||
if folderPath == "." {
|
||||
folderPath = ""
|
||||
}
|
||||
|
||||
c.HTML(http.StatusOK, "edit_text", gin.H{
|
||||
"app_name": h.config.AppName,
|
||||
"title": filepath.Base(filePath),
|
||||
"content": string(data),
|
||||
"file_path": filePath,
|
||||
"file_ext": strings.TrimPrefix(strings.ToLower(ext), "."),
|
||||
"folder_path": folderPath,
|
||||
"notes_tree": notesTree,
|
||||
"active_path": utils.GetActivePath(folderPath),
|
||||
"breadcrumbs": utils.GenerateBreadcrumbs(folderPath),
|
||||
"Authenticated": isAuthenticated(c),
|
||||
"IsAdmin": isAdmin(c),
|
||||
"ContentTemplate": "edit_text_content",
|
||||
"ScriptsTemplate": "edit_text_scripts",
|
||||
"Page": "edit_text",
|
||||
})
|
||||
}
|
||||
|
||||
// PostEditTextHandler saves changes to an allowed text file
|
||||
func (h *Handlers) PostEditTextHandler(c *gin.Context) {
|
||||
filePath := strings.TrimPrefix(c.Param("path"), "/")
|
||||
filePath := strings.TrimPrefix(c.Param("path"), "/")
|
||||
|
||||
if strings.Contains(filePath, "..") {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid path"})
|
||||
return
|
||||
}
|
||||
if strings.Contains(filePath, "..") {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid path"})
|
||||
return
|
||||
}
|
||||
|
||||
fullPath := filepath.Join(h.config.NotesDir, filePath)
|
||||
// Access control
|
||||
if allowed, err := h.authSvc.HasReadAccess(getUserIDPtr(c), filePath); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Permission check failed"})
|
||||
return
|
||||
} else if !allowed {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Access denied"})
|
||||
return
|
||||
}
|
||||
|
||||
// Enforce allowed file type
|
||||
ext := filepath.Ext(fullPath)
|
||||
ftype := models.GetFileType(ext, h.config.AllowedImageExtensions, h.config.AllowedFileExtensions)
|
||||
if ftype != models.FileTypeText {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "This file type cannot be edited"})
|
||||
return
|
||||
}
|
||||
fullPath := filepath.Join(h.config.NotesDir, filePath)
|
||||
|
||||
content := c.PostForm("content")
|
||||
// Enforce allowed file type
|
||||
ext := filepath.Ext(fullPath)
|
||||
ftype := models.GetFileType(ext, h.config.AllowedImageExtensions, h.config.AllowedFileExtensions)
|
||||
if ftype != models.FileTypeText {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "This file type cannot be edited"})
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure parent directory exists
|
||||
if err := os.MkdirAll(filepath.Dir(fullPath), 0o755); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create parent directory"})
|
||||
return
|
||||
}
|
||||
content := c.PostForm("content")
|
||||
|
||||
if err := os.WriteFile(fullPath, []byte(content), 0o644); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to save file"})
|
||||
return
|
||||
}
|
||||
// Ensure parent directory exists
|
||||
if err := os.MkdirAll(filepath.Dir(fullPath), 0o755); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create parent directory"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "redirect": "/view_text/" + filePath})
|
||||
if err := os.WriteFile(fullPath, []byte(content), 0o644); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to save file"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "redirect": "/view_text/" + filePath})
|
||||
}
|
||||
|
||||
func New(cfg *config.Config, store *sessions.CookieStore) *Handlers {
|
||||
func New(cfg *config.Config, store *sessions.CookieStore, authSvc *auth.Service) *Handlers {
|
||||
return &Handlers{
|
||||
config: cfg,
|
||||
store: store,
|
||||
renderer: markdown.NewRenderer(cfg),
|
||||
authSvc: authSvc,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -198,6 +630,29 @@ func (h *Handlers) IndexHandler(c *gin.Context) {
|
||||
|
||||
fmt.Printf("DEBUG: Tree structure built, app_name: %s\n", h.config.AppName)
|
||||
|
||||
// Access control
|
||||
if allowed, err := h.authSvc.HasReadAccess(getUserIDPtr(c), ""); err != nil {
|
||||
c.HTML(http.StatusInternalServerError, "error", gin.H{
|
||||
"error": "Permission check failed",
|
||||
"app_name": h.config.AppName,
|
||||
"message": err.Error(),
|
||||
"ContentTemplate": "error_content",
|
||||
"ScriptsTemplate": "error_scripts",
|
||||
"Page": "error",
|
||||
})
|
||||
return
|
||||
} else if !allowed {
|
||||
c.HTML(http.StatusForbidden, "error", gin.H{
|
||||
"error": "Access denied",
|
||||
"app_name": h.config.AppName,
|
||||
"message": "You do not have permission to view this folder",
|
||||
"ContentTemplate": "error_content",
|
||||
"ScriptsTemplate": "error_scripts",
|
||||
"Page": "error",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.HTML(http.StatusOK, "folder", gin.H{
|
||||
"app_name": h.config.AppName,
|
||||
"folder_path": "",
|
||||
@@ -208,6 +663,8 @@ func (h *Handlers) IndexHandler(c *gin.Context) {
|
||||
"breadcrumbs": utils.GenerateBreadcrumbs(""),
|
||||
"allowed_image_extensions": h.config.AllowedImageExtensions,
|
||||
"allowed_file_extensions": h.config.AllowedFileExtensions,
|
||||
"Authenticated": isAuthenticated(c),
|
||||
"IsAdmin": isAdmin(c),
|
||||
"ContentTemplate": "folder_content",
|
||||
"ScriptsTemplate": "folder_scripts",
|
||||
"Page": "folder",
|
||||
@@ -243,6 +700,29 @@ func (h *Handlers) FolderHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Access control
|
||||
if allowed, err := h.authSvc.HasReadAccess(getUserIDPtr(c), folderPath); err != nil {
|
||||
c.HTML(http.StatusInternalServerError, "error", gin.H{
|
||||
"error": "Permission check failed",
|
||||
"app_name": h.config.AppName,
|
||||
"message": err.Error(),
|
||||
"ContentTemplate": "error_content",
|
||||
"ScriptsTemplate": "error_scripts",
|
||||
"Page": "error",
|
||||
})
|
||||
return
|
||||
} else if !allowed {
|
||||
c.HTML(http.StatusForbidden, "error", gin.H{
|
||||
"error": "Access denied",
|
||||
"app_name": h.config.AppName,
|
||||
"message": "You do not have permission to view this folder",
|
||||
"ContentTemplate": "error_content",
|
||||
"ScriptsTemplate": "error_scripts",
|
||||
"Page": "error",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
folderContents, err := utils.GetFolderContents(folderPath, h.config)
|
||||
if err != nil {
|
||||
c.HTML(http.StatusNotFound, "error", gin.H{
|
||||
@@ -279,6 +759,8 @@ func (h *Handlers) FolderHandler(c *gin.Context) {
|
||||
"breadcrumbs": utils.GenerateBreadcrumbs(folderPath),
|
||||
"allowed_image_extensions": h.config.AllowedImageExtensions,
|
||||
"allowed_file_extensions": h.config.AllowedFileExtensions,
|
||||
"Authenticated": isAuthenticated(c),
|
||||
"IsAdmin": isAdmin(c),
|
||||
"ContentTemplate": "folder_content",
|
||||
"ScriptsTemplate": "folder_scripts",
|
||||
"Page": "folder",
|
||||
@@ -326,6 +808,29 @@ func (h *Handlers) NoteHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Access control
|
||||
if allowed, err := h.authSvc.HasReadAccess(getUserIDPtr(c), notePath); err != nil {
|
||||
c.HTML(http.StatusInternalServerError, "error", gin.H{
|
||||
"error": "Permission check failed",
|
||||
"app_name": h.config.AppName,
|
||||
"message": err.Error(),
|
||||
"ContentTemplate": "error_content",
|
||||
"ScriptsTemplate": "error_scripts",
|
||||
"Page": "error",
|
||||
})
|
||||
return
|
||||
} else if !allowed {
|
||||
c.HTML(http.StatusForbidden, "error", gin.H{
|
||||
"error": "Access denied",
|
||||
"app_name": h.config.AppName,
|
||||
"message": "You do not have permission to view this note",
|
||||
"ContentTemplate": "error_content",
|
||||
"ScriptsTemplate": "error_scripts",
|
||||
"Page": "error",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
fullPath := filepath.Join(h.config.NotesDir, notePath)
|
||||
|
||||
if _, err := os.Stat(fullPath); os.IsNotExist(err) {
|
||||
@@ -395,6 +900,8 @@ func (h *Handlers) NoteHandler(c *gin.Context) {
|
||||
"active_path": utils.GetActivePath(folderPath),
|
||||
"current_note": notePath,
|
||||
"breadcrumbs": utils.GenerateBreadcrumbs(folderPath),
|
||||
"Authenticated": isAuthenticated(c),
|
||||
"IsAdmin": isAdmin(c),
|
||||
"ContentTemplate": "note_content",
|
||||
"ScriptsTemplate": "note_scripts",
|
||||
"Page": "note",
|
||||
@@ -430,6 +937,15 @@ func (h *Handlers) ServeAttachedImageHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Access control
|
||||
if allowed, err := h.authSvc.HasReadAccess(getUserIDPtr(c), imagePath); err != nil {
|
||||
c.AbortWithStatus(http.StatusInternalServerError)
|
||||
return
|
||||
} else if !allowed {
|
||||
c.AbortWithStatus(http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
if !models.IsImageFile(filepath.Base(imagePath), h.config.AllowedImageExtensions) {
|
||||
c.AbortWithStatus(http.StatusForbidden)
|
||||
return
|
||||
@@ -464,6 +980,9 @@ func (h *Handlers) ServeStoredImageHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Access control (stored images referenced by pathless filenames are assumed public unless permissions exist for the referencing path)
|
||||
// We cannot infer the note path here, so we allow by default per policy.
|
||||
|
||||
c.File(fullPath)
|
||||
}
|
||||
|
||||
@@ -476,6 +995,15 @@ func (h *Handlers) DownloadHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Access control
|
||||
if allowed, err := h.authSvc.HasReadAccess(getUserIDPtr(c), filePath); err != nil {
|
||||
c.AbortWithStatus(http.StatusInternalServerError)
|
||||
return
|
||||
} else if !allowed {
|
||||
c.AbortWithStatus(http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
fullPath := filepath.Join(h.config.NotesDir, filePath)
|
||||
|
||||
if _, err := os.Stat(fullPath); os.IsNotExist(err) {
|
||||
@@ -504,6 +1032,29 @@ func (h *Handlers) ViewTextHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Access control
|
||||
if allowed, err := h.authSvc.HasReadAccess(getUserIDPtr(c), filePath); err != nil {
|
||||
c.HTML(http.StatusInternalServerError, "error", gin.H{
|
||||
"error": "Permission check failed",
|
||||
"app_name": h.config.AppName,
|
||||
"message": err.Error(),
|
||||
"ContentTemplate": "error_content",
|
||||
"ScriptsTemplate": "error_scripts",
|
||||
"Page": "error",
|
||||
})
|
||||
return
|
||||
} else if !allowed {
|
||||
c.HTML(http.StatusForbidden, "error", gin.H{
|
||||
"error": "Access denied",
|
||||
"app_name": h.config.AppName,
|
||||
"message": "You do not have permission to view this file",
|
||||
"ContentTemplate": "error_content",
|
||||
"ScriptsTemplate": "error_scripts",
|
||||
"Page": "error",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
fullPath := filepath.Join(h.config.NotesDir, filePath)
|
||||
|
||||
if _, err := os.Stat(fullPath); os.IsNotExist(err) {
|
||||
@@ -578,6 +1129,8 @@ func (h *Handlers) ViewTextHandler(c *gin.Context) {
|
||||
"notes_tree": notesTree,
|
||||
"active_path": utils.GetActivePath(folderPath),
|
||||
"breadcrumbs": utils.GenerateBreadcrumbs(folderPath),
|
||||
"Authenticated": isAuthenticated(c),
|
||||
"IsAdmin": isAdmin(c),
|
||||
"ContentTemplate": "view_text_content",
|
||||
"ScriptsTemplate": "view_text_scripts",
|
||||
"Page": "view_text",
|
||||
@@ -686,6 +1239,88 @@ func (h *Handlers) TreeAPIHandler(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, notesTree)
|
||||
}
|
||||
|
||||
// AdminPage renders a simple admin dashboard listing users, groups, and permissions
|
||||
func (h *Handlers) AdminPage(c *gin.Context) {
|
||||
// Query users
|
||||
users := make([]auth.User, 0, 32)
|
||||
if rows, err := h.authSvc.DB.Query(`SELECT id, username, email, password_hash, is_active, email_confirmed, mfa_secret, created_at, updated_at FROM users ORDER BY username`); err == nil {
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var u auth.User
|
||||
var mfa sql.NullString
|
||||
if err := rows.Scan(&u.ID, &u.Username, &u.Email, &u.PasswordHash, &u.IsActive, &u.EmailConfirmed, &mfa, &u.CreatedAt, &u.UpdatedAt); err == nil {
|
||||
u.MFASecret = mfa
|
||||
users = append(users, u)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Query groups
|
||||
type Group struct{ ID int64; Name string }
|
||||
groups := make([]Group, 0, 16)
|
||||
if gr, err := h.authSvc.DB.Query(`SELECT id, name FROM groups ORDER BY name`); err == nil {
|
||||
defer gr.Close()
|
||||
for gr.Next() {
|
||||
var g Group
|
||||
if err := gr.Scan(&g.ID, &g.Name); err == nil {
|
||||
groups = append(groups, g)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Query permissions
|
||||
type Permission struct {
|
||||
Group string
|
||||
Path string
|
||||
CanRead bool
|
||||
CanWrite bool
|
||||
CanDelete bool
|
||||
}
|
||||
perms := make([]Permission, 0, 64)
|
||||
if pr, err := h.authSvc.DB.Query(`
|
||||
SELECT g.name, p.path_prefix, p.can_read, p.can_write, p.can_delete
|
||||
FROM permissions p
|
||||
JOIN groups g ON g.id = p.group_id
|
||||
ORDER BY g.name, p.path_prefix`); err == nil {
|
||||
defer pr.Close()
|
||||
for pr.Next() {
|
||||
var name, path string
|
||||
var r, w, d int
|
||||
if err := pr.Scan(&name, &path, &r, &w, &d); err == nil {
|
||||
perms = append(perms, Permission{Group: name, Path: path, CanRead: r == 1, CanWrite: w == 1, CanDelete: d == 1})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build tree for sidebar
|
||||
notesTree, _ := utils.BuildTreeStructure(h.config.NotesDir, h.config.NotesDirHideSidepane, h.config)
|
||||
|
||||
// current user id for UI restrictions (e.g., prevent self-delete)
|
||||
var currentUserID int64
|
||||
if v, ok := c.Get("user_id"); ok {
|
||||
if id, ok2 := v.(int64); ok2 {
|
||||
currentUserID = id
|
||||
}
|
||||
}
|
||||
|
||||
c.HTML(http.StatusOK, "admin", gin.H{
|
||||
"app_name": h.config.AppName,
|
||||
"notes_tree": notesTree,
|
||||
"active_path": []string{},
|
||||
"current_note": nil,
|
||||
"breadcrumbs": utils.GenerateBreadcrumbs(""),
|
||||
"Authenticated": true,
|
||||
"IsAdmin": true,
|
||||
"users": users,
|
||||
"groups": groups,
|
||||
"permissions": perms,
|
||||
"CurrentUserID": currentUserID,
|
||||
"ContentTemplate": "admin_content",
|
||||
"ScriptsTemplate": "admin_scripts",
|
||||
"Page": "admin",
|
||||
})
|
||||
}
|
||||
|
||||
// SearchHandler performs a simple full-text search across markdown and allowed text files
|
||||
// within the notes directory, honoring skipped directories.
|
||||
// GET /api/search?q=term
|
||||
|
||||
107
internal/server/middleware.go
Normal file
107
internal/server/middleware.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const csrfSessionKey = "csrf_token"
|
||||
|
||||
func (s *Server) randomToken(n int) (string, error) {
|
||||
b := make([]byte, n)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// SessionUser loads gorilla session and exposes user_id and csrf token to context
|
||||
func (s *Server) SessionUser() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
sess, _ := s.store.Get(c.Request, "gobsidian_session")
|
||||
if v, ok := sess.Values["user_id"].(int64); ok {
|
||||
c.Set("user_id", v)
|
||||
// derive admin flag
|
||||
if ok, err := s.auth.IsUserInGroup(v, "admin"); err == nil && ok {
|
||||
c.Set("is_admin", true)
|
||||
}
|
||||
}
|
||||
// ensure CSRF token exists in session
|
||||
tok, _ := sess.Values[csrfSessionKey].(string)
|
||||
if tok == "" {
|
||||
if t, err := s.randomToken(32); err == nil {
|
||||
sess.Values[csrfSessionKey] = t
|
||||
_ = sess.Save(c.Request, c.Writer)
|
||||
tok = t
|
||||
}
|
||||
}
|
||||
c.Set("csrf_token", tok)
|
||||
// expose CSRF token to client: header + non-HttpOnly cookie
|
||||
if tok != "" {
|
||||
c.Writer.Header().Set("X-CSRF-Token", tok)
|
||||
// cookie accessible to JS (HttpOnly=false). Secure/ SameSite Lax for CSRF
|
||||
http.SetCookie(c.Writer, &http.Cookie{
|
||||
Name: "csrf_token",
|
||||
Value: tok,
|
||||
Path: "/",
|
||||
HttpOnly: false,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// CSRFRequire validates the CSRF token for state-changing requests
|
||||
func (s *Server) CSRFRequire() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead || c.Request.Method == http.MethodOptions {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
sess, _ := s.store.Get(c.Request, "gobsidian_session")
|
||||
expected, _ := sess.Values[csrfSessionKey].(string)
|
||||
var token string
|
||||
if h := c.GetHeader("X-CSRF-Token"); h != "" {
|
||||
token = h
|
||||
} else {
|
||||
token = c.PostForm("csrf_token")
|
||||
}
|
||||
if expected == "" || token == "" || expected != token {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "invalid CSRF token"})
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// RequireAuth enforces authenticated access
|
||||
func (s *Server) RequireAuth() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if _, exists := c.Get("user_id"); !exists {
|
||||
c.Redirect(http.StatusFound, "/editor/login")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// RequireAdmin enforces admin-only access
|
||||
func (s *Server) RequireAdmin() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if _, exists := c.Get("user_id"); !exists {
|
||||
c.Redirect(http.StatusFound, "/editor/login")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
if _, ok := c.Get("is_admin"); !ok {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "admin required"})
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/sessions"
|
||||
|
||||
"gobsidian/internal/auth"
|
||||
"gobsidian/internal/config"
|
||||
"gobsidian/internal/handlers"
|
||||
"gobsidian/internal/models"
|
||||
@@ -18,6 +19,7 @@ type Server struct {
|
||||
config *config.Config
|
||||
router *gin.Engine
|
||||
store *sessions.CookieStore
|
||||
auth *auth.Service
|
||||
}
|
||||
|
||||
func New(cfg *config.Config) *Server {
|
||||
@@ -28,12 +30,22 @@ func New(cfg *config.Config) *Server {
|
||||
router := gin.Default()
|
||||
store := sessions.NewCookieStore([]byte(cfg.SecretKey))
|
||||
|
||||
// Initialize auth service (panic on error during startup)
|
||||
authSvc, err := auth.Open(cfg)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("failed to initialize auth: %w", err))
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
config: cfg,
|
||||
router: router,
|
||||
store: store,
|
||||
auth: authSvc,
|
||||
}
|
||||
|
||||
// Global middlewares: session user + template setup
|
||||
s.router.Use(s.SessionUser())
|
||||
|
||||
s.setupRoutes()
|
||||
s.setupStaticFiles()
|
||||
s.setupTemplates()
|
||||
@@ -62,7 +74,7 @@ func (s *Server) Start() error {
|
||||
}
|
||||
|
||||
func (s *Server) setupRoutes() {
|
||||
h := handlers.New(s.config, s.store)
|
||||
h := handlers.New(s.config, s.store, s.auth)
|
||||
|
||||
// Main routes
|
||||
s.router.GET("/", h.IndexHandler)
|
||||
@@ -74,27 +86,68 @@ func (s *Server) setupRoutes() {
|
||||
s.router.GET("/serve_stored_image/:filename", h.ServeStoredImageHandler)
|
||||
s.router.GET("/download/*path", h.DownloadHandler)
|
||||
s.router.GET("/view_text/*path", h.ViewTextHandler)
|
||||
s.router.GET("/edit_text/*path", h.EditTextPageHandler)
|
||||
s.router.POST("/edit_text/*path", h.PostEditTextHandler)
|
||||
|
||||
// Upload routes
|
||||
s.router.POST("/upload", h.UploadHandler)
|
||||
// Auth routes
|
||||
s.router.GET("/editor/login", h.LoginPage)
|
||||
s.router.POST("/editor/login", s.CSRFRequire(), h.LoginPost)
|
||||
s.router.POST("/editor/logout", s.RequireAuth(), s.CSRFRequire(), h.LogoutPost)
|
||||
// MFA challenge routes (no auth yet, but CSRF)
|
||||
s.router.GET("/editor/mfa", s.CSRFRequire(), h.MFALoginPage)
|
||||
s.router.POST("/editor/mfa", s.CSRFRequire(), h.MFALoginVerify)
|
||||
|
||||
// Settings routes
|
||||
s.router.GET("/settings", h.SettingsPageHandler)
|
||||
s.router.GET("/settings/image_storage", h.GetImageStorageSettingsHandler)
|
||||
s.router.POST("/settings/image_storage", h.PostImageStorageSettingsHandler)
|
||||
s.router.GET("/settings/notes_dir", h.GetNotesDirSettingsHandler)
|
||||
s.router.POST("/settings/notes_dir", h.PostNotesDirSettingsHandler)
|
||||
s.router.GET("/settings/file_extensions", h.GetFileExtensionsSettingsHandler)
|
||||
s.router.POST("/settings/file_extensions", h.PostFileExtensionsSettingsHandler)
|
||||
// New /editor group protected by auth + CSRF
|
||||
editor := s.router.Group("/editor", s.RequireAuth(), s.CSRFRequire())
|
||||
{
|
||||
editor.GET("/create", h.CreateNotePageHandler)
|
||||
editor.POST("/create", h.CreateNoteHandler)
|
||||
editor.GET("/edit/*path", h.EditNotePageHandler)
|
||||
editor.POST("/edit/*path", h.EditNoteHandler)
|
||||
editor.DELETE("/delete/*path", h.DeleteHandler)
|
||||
|
||||
// Editor routes
|
||||
s.router.GET("/create", h.CreateNotePageHandler)
|
||||
s.router.POST("/create", h.CreateNoteHandler)
|
||||
s.router.GET("/edit/*path", h.EditNotePageHandler)
|
||||
s.router.POST("/edit/*path", h.EditNoteHandler)
|
||||
s.router.DELETE("/delete/*path", h.DeleteHandler)
|
||||
// Text editor routes under /editor
|
||||
editor.GET("/edit_text/*path", h.EditTextPageHandler)
|
||||
editor.POST("/edit_text/*path", h.PostEditTextHandler)
|
||||
|
||||
// Upload under /editor (secured)
|
||||
editor.POST("/upload", h.UploadHandler)
|
||||
|
||||
// Settings under /editor
|
||||
editor.GET("/settings", h.SettingsPageHandler)
|
||||
editor.GET("/settings/image_storage", h.GetImageStorageSettingsHandler)
|
||||
editor.POST("/settings/image_storage", h.PostImageStorageSettingsHandler)
|
||||
editor.GET("/settings/notes_dir", h.GetNotesDirSettingsHandler)
|
||||
editor.POST("/settings/notes_dir", h.PostNotesDirSettingsHandler)
|
||||
editor.GET("/settings/file_extensions", h.GetFileExtensionsSettingsHandler)
|
||||
editor.POST("/settings/file_extensions", h.PostFileExtensionsSettingsHandler)
|
||||
|
||||
// Profile
|
||||
editor.GET("/profile", h.ProfilePage)
|
||||
editor.POST("/profile/password", h.PostProfileChangePassword)
|
||||
editor.POST("/profile/email", h.PostProfileChangeEmail)
|
||||
editor.POST("/profile/mfa/enable", h.PostProfileEnableMFA)
|
||||
editor.POST("/profile/mfa/disable", h.PostProfileDisableMFA)
|
||||
// MFA setup during enrollment
|
||||
editor.GET("/profile/mfa/setup", h.ProfileMFASetupPage)
|
||||
editor.POST("/profile/mfa/verify", h.ProfileMFASetupVerify)
|
||||
|
||||
// Admin dashboard
|
||||
editor.GET("/admin", s.RequireAdmin(), h.AdminPage)
|
||||
|
||||
// Admin CRUD API under /editor/admin
|
||||
admin := editor.Group("/admin", s.RequireAdmin())
|
||||
{
|
||||
admin.POST("/users", h.AdminCreateUser)
|
||||
admin.DELETE("/users/:id", h.AdminDeleteUser)
|
||||
admin.POST("/users/:id/active", h.AdminSetUserActive)
|
||||
admin.POST("/users/:id/mfa/enable", h.AdminEnableUserMFA)
|
||||
admin.POST("/users/:id/mfa/disable", h.AdminDisableUserMFA)
|
||||
admin.POST("/users/:id/mfa/reset", h.AdminResetUserMFA)
|
||||
admin.POST("/groups", h.AdminCreateGroup)
|
||||
admin.DELETE("/groups/:id", h.AdminDeleteGroup)
|
||||
admin.POST("/memberships/add", h.AdminAddUserToGroup)
|
||||
admin.POST("/memberships/remove", h.AdminRemoveUserFromGroup)
|
||||
}
|
||||
}
|
||||
|
||||
// API routes
|
||||
s.router.GET("/api/tree", h.TreeAPIHandler)
|
||||
|
||||
Reference in New Issue
Block a user