Files
mailgosend/internal/crypto/crypto.go
T
2026-05-21 20:27:58 +00:00

217 lines
6.4 KiB
Go

// Package crypto provides AES-256-GCM encryption, HKDF key derivation,
// bcrypt helpers, and secure random utilities.
// All encryption uses authenticated encryption — tampering is detected.
package crypto
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"crypto/subtle"
"encoding/hex"
"fmt"
"io"
"golang.org/x/crypto/bcrypt"
"golang.org/x/crypto/hkdf"
)
const (
// BcryptCost is the minimum bcrypt work factor.
BcryptCost = 12
// keyLen is AES-256 key length in bytes.
keyLen = 32
// gcmNonceLen is the standard GCM nonce size.
gcmNonceLen = 12
)
// Crypto holds the master encryption key.
// One instance per application, injected everywhere that needs encryption.
type Crypto struct {
masterKey [keyLen]byte
}
// New creates a Crypto instance from a 32-byte master key.
func New(masterKey []byte) (*Crypto, error) {
if len(masterKey) != keyLen {
return nil, fmt.Errorf("crypto: master key must be %d bytes, got %d", keyLen, len(masterKey))
}
c := &Crypto{}
copy(c.masterKey[:], masterKey)
return c, nil
}
// DeriveKey returns a unique 32-byte AES-256 key for a given purpose + userID.
// Uses HKDF-SHA256 so each (purpose, userID) pair gets a unique subkey,
// and the master key is never used directly for encryption.
func (c *Crypto) DeriveKey(purpose string, userID int64) ([keyLen]byte, error) {
var key [keyLen]byte
info := fmt.Sprintf("%s:user:%d", purpose, userID)
r := hkdf.New(sha256.New, c.masterKey[:], nil, []byte(info))
if _, err := io.ReadFull(r, key[:]); err != nil {
return key, fmt.Errorf("hkdf derive: %w", err)
}
return key, nil
}
// DeriveKeyGlobal returns a 32-byte key derived from master key for global use
// (e.g. encrypting DKIM private keys stored per-domain, not per-user).
func (c *Crypto) DeriveKeyGlobal(purpose string) ([keyLen]byte, error) {
var key [keyLen]byte
r := hkdf.New(sha256.New, c.masterKey[:], nil, []byte("global:"+purpose))
if _, err := io.ReadFull(r, key[:]); err != nil {
return key, fmt.Errorf("hkdf derive global: %w", err)
}
return key, nil
}
// Encrypt encrypts plaintext with AES-256-GCM using the provided 32-byte key.
// Returns nonce||ciphertext||tag (nonce prepended, all opaque bytes).
// Returns an error if plaintext is nil (use []byte{} for empty).
func Encrypt(key [keyLen]byte, plaintext []byte) ([]byte, error) {
block, err := aes.NewCipher(key[:])
if err != nil {
return nil, fmt.Errorf("aes new cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("aes gcm: %w", err)
}
nonce := make([]byte, gcmNonceLen)
if _, err := rand.Read(nonce); err != nil {
return nil, fmt.Errorf("rand nonce: %w", err)
}
ciphertext := gcm.Seal(nonce, nonce, plaintext, nil)
return ciphertext, nil
}
// Decrypt decrypts a nonce||ciphertext||tag blob produced by Encrypt.
func Decrypt(key [keyLen]byte, ciphertext []byte) ([]byte, error) {
if len(ciphertext) < gcmNonceLen {
return nil, fmt.Errorf("ciphertext too short")
}
block, err := aes.NewCipher(key[:])
if err != nil {
return nil, fmt.Errorf("aes new cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("aes gcm: %w", err)
}
nonce := ciphertext[:gcmNonceLen]
data := ciphertext[gcmNonceLen:]
plaintext, err := gcm.Open(nil, nonce, data, nil)
if err != nil {
return nil, fmt.Errorf("decrypt: %w", err) // do not leak GCM error details
}
return plaintext, nil
}
// EncryptForUser derives a per-user key and encrypts.
func (c *Crypto) EncryptForUser(userID int64, purpose string, plaintext []byte) ([]byte, error) {
key, err := c.DeriveKey(purpose, userID)
if err != nil {
return nil, err
}
return Encrypt(key, plaintext)
}
// DecryptForUser derives a per-user key and decrypts.
func (c *Crypto) DecryptForUser(userID int64, purpose string, ciphertext []byte) ([]byte, error) {
key, err := c.DeriveKey(purpose, userID)
if err != nil {
return nil, err
}
return Decrypt(key, ciphertext)
}
// EncryptGlobal derives a global key for given purpose and encrypts.
func (c *Crypto) EncryptGlobal(purpose string, plaintext []byte) ([]byte, error) {
key, err := c.DeriveKeyGlobal(purpose)
if err != nil {
return nil, err
}
return Encrypt(key, plaintext)
}
// DecryptGlobal derives a global key for given purpose and decrypts.
func (c *Crypto) DecryptGlobal(purpose string, ciphertext []byte) ([]byte, error) {
key, err := c.DeriveKeyGlobal(purpose)
if err != nil {
return nil, err
}
return Decrypt(key, ciphertext)
}
// ---- Bcrypt ----
// HashPassword hashes a password with bcrypt at cost BcryptCost.
func HashPassword(password string) (string, error) {
if password == "" {
return "", fmt.Errorf("password must not be empty")
}
hash, err := bcrypt.GenerateFromPassword([]byte(password), BcryptCost)
if err != nil {
return "", fmt.Errorf("bcrypt: %w", err)
}
return string(hash), nil
}
// CheckPassword returns nil if password matches the stored bcrypt hash.
// Uses constant-time comparison internally (bcrypt).
func CheckPassword(hash, password string) error {
return bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
}
// ---- Session tokens ----
// NewToken generates a cryptographically random 32-byte token and returns
// (rawToken, sha256HexHash). Store the hash; send the raw token to the client.
func NewToken() (raw string, hash string, err error) {
b := make([]byte, 32)
if _, err = rand.Read(b); err != nil {
return "", "", fmt.Errorf("rand token: %w", err)
}
raw = hex.EncodeToString(b)
hash = HashToken(raw)
return raw, hash, nil
}
// HashToken returns the SHA-256 hex hash of a raw token string.
func HashToken(raw string) string {
h := sha256.Sum256([]byte(raw))
return hex.EncodeToString(h[:])
}
// SecureCompare returns true if a == b using constant-time comparison.
// Use for any comparison where timing attacks are a concern.
func SecureCompare(a, b string) bool {
return subtle.ConstantTimeCompare([]byte(a), []byte(b)) == 1
}
// ---- Random helpers ----
// RandomHex returns n random bytes as a hex string (length 2n).
func RandomHex(n int) (string, error) {
b := make([]byte, n)
if _, err := rand.Read(b); err != nil {
return "", fmt.Errorf("rand: %w", err)
}
return hex.EncodeToString(b), nil
}
// RandomBytes returns n cryptographically random bytes.
func RandomBytes(n int) ([]byte, error) {
b := make([]byte, n)
if _, err := rand.Read(b); err != nil {
return nil, fmt.Errorf("rand: %w", err)
}
return b, nil
}