package push import ( "context" "crypto/ecdsa" "crypto/rand" "crypto/sha256" "crypto/x509" "encoding/base64" "encoding/json" "encoding/pem" "fmt" "io" "math/big" "net/http" "strings" "sync" "time" ) // apnsClient sends notifications through the APNS HTTP/2 provider API using // token-based (.p8) authentication. type apnsClient struct { key *ecdsa.PrivateKey keyID string teamID string topic string host string http *http.Client mu sync.Mutex cachedJWT string jwtIssued time.Time } // newAPNSClient returns a configured client, or nil (no error) when APNS is not // configured. An error is returned only when provided credentials are invalid. func newAPNSClient(cfg Config) (*apnsClient, error) { pemKey := strings.TrimSpace(cfg.APNSPrivateKey) if pemKey == "" { return nil, nil } if strings.TrimSpace(cfg.APNSKeyID) == "" || strings.TrimSpace(cfg.APNSTeamID) == "" || strings.TrimSpace(cfg.APNSBundleID) == "" { return nil, fmt.Errorf("apns requires APNS_KEY_ID, APNS_TEAM_ID and APNS_BUNDLE_ID") } key, err := parseAPNSKey(pemKey) if err != nil { return nil, err } host := "https://api.sandbox.push.apple.com" if cfg.APNSProduction { host = "https://api.push.apple.com" } return &apnsClient{ key: key, keyID: strings.TrimSpace(cfg.APNSKeyID), teamID: strings.TrimSpace(cfg.APNSTeamID), topic: strings.TrimSpace(cfg.APNSBundleID), host: host, http: &http.Client{Timeout: 10 * time.Second}, }, nil } func parseAPNSKey(pemKey string) (*ecdsa.PrivateKey, error) { block, _ := pem.Decode([]byte(pemKey)) if block == nil { return nil, fmt.Errorf("apns private key is not valid PEM") } parsed, err := x509.ParsePKCS8PrivateKey(block.Bytes) if err != nil { return nil, fmt.Errorf("parse apns private key: %w", err) } key, ok := parsed.(*ecdsa.PrivateKey) if !ok { return nil, fmt.Errorf("apns private key is not an ECDSA key") } return key, nil } // jwt returns a cached provider authentication token, regenerating it when it // is older than 40 minutes (Apple allows reuse for up to 60 minutes). func (c *apnsClient) jwt() (string, error) { c.mu.Lock() defer c.mu.Unlock() if c.cachedJWT != "" && time.Since(c.jwtIssued) < 40*time.Minute { return c.cachedJWT, nil } now := time.Now() header := map[string]string{"alg": "ES256", "kid": c.keyID} claims := map[string]any{"iss": c.teamID, "iat": now.Unix()} headerJSON, _ := json.Marshal(header) claimsJSON, _ := json.Marshal(claims) signingInput := base64URL(headerJSON) + "." + base64URL(claimsJSON) digest := sha256.Sum256([]byte(signingInput)) r, s, err := ecdsa.Sign(rand.Reader, c.key, digest[:]) if err != nil { return "", fmt.Errorf("sign apns jwt: %w", err) } signature := ecdsaSignatureBytes(r, s) token := signingInput + "." + base64URL(signature) c.cachedJWT = token c.jwtIssued = now return token, nil } type apnsPayload struct { APS apnsAPS `json:"aps"` Data map[string]string `json:"data,omitempty"` } type apnsAPS struct { Alert apnsAlert `json:"alert"` Sound string `json:"sound,omitempty"` } type apnsAlert struct { Title string `json:"title,omitempty"` Body string `json:"body,omitempty"` } func (c *apnsClient) send(ctx context.Context, deviceToken string, n Notification) error { token, err := c.jwt() if err != nil { return err } payload := apnsPayload{ APS: apnsAPS{ Alert: apnsAlert{Title: n.Title, Body: n.Body}, Sound: "default", }, Data: n.Data, } body, err := json.Marshal(payload) if err != nil { return err } url := c.host + "/3/device/" + deviceToken req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, strings.NewReader(string(body))) if err != nil { return err } req.Header.Set("authorization", "bearer "+token) req.Header.Set("apns-topic", c.topic) req.Header.Set("apns-push-type", "alert") req.Header.Set("content-type", "application/json") resp, err := c.http.Do(req) if err != nil { return err } defer resp.Body.Close() if resp.StatusCode == http.StatusOK { return nil } respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) // 410 Gone (and 400 BadDeviceToken/DeviceTokenNotForTopic) means prune. if resp.StatusCode == http.StatusGone || apnsIsBadToken(respBody) { return errTokenUnregistered } return fmt.Errorf("apns send status %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody))) } func apnsIsBadToken(body []byte) bool { var parsed struct { Reason string `json:"reason"` } if err := json.Unmarshal(body, &parsed); err != nil { return false } switch parsed.Reason { case "Unregistered", "BadDeviceToken", "DeviceTokenNotForTopic": return true } return false } func base64URL(b []byte) string { return base64.RawURLEncoding.EncodeToString(b) } // ecdsaSignatureBytes encodes an ES256 signature as the fixed-width R||S form // required by JWS (each integer left-padded to 32 bytes). func ecdsaSignatureBytes(r, s *big.Int) []byte { const size = 32 out := make([]byte, size*2) r.FillBytes(out[:size]) s.FillBytes(out[size:]) return out }