161 lines
3.8 KiB
Go
161 lines
3.8 KiB
Go
package credentials
|
|
|
|
import (
|
|
"crypto/aes"
|
|
"crypto/cipher"
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"strings"
|
|
)
|
|
|
|
const prefix = "UMC1|"
|
|
|
|
type Manager struct {
|
|
activeKeyID string
|
|
keys map[string][]byte
|
|
}
|
|
|
|
type payload struct {
|
|
Username string `json:"username"`
|
|
Password string `json:"password"`
|
|
}
|
|
|
|
func NewManager(keysSpec, activeKeyID string) (*Manager, error) {
|
|
keys, err := parseKeys(keysSpec)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(keys) == 0 {
|
|
return nil, errors.New("mail credential keys are required")
|
|
}
|
|
|
|
if activeKeyID == "" {
|
|
for keyID := range keys {
|
|
activeKeyID = keyID
|
|
break
|
|
}
|
|
}
|
|
if _, ok := keys[activeKeyID]; !ok {
|
|
return nil, fmt.Errorf("active credential key id %q not found", activeKeyID)
|
|
}
|
|
|
|
return &Manager{
|
|
activeKeyID: activeKeyID,
|
|
keys: keys,
|
|
}, nil
|
|
}
|
|
|
|
func (m *Manager) Encrypt(username, password string) ([]byte, error) {
|
|
key := m.keys[m.activeKeyID]
|
|
block, err := aes.NewCipher(key)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("new cipher: %w", err)
|
|
}
|
|
|
|
gcm, err := cipher.NewGCM(block)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("new gcm: %w", err)
|
|
}
|
|
|
|
nonce := make([]byte, gcm.NonceSize())
|
|
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
|
return nil, fmt.Errorf("nonce: %w", err)
|
|
}
|
|
|
|
rawPayload, err := json.Marshal(payload{Username: username, Password: password})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("marshal payload: %w", err)
|
|
}
|
|
|
|
ciphertext := gcm.Seal(nil, nonce, rawPayload, nil)
|
|
|
|
serialized := strings.Join([]string{
|
|
prefix[:len(prefix)-1],
|
|
m.activeKeyID,
|
|
base64.StdEncoding.EncodeToString(nonce),
|
|
base64.StdEncoding.EncodeToString(ciphertext),
|
|
}, "|")
|
|
return []byte(serialized), nil
|
|
}
|
|
|
|
func (m *Manager) Decrypt(blob []byte) (string, string, error) {
|
|
parts := strings.Split(string(blob), "|")
|
|
if len(parts) != 4 || parts[0] != strings.TrimSuffix(prefix, "|") {
|
|
return "", "", errors.New("credentials payload is not encrypted with supported format")
|
|
}
|
|
|
|
keyID := parts[1]
|
|
key, ok := m.keys[keyID]
|
|
if !ok {
|
|
return "", "", fmt.Errorf("unknown credential key id %q", keyID)
|
|
}
|
|
|
|
nonce, err := base64.StdEncoding.DecodeString(parts[2])
|
|
if err != nil {
|
|
return "", "", fmt.Errorf("decode nonce: %w", err)
|
|
}
|
|
ciphertext, err := base64.StdEncoding.DecodeString(parts[3])
|
|
if err != nil {
|
|
return "", "", fmt.Errorf("decode ciphertext: %w", err)
|
|
}
|
|
|
|
block, err := aes.NewCipher(key)
|
|
if err != nil {
|
|
return "", "", fmt.Errorf("new cipher: %w", err)
|
|
}
|
|
gcm, err := cipher.NewGCM(block)
|
|
if err != nil {
|
|
return "", "", fmt.Errorf("new gcm: %w", err)
|
|
}
|
|
|
|
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
|
|
if err != nil {
|
|
return "", "", fmt.Errorf("decrypt: %w", err)
|
|
}
|
|
|
|
var p payload
|
|
if err := json.Unmarshal(plaintext, &p); err != nil {
|
|
return "", "", fmt.Errorf("unmarshal payload: %w", err)
|
|
}
|
|
if p.Username == "" || p.Password == "" {
|
|
return "", "", errors.New("decrypted credentials are incomplete")
|
|
}
|
|
return p.Username, p.Password, nil
|
|
}
|
|
|
|
func IsEncrypted(blob []byte) bool {
|
|
return strings.HasPrefix(string(blob), prefix)
|
|
}
|
|
|
|
func parseKeys(spec string) (map[string][]byte, error) {
|
|
entries := strings.Split(spec, ",")
|
|
keys := make(map[string][]byte, len(entries))
|
|
for _, entry := range entries {
|
|
entry = strings.TrimSpace(entry)
|
|
if entry == "" {
|
|
continue
|
|
}
|
|
pair := strings.SplitN(entry, ":", 2)
|
|
if len(pair) != 2 {
|
|
return nil, fmt.Errorf("invalid key entry %q", entry)
|
|
}
|
|
keyID := strings.TrimSpace(pair[0])
|
|
if keyID == "" {
|
|
return nil, errors.New("key id cannot be empty")
|
|
}
|
|
rawKey, err := base64.StdEncoding.DecodeString(strings.TrimSpace(pair[1]))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("decode key %s: %w", keyID, err)
|
|
}
|
|
if l := len(rawKey); l != 16 && l != 24 && l != 32 {
|
|
return nil, fmt.Errorf("key %s must be AES-128/192/256", keyID)
|
|
}
|
|
keys[keyID] = rawKey
|
|
}
|
|
return keys, nil
|
|
}
|