ultisuite-backend/internal/mail/credentials/manager.go
2026-05-24 00:03:36 +02:00

185 lines
4.3 KiB
Go

package credentials
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"strings"
)
const prefix = "UMC1|"
type Manager struct {
activeKeyID string
keys map[string][]byte
}
func NewManager(keysSpec, activeKeyID string) (*Manager, error) {
keys, err := parseKeys(keysSpec)
if err != nil {
return nil, err
}
if len(keys) == 0 {
return nil, errors.New("mail credential keys are required")
}
if activeKeyID == "" {
for keyID := range keys {
activeKeyID = keyID
break
}
}
if _, ok := keys[activeKeyID]; !ok {
return nil, fmt.Errorf("active credential key id %q not found", activeKeyID)
}
return &Manager{
activeKeyID: activeKeyID,
keys: keys,
}, nil
}
func (m *Manager) Encrypt(username, password string) ([]byte, error) {
return m.EncryptCredential(Credential{
AuthType: AuthPassword,
Username: username,
Password: password,
})
}
func (m *Manager) EncryptCredential(c Credential) ([]byte, error) {
if c.AuthType == "" {
c.AuthType = AuthPassword
}
rawPayload, err := json.Marshal(c.toStored())
if err != nil {
return nil, fmt.Errorf("marshal payload: %w", err)
}
return m.encryptRaw(rawPayload)
}
func (m *Manager) Decrypt(blob []byte) (string, string, error) {
c, err := m.DecryptCredential(blob)
if err != nil {
return "", "", err
}
if c.IsOAuth() {
return c.Username, c.AccessToken, nil
}
return c.Username, c.Password, nil
}
func (m *Manager) DecryptCredential(blob []byte) (Credential, error) {
plaintext, err := m.decryptRaw(blob)
if err != nil {
return Credential{}, err
}
var p storedPayload
if err := json.Unmarshal(plaintext, &p); err != nil {
return Credential{}, fmt.Errorf("unmarshal payload: %w", err)
}
return storedToCredential(p)
}
func (m *Manager) encryptRaw(rawPayload []byte) ([]byte, error) {
key := m.keys[m.activeKeyID]
block, err := aes.NewCipher(key)
if err != nil {
return nil, fmt.Errorf("new cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("new gcm: %w", err)
}
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return nil, fmt.Errorf("nonce: %w", err)
}
ciphertext := gcm.Seal(nil, nonce, rawPayload, nil)
serialized := strings.Join([]string{
prefix[:len(prefix)-1],
m.activeKeyID,
base64.StdEncoding.EncodeToString(nonce),
base64.StdEncoding.EncodeToString(ciphertext),
}, "|")
return []byte(serialized), nil
}
func (m *Manager) decryptRaw(blob []byte) ([]byte, error) {
parts := strings.Split(string(blob), "|")
if len(parts) != 4 || parts[0] != strings.TrimSuffix(prefix, "|") {
return nil, errors.New("credentials payload is not encrypted with supported format")
}
keyID := parts[1]
key, ok := m.keys[keyID]
if !ok {
return nil, fmt.Errorf("unknown credential key id %q", keyID)
}
nonce, err := base64.StdEncoding.DecodeString(parts[2])
if err != nil {
return nil, fmt.Errorf("decode nonce: %w", err)
}
ciphertext, err := base64.StdEncoding.DecodeString(parts[3])
if err != nil {
return nil, fmt.Errorf("decode ciphertext: %w", err)
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, fmt.Errorf("new cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("new gcm: %w", err)
}
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return nil, fmt.Errorf("decrypt: %w", err)
}
return plaintext, nil
}
func IsEncrypted(blob []byte) bool {
return strings.HasPrefix(string(blob), prefix)
}
func parseKeys(spec string) (map[string][]byte, error) {
entries := strings.Split(spec, ",")
keys := make(map[string][]byte, len(entries))
for _, entry := range entries {
entry = strings.TrimSpace(entry)
if entry == "" {
continue
}
pair := strings.SplitN(entry, ":", 2)
if len(pair) != 2 {
return nil, fmt.Errorf("invalid key entry %q", entry)
}
keyID := strings.TrimSpace(pair[0])
if keyID == "" {
return nil, errors.New("key id cannot be empty")
}
rawKey, err := base64.StdEncoding.DecodeString(strings.TrimSpace(pair[1]))
if err != nil {
return nil, fmt.Errorf("decode key %s: %w", keyID, err)
}
if l := len(rawKey); l != 16 && l != 24 && l != 32 {
return nil, fmt.Errorf("key %s must be AES-128/192/256", keyID)
}
keys[keyID] = rawKey
}
return keys, nil
}