ultisuite-backend/internal/mail/credentials/manager.go

161 lines
3.8 KiB
Go

package credentials
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"strings"
)
const prefix = "UMC1|"
type Manager struct {
activeKeyID string
keys map[string][]byte
}
type payload struct {
Username string `json:"username"`
Password string `json:"password"`
}
func NewManager(keysSpec, activeKeyID string) (*Manager, error) {
keys, err := parseKeys(keysSpec)
if err != nil {
return nil, err
}
if len(keys) == 0 {
return nil, errors.New("mail credential keys are required")
}
if activeKeyID == "" {
for keyID := range keys {
activeKeyID = keyID
break
}
}
if _, ok := keys[activeKeyID]; !ok {
return nil, fmt.Errorf("active credential key id %q not found", activeKeyID)
}
return &Manager{
activeKeyID: activeKeyID,
keys: keys,
}, nil
}
func (m *Manager) Encrypt(username, password string) ([]byte, error) {
key := m.keys[m.activeKeyID]
block, err := aes.NewCipher(key)
if err != nil {
return nil, fmt.Errorf("new cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("new gcm: %w", err)
}
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return nil, fmt.Errorf("nonce: %w", err)
}
rawPayload, err := json.Marshal(payload{Username: username, Password: password})
if err != nil {
return nil, fmt.Errorf("marshal payload: %w", err)
}
ciphertext := gcm.Seal(nil, nonce, rawPayload, nil)
serialized := strings.Join([]string{
prefix[:len(prefix)-1],
m.activeKeyID,
base64.StdEncoding.EncodeToString(nonce),
base64.StdEncoding.EncodeToString(ciphertext),
}, "|")
return []byte(serialized), nil
}
func (m *Manager) Decrypt(blob []byte) (string, string, error) {
parts := strings.Split(string(blob), "|")
if len(parts) != 4 || parts[0] != strings.TrimSuffix(prefix, "|") {
return "", "", errors.New("credentials payload is not encrypted with supported format")
}
keyID := parts[1]
key, ok := m.keys[keyID]
if !ok {
return "", "", fmt.Errorf("unknown credential key id %q", keyID)
}
nonce, err := base64.StdEncoding.DecodeString(parts[2])
if err != nil {
return "", "", fmt.Errorf("decode nonce: %w", err)
}
ciphertext, err := base64.StdEncoding.DecodeString(parts[3])
if err != nil {
return "", "", fmt.Errorf("decode ciphertext: %w", err)
}
block, err := aes.NewCipher(key)
if err != nil {
return "", "", fmt.Errorf("new cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", "", fmt.Errorf("new gcm: %w", err)
}
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return "", "", fmt.Errorf("decrypt: %w", err)
}
var p payload
if err := json.Unmarshal(plaintext, &p); err != nil {
return "", "", fmt.Errorf("unmarshal payload: %w", err)
}
if p.Username == "" || p.Password == "" {
return "", "", errors.New("decrypted credentials are incomplete")
}
return p.Username, p.Password, nil
}
func IsEncrypted(blob []byte) bool {
return strings.HasPrefix(string(blob), prefix)
}
func parseKeys(spec string) (map[string][]byte, error) {
entries := strings.Split(spec, ",")
keys := make(map[string][]byte, len(entries))
for _, entry := range entries {
entry = strings.TrimSpace(entry)
if entry == "" {
continue
}
pair := strings.SplitN(entry, ":", 2)
if len(pair) != 2 {
return nil, fmt.Errorf("invalid key entry %q", entry)
}
keyID := strings.TrimSpace(pair[0])
if keyID == "" {
return nil, errors.New("key id cannot be empty")
}
rawKey, err := base64.StdEncoding.DecodeString(strings.TrimSpace(pair[1]))
if err != nil {
return nil, fmt.Errorf("decode key %s: %w", keyID, err)
}
if l := len(rawKey); l != 16 && l != 24 && l != 32 {
return nil, fmt.Errorf("key %s must be AES-128/192/256", keyID)
}
keys[keyID] = rawKey
}
return keys, nil
}