Files
unifi-adblocker/utils.go

273 lines
6.3 KiB
Go

package main
import (
"embed"
"html/template"
"io"
"io/fs"
"net/http"
"os"
"reflect"
"regexp"
"strings"
"github.com/justinas/nosurf"
"github.com/pquerna/otp/totp"
)
//go:embed templates/*.html static/*
var content embed.FS
type TemplateData struct {
CSRFToken string
PageData interface{}
Authenticated bool
AppName string
Notification string
NotificationType string
MFAEnabled bool
MFASetupInProgress bool
MFASecret string
MFAURL string
}
func renderTemplate(w http.ResponseWriter, r *http.Request, tmpl string, pageData interface{}) {
// Check if user is authenticated
session, _ := store.Get(r, "session")
isAuth := false
if auth, ok := session.Values["authenticated"].(bool); ok {
isAuth = auth
}
configMu.RLock()
appName := config.AppName
mfaEnabled := config.MFASecret != ""
configMu.RUnlock()
td := TemplateData{
CSRFToken: nosurf.Token(r),
PageData: pageData,
Authenticated: isAuth,
AppName: appName,
MFAEnabled: mfaEnabled,
}
// Extract notification and type from pageData if it has those fields
if pageData != nil {
v := reflect.ValueOf(pageData)
if v.Kind() == reflect.Struct {
notifField := v.FieldByName("Notification")
typeField := v.FieldByName("NotificationType")
if notifField.IsValid() && notifField.Kind() == reflect.String {
td.Notification = notifField.String()
}
if typeField.IsValid() && typeField.Kind() == reflect.String {
td.NotificationType = typeField.String()
}
}
}
templatesFS, _ := fs.Sub(content, "templates")
// Parse base.html, notifications.html and the specific page template together
files := []string{"base.html", "notifications.html", tmpl}
t, err := template.ParseFS(templatesFS, files...)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Execute base.html which contains {{ block "content" . }}
// The specific page template's {{ define "content" }} will override the block
err = t.ExecuteTemplate(w, "base.html", td)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
func copyFile(src, dst string) error {
source, err := os.Open(src)
if err != nil {
return err
}
defer source.Close()
destination, err := os.Create(dst)
if err != nil {
return err
}
defer destination.Close()
_, err = io.Copy(destination, source)
return err
}
func readLines(file string) ([]string, error) {
data, err := os.ReadFile(file)
if err != nil {
return nil, err
}
lines := strings.Split(string(data), "\n")
var cleaned []string
for _, line := range lines {
line = strings.TrimSpace(line)
if line != "" {
cleaned = append(cleaned, line)
}
}
return cleaned, nil
}
func writeLines(file string, lines []string) {
data := strings.Join(lines, "\n")
os.WriteFile(file, []byte(data), 0644)
}
func mergeUnique(data1, data2 string) string {
lines1 := strings.Split(data1, "\n")
lines2 := strings.Split(data2, "\n")
all := append(lines1, lines2...)
// Sort for consistency
var sorted []string
seen := make(map[string]bool)
for _, line := range all {
line = strings.TrimSpace(line)
if line == "" || seen[line] {
continue
}
sorted = append(sorted, line)
seen[line] = true
}
// Sort alphabetically
for i := 0; i < len(sorted); i++ {
for j := i + 1; j < len(sorted); j++ {
if sorted[i] > sorted[j] {
sorted[i], sorted[j] = sorted[j], sorted[i]
}
}
}
return strings.Join(sorted, "\n")
}
func sanitizeInput(input string) string {
// Basic sanitization, remove dangerous chars
re := regexp.MustCompile(`[;<>&|]`)
return re.ReplaceAllString(input, "")
}
func isValidURL(u string) bool {
return strings.HasPrefix(u, "http://") || strings.HasPrefix(u, "https://")
}
func isValidDomain(d string) bool {
re := regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9.-]*\.[a-zA-Z]{2,}$`)
return re.MatchString(d)
}
// URLListItem represents a single URL entry in the CSV
type URLListItem struct {
Name string
Enabled bool
Group string
URL string
Note string
}
// readURLListCSV reads the URL list from a CSV file
func readURLListCSV(file string) ([]URLListItem, error) {
data, err := os.ReadFile(file)
if err != nil {
return nil, err
}
lines := strings.Split(string(data), "\n")
var items []URLListItem
for i, line := range lines {
line = strings.TrimSpace(line)
if line == "" || (i == 0 && strings.HasPrefix(line, "Name,")) {
// Skip empty lines and header
continue
}
parts := strings.SplitN(line, ",", 5)
if len(parts) < 2 {
continue
}
item := URLListItem{
Name: strings.TrimSpace(parts[0]),
Enabled: strings.TrimSpace(parts[1]) == "true",
Group: "",
URL: "",
Note: "",
}
if len(parts) > 2 {
item.Group = strings.TrimSpace(parts[2])
}
if len(parts) > 3 {
item.URL = strings.TrimSpace(parts[3])
}
if len(parts) > 4 {
item.Note = strings.TrimSpace(parts[4])
}
items = append(items, item)
}
return items, nil
}
// writeURLListCSV writes the URL list to a CSV file
func writeURLListCSV(file string, items []URLListItem) error {
var sb strings.Builder
sb.WriteString("Name,Enabled,Group,URL,Note\n")
for _, item := range items {
enabled := "false"
if item.Enabled {
enabled = "true"
}
// Escape quotes and wrap in quotes if needed
name := escapeCSVField(item.Name)
group := escapeCSVField(item.Group)
url := escapeCSVField(item.URL)
note := escapeCSVField(item.Note)
sb.WriteString(name + "," + enabled + "," + group + "," + url + "," + note + "\n")
}
return os.WriteFile(file, []byte(sb.String()), 0644)
}
// escapeCSVField properly escapes a CSV field
func escapeCSVField(field string) string {
if strings.ContainsAny(field, ",\"\n") {
return "\"" + strings.ReplaceAll(field, "\"", "\"\"") + "\""
}
return field
}
func generateSecret() string {
// This function is kept for backward compatibility but is not currently used
// since we're using pquerna/otp for TOTP generation
return ""
}
// validateTOTP validates a TOTP code against a secret
func validateTOTP(secret, code string) bool {
return totp.Validate(code, secret)
}
// generateMFASecret generates a new MFA secret and QR code URL
func generateMFASecret(username, appname string) (string, string) {
key, _ := totp.Generate(totp.GenerateOpts{
Issuer: appname,
AccountName: username,
})
return key.Secret(), key.URL()
}