273 lines
6.3 KiB
Go
273 lines
6.3 KiB
Go
package main
|
|
|
|
import (
|
|
"embed"
|
|
"html/template"
|
|
"io"
|
|
"io/fs"
|
|
"net/http"
|
|
"os"
|
|
"reflect"
|
|
"regexp"
|
|
"strings"
|
|
|
|
"github.com/justinas/nosurf"
|
|
"github.com/pquerna/otp/totp"
|
|
)
|
|
|
|
//go:embed templates/*.html static/*
|
|
var content embed.FS
|
|
|
|
type TemplateData struct {
|
|
CSRFToken string
|
|
PageData interface{}
|
|
Authenticated bool
|
|
AppName string
|
|
Notification string
|
|
NotificationType string
|
|
MFAEnabled bool
|
|
MFASetupInProgress bool
|
|
MFASecret string
|
|
MFAURL string
|
|
}
|
|
|
|
func renderTemplate(w http.ResponseWriter, r *http.Request, tmpl string, pageData interface{}) {
|
|
// Check if user is authenticated
|
|
session, _ := store.Get(r, "session")
|
|
isAuth := false
|
|
if auth, ok := session.Values["authenticated"].(bool); ok {
|
|
isAuth = auth
|
|
}
|
|
|
|
configMu.RLock()
|
|
appName := config.AppName
|
|
mfaEnabled := config.MFASecret != ""
|
|
configMu.RUnlock()
|
|
|
|
td := TemplateData{
|
|
CSRFToken: nosurf.Token(r),
|
|
PageData: pageData,
|
|
Authenticated: isAuth,
|
|
AppName: appName,
|
|
MFAEnabled: mfaEnabled,
|
|
}
|
|
|
|
// Extract notification and type from pageData if it has those fields
|
|
if pageData != nil {
|
|
v := reflect.ValueOf(pageData)
|
|
if v.Kind() == reflect.Struct {
|
|
notifField := v.FieldByName("Notification")
|
|
typeField := v.FieldByName("NotificationType")
|
|
|
|
if notifField.IsValid() && notifField.Kind() == reflect.String {
|
|
td.Notification = notifField.String()
|
|
}
|
|
if typeField.IsValid() && typeField.Kind() == reflect.String {
|
|
td.NotificationType = typeField.String()
|
|
}
|
|
}
|
|
}
|
|
|
|
templatesFS, _ := fs.Sub(content, "templates")
|
|
// Parse base.html, notifications.html and the specific page template together
|
|
files := []string{"base.html", "notifications.html", tmpl}
|
|
t, err := template.ParseFS(templatesFS, files...)
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
// Execute base.html which contains {{ block "content" . }}
|
|
// The specific page template's {{ define "content" }} will override the block
|
|
err = t.ExecuteTemplate(w, "base.html", td)
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
}
|
|
}
|
|
|
|
func copyFile(src, dst string) error {
|
|
source, err := os.Open(src)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer source.Close()
|
|
|
|
destination, err := os.Create(dst)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer destination.Close()
|
|
|
|
_, err = io.Copy(destination, source)
|
|
return err
|
|
}
|
|
|
|
func readLines(file string) ([]string, error) {
|
|
data, err := os.ReadFile(file)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
lines := strings.Split(string(data), "\n")
|
|
var cleaned []string
|
|
for _, line := range lines {
|
|
line = strings.TrimSpace(line)
|
|
if line != "" {
|
|
cleaned = append(cleaned, line)
|
|
}
|
|
}
|
|
return cleaned, nil
|
|
}
|
|
|
|
func writeLines(file string, lines []string) {
|
|
data := strings.Join(lines, "\n")
|
|
os.WriteFile(file, []byte(data), 0644)
|
|
}
|
|
|
|
func mergeUnique(data1, data2 string) string {
|
|
lines1 := strings.Split(data1, "\n")
|
|
lines2 := strings.Split(data2, "\n")
|
|
all := append(lines1, lines2...)
|
|
|
|
// Sort for consistency
|
|
var sorted []string
|
|
seen := make(map[string]bool)
|
|
for _, line := range all {
|
|
line = strings.TrimSpace(line)
|
|
if line == "" || seen[line] {
|
|
continue
|
|
}
|
|
sorted = append(sorted, line)
|
|
seen[line] = true
|
|
}
|
|
|
|
// Sort alphabetically
|
|
for i := 0; i < len(sorted); i++ {
|
|
for j := i + 1; j < len(sorted); j++ {
|
|
if sorted[i] > sorted[j] {
|
|
sorted[i], sorted[j] = sorted[j], sorted[i]
|
|
}
|
|
}
|
|
}
|
|
|
|
return strings.Join(sorted, "\n")
|
|
}
|
|
|
|
func sanitizeInput(input string) string {
|
|
// Basic sanitization, remove dangerous chars
|
|
re := regexp.MustCompile(`[;<>&|]`)
|
|
return re.ReplaceAllString(input, "")
|
|
}
|
|
|
|
func isValidURL(u string) bool {
|
|
return strings.HasPrefix(u, "http://") || strings.HasPrefix(u, "https://")
|
|
}
|
|
|
|
func isValidDomain(d string) bool {
|
|
re := regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9.-]*\.[a-zA-Z]{2,}$`)
|
|
return re.MatchString(d)
|
|
}
|
|
|
|
// URLListItem represents a single URL entry in the CSV
|
|
type URLListItem struct {
|
|
Name string
|
|
Enabled bool
|
|
Group string
|
|
URL string
|
|
Note string
|
|
}
|
|
|
|
// readURLListCSV reads the URL list from a CSV file
|
|
func readURLListCSV(file string) ([]URLListItem, error) {
|
|
data, err := os.ReadFile(file)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
lines := strings.Split(string(data), "\n")
|
|
var items []URLListItem
|
|
|
|
for i, line := range lines {
|
|
line = strings.TrimSpace(line)
|
|
if line == "" || (i == 0 && strings.HasPrefix(line, "Name,")) {
|
|
// Skip empty lines and header
|
|
continue
|
|
}
|
|
|
|
parts := strings.SplitN(line, ",", 5)
|
|
if len(parts) < 2 {
|
|
continue
|
|
}
|
|
|
|
item := URLListItem{
|
|
Name: strings.TrimSpace(parts[0]),
|
|
Enabled: strings.TrimSpace(parts[1]) == "true",
|
|
Group: "",
|
|
URL: "",
|
|
Note: "",
|
|
}
|
|
|
|
if len(parts) > 2 {
|
|
item.Group = strings.TrimSpace(parts[2])
|
|
}
|
|
if len(parts) > 3 {
|
|
item.URL = strings.TrimSpace(parts[3])
|
|
}
|
|
if len(parts) > 4 {
|
|
item.Note = strings.TrimSpace(parts[4])
|
|
}
|
|
|
|
items = append(items, item)
|
|
}
|
|
|
|
return items, nil
|
|
}
|
|
|
|
// writeURLListCSV writes the URL list to a CSV file
|
|
func writeURLListCSV(file string, items []URLListItem) error {
|
|
var sb strings.Builder
|
|
sb.WriteString("Name,Enabled,Group,URL,Note\n")
|
|
|
|
for _, item := range items {
|
|
enabled := "false"
|
|
if item.Enabled {
|
|
enabled = "true"
|
|
}
|
|
// Escape quotes and wrap in quotes if needed
|
|
name := escapeCSVField(item.Name)
|
|
group := escapeCSVField(item.Group)
|
|
url := escapeCSVField(item.URL)
|
|
note := escapeCSVField(item.Note)
|
|
|
|
sb.WriteString(name + "," + enabled + "," + group + "," + url + "," + note + "\n")
|
|
}
|
|
|
|
return os.WriteFile(file, []byte(sb.String()), 0644)
|
|
}
|
|
|
|
// escapeCSVField properly escapes a CSV field
|
|
func escapeCSVField(field string) string {
|
|
if strings.ContainsAny(field, ",\"\n") {
|
|
return "\"" + strings.ReplaceAll(field, "\"", "\"\"") + "\""
|
|
}
|
|
return field
|
|
}
|
|
|
|
func generateSecret() string {
|
|
// This function is kept for backward compatibility but is not currently used
|
|
// since we're using pquerna/otp for TOTP generation
|
|
return ""
|
|
}
|
|
|
|
// validateTOTP validates a TOTP code against a secret
|
|
func validateTOTP(secret, code string) bool {
|
|
return totp.Validate(code, secret)
|
|
}
|
|
|
|
// generateMFASecret generates a new MFA secret and QR code URL
|
|
func generateMFASecret(username, appname string) (string, string) {
|
|
key, _ := totp.Generate(totp.GenerateOpts{
|
|
Issuer: appname,
|
|
AccountName: username,
|
|
})
|
|
return key.Secret(), key.URL()
|
|
}
|