231 lines
5.9 KiB
Go
231 lines
5.9 KiB
Go
package mail
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"strings"
|
|
|
|
"github.com/jackc/pgx/v5"
|
|
|
|
"github.com/ultisuite/ulti-backend/internal/mail/credentials"
|
|
)
|
|
|
|
type accountRow struct {
|
|
id, name, email, provider string
|
|
imapHost, smtpHost string
|
|
imapPort, smtpPort int
|
|
imapTLS, smtpTLS bool
|
|
credentials []byte
|
|
}
|
|
|
|
func (s *Service) loadAccountRow(ctx context.Context, externalID, accountID string) (accountRow, error) {
|
|
var row accountRow
|
|
err := s.db.QueryRow(ctx, `
|
|
SELECT ma.id, ma.name, ma.email, ma.provider,
|
|
ma.imap_host, ma.imap_port, ma.imap_tls,
|
|
ma.smtp_host, ma.smtp_port, ma.smtp_tls,
|
|
ma.credentials
|
|
FROM mail_accounts ma
|
|
JOIN users u ON ma.user_id = u.id
|
|
WHERE ma.id = $1 AND u.external_id = $2
|
|
`, accountID, externalID).Scan(
|
|
&row.id, &row.name, &row.email, &row.provider,
|
|
&row.imapHost, &row.imapPort, &row.imapTLS,
|
|
&row.smtpHost, &row.smtpPort, &row.smtpTLS,
|
|
&row.credentials,
|
|
)
|
|
if err != nil {
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
return accountRow{}, ErrNotFound
|
|
}
|
|
return accountRow{}, err
|
|
}
|
|
return row, nil
|
|
}
|
|
|
|
func accountDetailFromRow(row accountRow, cred credentials.Credential) map[string]any {
|
|
out := map[string]any{
|
|
"id": row.id,
|
|
"name": row.name,
|
|
"email": row.email,
|
|
"provider": row.provider,
|
|
"imap_host": row.imapHost,
|
|
"imap_port": row.imapPort,
|
|
"imap_tls": row.imapTLS,
|
|
"smtp_host": row.smtpHost,
|
|
"smtp_port": row.smtpPort,
|
|
"smtp_tls": row.smtpTLS,
|
|
"auth_type": string(credentials.AuthPassword),
|
|
"username": cred.Username,
|
|
}
|
|
if cred.AuthType != "" {
|
|
out["auth_type"] = string(cred.AuthType)
|
|
}
|
|
if cred.IsOAuth() {
|
|
out["oauth_provider"] = cred.OAuthProvider
|
|
}
|
|
return out
|
|
}
|
|
|
|
func (s *Service) decryptAccountCredential(blob []byte) (credentials.Credential, error) {
|
|
if s.credentials == nil {
|
|
return credentials.Credential{}, ErrCredentialsUnavailable
|
|
}
|
|
if len(blob) == 0 {
|
|
return credentials.Credential{}, nil
|
|
}
|
|
return s.credentials.DecryptCredential(blob)
|
|
}
|
|
|
|
func (s *Service) GetAccount(ctx context.Context, externalID, accountID string) (map[string]any, error) {
|
|
row, err := s.loadAccountRow(ctx, externalID, accountID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
cred, err := s.decryptAccountCredential(row.credentials)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return accountDetailFromRow(row, cred), nil
|
|
}
|
|
|
|
func (s *Service) UpdateAccount(ctx context.Context, externalID, accountID string, req *updateAccountRequest) error {
|
|
if s.credentials == nil {
|
|
return ErrCredentialsUnavailable
|
|
}
|
|
|
|
row, err := s.loadAccountRow(ctx, externalID, accountID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
cred, err := s.decryptAccountCredential(row.credentials)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
name := strings.TrimSpace(req.Name)
|
|
if name == "" {
|
|
name = row.name
|
|
}
|
|
provider := strings.TrimSpace(req.Provider)
|
|
if provider == "" {
|
|
provider = row.provider
|
|
}
|
|
|
|
if cred.IsOAuth() {
|
|
if strings.TrimSpace(req.Password) != "" {
|
|
return ErrOAuthPasswordNotAllowed
|
|
}
|
|
if u := strings.TrimSpace(req.Username); u != "" {
|
|
cred.Username = u
|
|
}
|
|
} else {
|
|
if u := strings.TrimSpace(req.Username); u != "" {
|
|
cred.Username = u
|
|
}
|
|
if p := req.Password; p != "" {
|
|
cred.Password = p
|
|
}
|
|
if strings.TrimSpace(cred.Username) == "" {
|
|
return ErrInvalidAccountCredentials
|
|
}
|
|
}
|
|
|
|
encrypted, err := s.credentials.EncryptCredential(cred)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
result, err := s.db.Exec(ctx, `
|
|
UPDATE mail_accounts ma SET
|
|
name = $1,
|
|
email = $2,
|
|
provider = $3,
|
|
imap_host = $4,
|
|
imap_port = $5,
|
|
imap_tls = $6,
|
|
smtp_host = $7,
|
|
smtp_port = $8,
|
|
smtp_tls = $9,
|
|
credentials = $10,
|
|
updated_at = NOW()
|
|
FROM users u
|
|
WHERE ma.id = $11 AND ma.user_id = u.id AND u.external_id = $12
|
|
`, name, req.Email, provider,
|
|
req.IMAPHost, req.IMAPPort, req.IMAPTLS,
|
|
req.SMTPHost, req.SMTPPort, req.SMTPTLS,
|
|
encrypted, accountID, externalID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if result.RowsAffected() == 0 {
|
|
return ErrNotFound
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func credentialFromTestRequest(req *testAccountRequest) (credentials.Credential, error) {
|
|
if req.AuthType == string(credentials.AuthOAuth2) {
|
|
return credentials.Credential{
|
|
AuthType: credentials.AuthOAuth2,
|
|
Username: strings.TrimSpace(req.Username),
|
|
AccessToken: strings.TrimSpace(req.AccessToken),
|
|
OAuthProvider: strings.TrimSpace(req.OAuthProvider),
|
|
}, nil
|
|
}
|
|
if strings.TrimSpace(req.Password) == "" {
|
|
return credentials.Credential{}, ErrInvalidAccountCredentials
|
|
}
|
|
return credentials.Credential{
|
|
AuthType: credentials.AuthPassword,
|
|
Username: strings.TrimSpace(req.Username),
|
|
Password: req.Password,
|
|
}, nil
|
|
}
|
|
|
|
func mergeTestCredential(stored credentials.Credential, req *testAccountRequest) (credentials.Credential, error) {
|
|
cred := stored
|
|
if u := strings.TrimSpace(req.Username); u != "" {
|
|
cred.Username = u
|
|
}
|
|
if req.AuthType == string(credentials.AuthOAuth2) || cred.IsOAuth() {
|
|
if t := strings.TrimSpace(req.AccessToken); t != "" {
|
|
cred.AccessToken = t
|
|
}
|
|
if strings.TrimSpace(cred.AccessToken) == "" {
|
|
return credentials.Credential{}, ErrInvalidAccountCredentials
|
|
}
|
|
return cred, nil
|
|
}
|
|
if req.Password != "" {
|
|
cred.Password = req.Password
|
|
}
|
|
if strings.TrimSpace(cred.Password) == "" {
|
|
return credentials.Credential{}, ErrInvalidAccountCredentials
|
|
}
|
|
if strings.TrimSpace(cred.Username) == "" {
|
|
return credentials.Credential{}, ErrInvalidAccountCredentials
|
|
}
|
|
return cred, nil
|
|
}
|
|
|
|
func (s *Service) CredentialForConnectionTest(ctx context.Context, externalID string, req *testAccountRequest) (credentials.Credential, error) {
|
|
accountID := strings.TrimSpace(req.AccountID)
|
|
if accountID == "" {
|
|
return credentialFromTestRequest(req)
|
|
}
|
|
|
|
row, err := s.loadAccountRow(ctx, externalID, accountID)
|
|
if err != nil {
|
|
return credentials.Credential{}, err
|
|
}
|
|
|
|
stored, err := s.decryptAccountCredential(row.credentials)
|
|
if err != nil {
|
|
return credentials.Credential{}, err
|
|
}
|
|
|
|
return mergeTestCredential(stored, req)
|
|
}
|