225 lines
6.0 KiB
Go
225 lines
6.0 KiB
Go
package oauth
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/redis/go-redis/v9"
|
|
"golang.org/x/oauth2"
|
|
)
|
|
|
|
const pendingKeyPrefix = "mail_oauth_pending:"
|
|
const pendingTTL = 10 * time.Minute
|
|
|
|
var ErrUnknownState = errors.New("oauth state expired or unknown")
|
|
var ErrProviderDisabled = errors.New("oauth provider not configured")
|
|
|
|
type Provider string
|
|
|
|
const (
|
|
ProviderGoogle Provider = "google"
|
|
ProviderMicrosoft Provider = "microsoft"
|
|
)
|
|
|
|
type PendingAccount struct {
|
|
UserExternalID string `json:"user_external_id"`
|
|
Provider string `json:"provider"`
|
|
Email string `json:"email"`
|
|
Name string `json:"name"`
|
|
ProviderID string `json:"provider_id"`
|
|
IMAPHost string `json:"imap_host"`
|
|
IMAPPort int `json:"imap_port"`
|
|
IMAPTLS bool `json:"imap_tls"`
|
|
SMTPHost string `json:"smtp_host"`
|
|
SMTPPort int `json:"smtp_port"`
|
|
SMTPTLS bool `json:"smtp_tls"`
|
|
PKCEVerifier string `json:"pkce_verifier"`
|
|
}
|
|
|
|
type Config struct {
|
|
GoogleClientID string
|
|
GoogleClientSecret string
|
|
MicrosoftClientID string
|
|
MicrosoftSecret string
|
|
MicrosoftTenant string
|
|
RedirectURL string
|
|
}
|
|
|
|
type Service struct {
|
|
cfg Config
|
|
rdb *redis.Client
|
|
}
|
|
|
|
func NewService(cfg Config, rdb *redis.Client) *Service {
|
|
return &Service{cfg: cfg, rdb: rdb}
|
|
}
|
|
|
|
func (s *Service) EnabledProviders() []string {
|
|
var out []string
|
|
if s.providerConfig(ProviderGoogle) != nil {
|
|
out = append(out, string(ProviderGoogle))
|
|
}
|
|
if s.providerConfig(ProviderMicrosoft) != nil {
|
|
out = append(out, string(ProviderMicrosoft))
|
|
}
|
|
return out
|
|
}
|
|
|
|
func (s *Service) Start(ctx context.Context, userExternalID string, provider Provider, pending PendingAccount) (authURL, state string, err error) {
|
|
oauthCfg := s.providerConfig(provider)
|
|
if oauthCfg == nil {
|
|
return "", "", ErrProviderDisabled
|
|
}
|
|
verifier, challenge, err := newPKCE()
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
state, err = randomToken(24)
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
pending.UserExternalID = userExternalID
|
|
pending.Provider = string(provider)
|
|
pending.PKCEVerifier = verifier
|
|
|
|
if err := s.savePending(ctx, state, pending); err != nil {
|
|
return "", "", err
|
|
}
|
|
|
|
authURL = oauthCfg.AuthCodeURL(state,
|
|
oauth2.AccessTypeOffline,
|
|
oauth2.SetAuthURLParam("code_challenge", challenge),
|
|
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
|
oauth2.SetAuthURLParam("prompt", "consent"),
|
|
)
|
|
return authURL, state, nil
|
|
}
|
|
|
|
func (s *Service) Exchange(ctx context.Context, state, code string) (PendingAccount, *oauth2.Token, error) {
|
|
pending, err := s.loadPending(ctx, state)
|
|
if err != nil {
|
|
return PendingAccount{}, nil, err
|
|
}
|
|
oauthCfg := s.providerConfig(Provider(pending.Provider))
|
|
if oauthCfg == nil {
|
|
return PendingAccount{}, nil, ErrProviderDisabled
|
|
}
|
|
token, err := oauthCfg.Exchange(ctx, code, oauth2.SetAuthURLParam("code_verifier", pending.PKCEVerifier))
|
|
if err != nil {
|
|
return PendingAccount{}, nil, fmt.Errorf("token exchange: %w", err)
|
|
}
|
|
_ = s.rdb.Del(ctx, pendingKeyPrefix+state).Err()
|
|
return pending, token, nil
|
|
}
|
|
|
|
func (s *Service) Refresh(ctx context.Context, provider, refreshToken string) (*oauth2.Token, error) {
|
|
oauthCfg := s.providerConfig(Provider(provider))
|
|
if oauthCfg == nil {
|
|
return nil, ErrProviderDisabled
|
|
}
|
|
src := oauthCfg.TokenSource(ctx, &oauth2.Token{RefreshToken: refreshToken})
|
|
token, err := src.Token()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return token, nil
|
|
}
|
|
|
|
func (s *Service) providerConfig(provider Provider) *oauth2.Config {
|
|
switch provider {
|
|
case ProviderGoogle:
|
|
if s.cfg.GoogleClientID == "" || s.cfg.GoogleClientSecret == "" || s.cfg.RedirectURL == "" {
|
|
return nil
|
|
}
|
|
return &oauth2.Config{
|
|
ClientID: s.cfg.GoogleClientID,
|
|
ClientSecret: s.cfg.GoogleClientSecret,
|
|
RedirectURL: s.cfg.RedirectURL,
|
|
Scopes: []string{"https://mail.google.com/"},
|
|
Endpoint: oauth2.Endpoint{
|
|
AuthURL: "https://accounts.google.com/o/oauth2/v2/auth",
|
|
TokenURL: "https://oauth2.googleapis.com/token",
|
|
},
|
|
}
|
|
case ProviderMicrosoft:
|
|
if s.cfg.MicrosoftClientID == "" || s.cfg.MicrosoftSecret == "" || s.cfg.RedirectURL == "" {
|
|
return nil
|
|
}
|
|
tenant := s.cfg.MicrosoftTenant
|
|
if tenant == "" {
|
|
tenant = "common"
|
|
}
|
|
return &oauth2.Config{
|
|
ClientID: s.cfg.MicrosoftClientID,
|
|
ClientSecret: s.cfg.MicrosoftSecret,
|
|
RedirectURL: s.cfg.RedirectURL,
|
|
Scopes: []string{
|
|
"offline_access",
|
|
"https://outlook.office.com/IMAP.AccessAsUser.All",
|
|
"https://outlook.office.com/SMTP.Send",
|
|
},
|
|
Endpoint: oauth2.Endpoint{
|
|
AuthURL: fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/authorize", tenant),
|
|
TokenURL: fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/token", tenant),
|
|
},
|
|
}
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (s *Service) savePending(ctx context.Context, state string, pending PendingAccount) error {
|
|
if s.rdb == nil {
|
|
return errors.New("oauth state store unavailable")
|
|
}
|
|
raw, err := json.Marshal(pending)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return s.rdb.Set(ctx, pendingKeyPrefix+state, raw, pendingTTL).Err()
|
|
}
|
|
|
|
func (s *Service) loadPending(ctx context.Context, state string) (PendingAccount, error) {
|
|
if s.rdb == nil {
|
|
return PendingAccount{}, errors.New("oauth state store unavailable")
|
|
}
|
|
raw, err := s.rdb.Get(ctx, pendingKeyPrefix+state).Bytes()
|
|
if err != nil {
|
|
if errors.Is(err, redis.Nil) {
|
|
return PendingAccount{}, ErrUnknownState
|
|
}
|
|
return PendingAccount{}, err
|
|
}
|
|
var pending PendingAccount
|
|
if err := json.Unmarshal(raw, &pending); err != nil {
|
|
return PendingAccount{}, err
|
|
}
|
|
return pending, nil
|
|
}
|
|
|
|
func newPKCE() (verifier, challenge string, err error) {
|
|
b := make([]byte, 32)
|
|
if _, err := rand.Read(b); err != nil {
|
|
return "", "", err
|
|
}
|
|
verifier = base64URLEncode(b)
|
|
sum, err := sha256Sum(verifier)
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
challenge = base64URLEncode(sum)
|
|
return verifier, challenge, nil
|
|
}
|
|
|
|
func randomToken(n int) (string, error) {
|
|
b := make([]byte, n)
|
|
if _, err := rand.Read(b); err != nil {
|
|
return "", err
|
|
}
|
|
return base64URLEncode(b), nil
|
|
}
|