package oauth import ( "context" "crypto/rand" "encoding/json" "errors" "fmt" "time" "github.com/redis/go-redis/v9" "golang.org/x/oauth2" ) const pendingKeyPrefix = "mail_oauth_pending:" const pendingTTL = 10 * time.Minute var ErrUnknownState = errors.New("oauth state expired or unknown") var ErrProviderDisabled = errors.New("oauth provider not configured") type Provider string const ( ProviderGoogle Provider = "google" ProviderMicrosoft Provider = "microsoft" ) type PendingAccount struct { UserExternalID string `json:"user_external_id"` Provider string `json:"provider"` Email string `json:"email"` Name string `json:"name"` ProviderID string `json:"provider_id"` IMAPHost string `json:"imap_host"` IMAPPort int `json:"imap_port"` IMAPTLS bool `json:"imap_tls"` SMTPHost string `json:"smtp_host"` SMTPPort int `json:"smtp_port"` SMTPTLS bool `json:"smtp_tls"` PKCEVerifier string `json:"pkce_verifier"` } type Config struct { GoogleClientID string GoogleClientSecret string MicrosoftClientID string MicrosoftSecret string MicrosoftTenant string RedirectURL string } type Service struct { cfg Config rdb *redis.Client } func NewService(cfg Config, rdb *redis.Client) *Service { return &Service{cfg: cfg, rdb: rdb} } func (s *Service) EnabledProviders() []string { var out []string if s.providerConfig(ProviderGoogle) != nil { out = append(out, string(ProviderGoogle)) } if s.providerConfig(ProviderMicrosoft) != nil { out = append(out, string(ProviderMicrosoft)) } return out } func (s *Service) Start(ctx context.Context, userExternalID string, provider Provider, pending PendingAccount) (authURL, state string, err error) { oauthCfg := s.providerConfig(provider) if oauthCfg == nil { return "", "", ErrProviderDisabled } verifier, challenge, err := newPKCE() if err != nil { return "", "", err } state, err = randomToken(24) if err != nil { return "", "", err } pending.UserExternalID = userExternalID pending.Provider = string(provider) pending.PKCEVerifier = verifier if err := s.savePending(ctx, state, pending); err != nil { return "", "", err } authURL = oauthCfg.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("code_challenge", challenge), oauth2.SetAuthURLParam("code_challenge_method", "S256"), oauth2.SetAuthURLParam("prompt", "consent"), ) return authURL, state, nil } func (s *Service) Exchange(ctx context.Context, state, code string) (PendingAccount, *oauth2.Token, error) { pending, err := s.loadPending(ctx, state) if err != nil { return PendingAccount{}, nil, err } oauthCfg := s.providerConfig(Provider(pending.Provider)) if oauthCfg == nil { return PendingAccount{}, nil, ErrProviderDisabled } token, err := oauthCfg.Exchange(ctx, code, oauth2.SetAuthURLParam("code_verifier", pending.PKCEVerifier)) if err != nil { return PendingAccount{}, nil, fmt.Errorf("token exchange: %w", err) } _ = s.rdb.Del(ctx, pendingKeyPrefix+state).Err() return pending, token, nil } func (s *Service) Refresh(ctx context.Context, provider, refreshToken string) (*oauth2.Token, error) { oauthCfg := s.providerConfig(Provider(provider)) if oauthCfg == nil { return nil, ErrProviderDisabled } src := oauthCfg.TokenSource(ctx, &oauth2.Token{RefreshToken: refreshToken}) token, err := src.Token() if err != nil { return nil, err } return token, nil } func (s *Service) providerConfig(provider Provider) *oauth2.Config { switch provider { case ProviderGoogle: if s.cfg.GoogleClientID == "" || s.cfg.GoogleClientSecret == "" || s.cfg.RedirectURL == "" { return nil } return &oauth2.Config{ ClientID: s.cfg.GoogleClientID, ClientSecret: s.cfg.GoogleClientSecret, RedirectURL: s.cfg.RedirectURL, Scopes: []string{"https://mail.google.com/"}, Endpoint: oauth2.Endpoint{ AuthURL: "https://accounts.google.com/o/oauth2/v2/auth", TokenURL: "https://oauth2.googleapis.com/token", }, } case ProviderMicrosoft: if s.cfg.MicrosoftClientID == "" || s.cfg.MicrosoftSecret == "" || s.cfg.RedirectURL == "" { return nil } tenant := s.cfg.MicrosoftTenant if tenant == "" { tenant = "common" } return &oauth2.Config{ ClientID: s.cfg.MicrosoftClientID, ClientSecret: s.cfg.MicrosoftSecret, RedirectURL: s.cfg.RedirectURL, Scopes: []string{ "offline_access", "https://outlook.office.com/IMAP.AccessAsUser.All", "https://outlook.office.com/SMTP.Send", }, Endpoint: oauth2.Endpoint{ AuthURL: fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/authorize", tenant), TokenURL: fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/token", tenant), }, } default: return nil } } func (s *Service) savePending(ctx context.Context, state string, pending PendingAccount) error { if s.rdb == nil { return errors.New("oauth state store unavailable") } raw, err := json.Marshal(pending) if err != nil { return err } return s.rdb.Set(ctx, pendingKeyPrefix+state, raw, pendingTTL).Err() } func (s *Service) loadPending(ctx context.Context, state string) (PendingAccount, error) { if s.rdb == nil { return PendingAccount{}, errors.New("oauth state store unavailable") } raw, err := s.rdb.Get(ctx, pendingKeyPrefix+state).Bytes() if err != nil { if errors.Is(err, redis.Nil) { return PendingAccount{}, ErrUnknownState } return PendingAccount{}, err } var pending PendingAccount if err := json.Unmarshal(raw, &pending); err != nil { return PendingAccount{}, err } return pending, nil } func newPKCE() (verifier, challenge string, err error) { b := make([]byte, 32) if _, err := rand.Read(b); err != nil { return "", "", err } verifier = base64URLEncode(b) sum, err := sha256Sum(verifier) if err != nil { return "", "", err } challenge = base64URLEncode(sum) return verifier, challenge, nil } func randomToken(n int) (string, error) { b := make([]byte, n) if _, err := rand.Read(b); err != nil { return "", err } return base64URLEncode(b), nil }