Files
unifi-adblocker/handlers.go
T

525 lines
14 KiB
Go

package main
import (
"encoding/json"
"net/http"
"regexp"
"strconv"
"strings"
"golang.org/x/crypto/bcrypt"
)
const sessionName = "session"
func profileHandler(w http.ResponseWriter, r *http.Request) {
configMu.RLock()
mfaEnabled := config.MFASecret != ""
configMu.RUnlock()
session, _ := store.Get(r, sessionName)
data := TemplateData{
MFAEnabled: mfaEnabled,
}
// Check if MFA setup is in progress
if setupSecret, ok := session.Values["mfa_setup_secret"].(string); ok && setupSecret != "" {
if setupURL, ok := session.Values["mfa_setup_url"].(string); ok {
data.MFASetupInProgress = true
data.MFASecret = setupSecret
data.MFAURL = setupURL
}
}
if r.Method == http.MethodPost {
action := r.FormValue("action")
switch action {
case "change_password":
current := r.FormValue("current_password")
newpw := r.FormValue("new_password")
confirmpw := r.FormValue("confirm_new_password")
mfaCode := r.FormValue("mfa_code")
// Validate current password
if bcrypt.CompareHashAndPassword([]byte(config.HashedPass), []byte(current)) != nil {
data.Notification = "Current password is incorrect."
data.NotificationType = "error"
renderTemplate(w, r, "profile.html", data)
return
}
// Check if new passwords match
if newpw != confirmpw {
data.Notification = "New passwords do not match."
data.NotificationType = "error"
renderTemplate(w, r, "profile.html", data)
return
}
// If MFA enabled, validate code
if config.MFASecret != "" {
if !validateTOTP(config.MFASecret, mfaCode) {
data.Notification = "Invalid MFA code."
data.NotificationType = "error"
renderTemplate(w, r, "profile.html", data)
return
}
}
// Validate new password
if len(newpw) < 8 {
data.Notification = "Password must be at least 8 characters."
data.NotificationType = "error"
renderTemplate(w, r, "profile.html", data)
return
}
hash, _ := bcrypt.GenerateFromPassword([]byte(newpw), bcrypt.DefaultCost)
config.HashedPass = string(hash)
saveConfig()
// Invalidate session
session.Options.MaxAge = -1
session.Save(r, w)
data.Notification = "Password changed successfully. Please log in again."
data.NotificationType = "success"
http.Redirect(w, r, "/login", http.StatusSeeOther)
return
case "toggle_mfa":
mfaAction := r.FormValue("mfa")
if mfaAction == "on" && config.MFASecret == "" {
// Generate new MFA secret and store in session
secret, qrURL := generateMFASecret(config.Username, config.AppName)
session.Values["mfa_setup_secret"] = secret
session.Values["mfa_setup_url"] = qrURL
session.Save(r, w)
data.MFASetupInProgress = true
data.MFASecret = secret
data.MFAURL = qrURL
data.Notification = "MFA setup initiated. Please configure your authenticator app."
data.NotificationType = "info"
} else if mfaAction == "off" && config.MFASecret != "" {
config.MFASecret = ""
saveConfig()
data.Notification = "MFA disabled."
data.NotificationType = "success"
data.MFAEnabled = false
}
case "confirm_mfa":
mfaCode := r.FormValue("mfa_code")
if setupSecret, ok := session.Values["mfa_setup_secret"].(string); ok && setupSecret != "" {
if validateTOTP(setupSecret, mfaCode) {
config.MFASecret = setupSecret
saveConfig()
// Clear session
delete(session.Values, "mfa_setup_secret")
delete(session.Values, "mfa_setup_url")
session.Save(r, w)
data.Notification = "MFA enabled successfully."
data.NotificationType = "success"
data.MFAEnabled = true
data.MFASetupInProgress = false
data.MFASecret = ""
data.MFAURL = ""
} else {
data.Notification = "Invalid MFA code. Please try again."
data.NotificationType = "error"
}
}
case "cancel_mfa":
// Clear session
delete(session.Values, "mfa_setup_secret")
delete(session.Values, "mfa_setup_url")
session.Save(r, w)
data.MFASetupInProgress = false
data.MFASecret = ""
data.MFAURL = ""
data.Notification = "MFA setup cancelled."
data.NotificationType = "info"
}
}
renderTemplate(w, r, "profile.html", data)
}
func mfaSetupHandler(w http.ResponseWriter, r *http.Request) {
if config.MFASecret != "" {
http.Error(w, "MFA already enabled", 400)
return
}
secret, qrURL := generateMFASecret(config.Username, config.AppName)
// Store in session
session, _ := store.Get(r, sessionName)
session.Values["mfa_setup_secret"] = secret
session.Values["mfa_setup_url"] = qrURL
session.Save(r, w)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{"secret": secret, "url": qrURL})
}
func dashboardHandler(w http.ResponseWriter, r *http.Request) {
pageData := map[string]interface{}{
"Domains": []string{},
"Query": "",
"Notification": "",
"NotificationType": "",
}
renderTemplate(w, r, "domains.html", pageData)
}
func urlListsHandler(w http.ResponseWriter, r *http.Request) {
configMu.RLock()
urlFileList := config.URLFileList
defaultURLs := config.DefaultURLs
configMu.RUnlock()
if r.Method == "POST" {
action := r.FormValue("action")
items, err := readURLListCSV(urlFileList)
if err != nil {
renderTemplate(w, r, "urllists.html", struct {
Items []URLListItem
Notification string
NotificationType string
}{
Items: items,
Notification: "Error reading URL list",
NotificationType: "error",
})
return
}
var errorMsg string
switch action {
case "add":
name := sanitizeInput(r.FormValue("name"))
urlStr := sanitizeInput(r.FormValue("url"))
group := sanitizeInput(r.FormValue("group"))
note := sanitizeInput(r.FormValue("note"))
if name == "" {
errorMsg = "Name is required"
} else if !isValidURL(urlStr) {
errorMsg = "Invalid URL format. Must start with http:// or https://"
} else {
items = append(items, URLListItem{
Name: name,
Enabled: true,
Group: group,
URL: urlStr,
Note: note,
})
}
case "remove":
indexStr := r.FormValue("index")
index, _ := strconv.Atoi(indexStr)
if index >= 0 && index < len(items) {
items = append(items[:index], items[index+1:]...)
} else {
errorMsg = "Invalid index"
}
case "toggle":
indexStr := r.FormValue("index")
index, _ := strconv.Atoi(indexStr)
if index >= 0 && index < len(items) {
items[index].Enabled = !items[index].Enabled
} else {
errorMsg = "Invalid index"
}
case "edit":
indexStr := r.FormValue("index")
name := sanitizeInput(r.FormValue("name"))
urlStr := sanitizeInput(r.FormValue("url"))
group := sanitizeInput(r.FormValue("group"))
note := sanitizeInput(r.FormValue("note"))
index, _ := strconv.Atoi(indexStr)
if index >= 0 && index < len(items) {
if name == "" {
errorMsg = "Name is required"
} else if !isValidURL(urlStr) {
errorMsg = "Invalid URL format. Must start with http:// or https://"
} else {
items[index].Name = name
items[index].URL = urlStr
items[index].Group = group
items[index].Note = note
}
} else {
errorMsg = "Invalid index"
}
}
if errorMsg != "" {
renderTemplate(w, r, "urllists.html", struct {
Items []URLListItem
Notification string
NotificationType string
}{
Items: items,
Notification: errorMsg,
NotificationType: "error",
})
return
}
err = writeURLListCSV(urlFileList, items)
if err != nil {
renderTemplate(w, r, "urllists.html", struct {
Items []URLListItem
Notification string
NotificationType string
}{
Items: items,
Notification: "Error saving URL list",
NotificationType: "error",
})
return
}
renderTemplate(w, r, "urllists.html", struct {
Items []URLListItem
Notification string
NotificationType string
}{
Items: items,
Notification: "URL list updated successfully",
NotificationType: "success",
})
return
}
items, err := readURLListCSV(urlFileList)
if err != nil || len(items) == 0 {
// Use default blocklists
defaultItems := defaultURLs
writeURLListCSV(urlFileList, defaultItems)
items = defaultItems
}
renderTemplate(w, r, "urllists.html", struct {
Items []URLListItem
}{
Items: items,
})
}
func domainsHandler(w http.ResponseWriter, r *http.Request) {
pageData := map[string]interface{}{
"Domains": []string{},
"Query": "",
"Notification": "",
"NotificationType": "",
}
renderTemplate(w, r, "domains.html", pageData)
}
func whitelistHandler(w http.ResponseWriter, r *http.Request) {
pageData := map[string]interface{}{
"Domains": []string{},
"Query": "",
"Notification": "",
"NotificationType": "",
}
if r.Method == http.MethodPost {
action := r.FormValue("action")
if action == "add_domain" {
domain := strings.TrimSpace(r.FormValue("domain"))
if domain == "" {
pageData["Notification"] = "Domain cannot be empty."
pageData["NotificationType"] = "error"
} else {
configMu.RLock()
whitelistFile := config.WhitelistFile
configMu.RUnlock()
// Read current whitelist
lines, err := readLines(whitelistFile)
if err != nil {
lines = []string{}
}
// Check if already exists
exists := false
for _, d := range lines {
if strings.EqualFold(d, domain) {
exists = true
break
}
}
if exists {
pageData["Notification"] = "Domain already in whitelist."
pageData["NotificationType"] = "error"
} else {
lines = append(lines, domain)
writeLines(whitelistFile, lines)
pageData["Notification"] = "Domain added to whitelist."
pageData["NotificationType"] = "success"
}
}
}
}
renderTemplate(w, r, "whitelist.html", pageData)
}
func searchDomainsHandler(w http.ResponseWriter, r *http.Request) {
configMu.RLock()
blocklistFile := config.BlocklistFile
configMu.RUnlock()
query := strings.TrimSpace(r.FormValue("query"))
if query == "" {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{"domains": []string{}})
return
}
domains, err := readLines(blocklistFile)
if err != nil {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{"domains": []string{}})
return
}
var filtered []string
if strings.Contains(query, "*") {
// Wildcard search
pattern := strings.ReplaceAll(regexp.QuoteMeta(query), "\\*", ".*")
re, err := regexp.Compile("(?i)^" + pattern + "$")
if err == nil {
for _, d := range domains {
if re.MatchString(d) {
filtered = append(filtered, d)
if len(filtered) >= 100 {
break
}
}
}
}
} else {
// Exact match, case insensitive
for _, d := range domains {
if strings.EqualFold(d, query) {
filtered = append(filtered, d)
if len(filtered) >= 100 {
break
}
}
}
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{"domains": filtered})
}
func searchWhitelistHandler(w http.ResponseWriter, r *http.Request) {
configMu.RLock()
removeFile := config.RemoveFile
whitelistFile := config.WhitelistFile
configMu.RUnlock()
query := strings.TrimSpace(r.FormValue("query"))
if query == "" {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{"results": []map[string]string{}})
return
}
// Read both files
domains1, err1 := readLines(removeFile)
domains2, err2 := readLines(whitelistFile)
if err1 != nil && err2 != nil {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{"results": []map[string]string{}})
return
}
// Create maps for quick lookup
removeMap := make(map[string]bool)
for _, d := range domains1 {
removeMap[strings.ToLower(d)] = true
}
whitelistMap := make(map[string]bool)
for _, d := range domains2 {
whitelistMap[strings.ToLower(d)] = true
}
// Combine unique domains
domainMap := make(map[string]bool)
for d := range removeMap {
domainMap[d] = true
}
for d := range whitelistMap {
domainMap[d] = true
}
var filtered []map[string]string
for d := range domainMap {
var source string
inRemove := removeMap[d]
inWhitelist := whitelistMap[d]
if inRemove && inWhitelist {
source = "Both"
} else if inRemove {
source = "Unifi"
} else {
source = "Custom"
}
match := false
if strings.Contains(query, "*") {
// Wildcard search
pattern := strings.ReplaceAll(regexp.QuoteMeta(query), "\\*", ".*")
re, err := regexp.Compile("(?i)^" + pattern + "$")
if err == nil && re.MatchString(d) {
match = true
}
} else {
// Exact match, case insensitive
if strings.EqualFold(d, query) {
match = true
}
}
if match {
filtered = append(filtered, map[string]string{"domain": d, "source": source})
if len(filtered) >= 100 {
break
}
}
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{"results": filtered})
}
func applyHandler(w http.ResponseWriter, r *http.Request) {
userInitiatedMu.Lock()
userInitiated = true
userInitiatedMu.Unlock()
updateBlocklist(1) // Force
userInitiatedMu.Lock()
userInitiated = false
userInitiatedMu.Unlock()
http.Redirect(w, r, "/", http.StatusSeeOther)
}
func logsHandler(w http.ResponseWriter, r *http.Request) {
// For simplicity, assume logs are in stdout, or implement file logging
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write([]byte("Logs placeholder"))
}