- Added device token management API for mobile devices, including registration, unregistration, and listing of devices. - Implemented push notification functionality using FCM for Android and APNS for iOS. - Introduced new endpoints for device registration and management in the devices API. - Enhanced the configuration to support mobile push notifications with optional credentials for FCM and APNS. - Updated database schema to include a new table for storing device tokens. - Added integration tests for device management and push notification features.
205 lines
5.0 KiB
Go
205 lines
5.0 KiB
Go
package push
|
|
|
|
import (
|
|
"context"
|
|
"crypto/ecdsa"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"crypto/x509"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"encoding/pem"
|
|
"fmt"
|
|
"io"
|
|
"math/big"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
// apnsClient sends notifications through the APNS HTTP/2 provider API using
|
|
// token-based (.p8) authentication.
|
|
type apnsClient struct {
|
|
key *ecdsa.PrivateKey
|
|
keyID string
|
|
teamID string
|
|
topic string
|
|
host string
|
|
http *http.Client
|
|
|
|
mu sync.Mutex
|
|
cachedJWT string
|
|
jwtIssued time.Time
|
|
}
|
|
|
|
// newAPNSClient returns a configured client, or nil (no error) when APNS is not
|
|
// configured. An error is returned only when provided credentials are invalid.
|
|
func newAPNSClient(cfg Config) (*apnsClient, error) {
|
|
pemKey := strings.TrimSpace(cfg.APNSPrivateKey)
|
|
if pemKey == "" {
|
|
return nil, nil
|
|
}
|
|
if strings.TrimSpace(cfg.APNSKeyID) == "" ||
|
|
strings.TrimSpace(cfg.APNSTeamID) == "" ||
|
|
strings.TrimSpace(cfg.APNSBundleID) == "" {
|
|
return nil, fmt.Errorf("apns requires APNS_KEY_ID, APNS_TEAM_ID and APNS_BUNDLE_ID")
|
|
}
|
|
|
|
key, err := parseAPNSKey(pemKey)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
host := "https://api.sandbox.push.apple.com"
|
|
if cfg.APNSProduction {
|
|
host = "https://api.push.apple.com"
|
|
}
|
|
|
|
return &apnsClient{
|
|
key: key,
|
|
keyID: strings.TrimSpace(cfg.APNSKeyID),
|
|
teamID: strings.TrimSpace(cfg.APNSTeamID),
|
|
topic: strings.TrimSpace(cfg.APNSBundleID),
|
|
host: host,
|
|
http: &http.Client{Timeout: 10 * time.Second},
|
|
}, nil
|
|
}
|
|
|
|
func parseAPNSKey(pemKey string) (*ecdsa.PrivateKey, error) {
|
|
block, _ := pem.Decode([]byte(pemKey))
|
|
if block == nil {
|
|
return nil, fmt.Errorf("apns private key is not valid PEM")
|
|
}
|
|
parsed, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse apns private key: %w", err)
|
|
}
|
|
key, ok := parsed.(*ecdsa.PrivateKey)
|
|
if !ok {
|
|
return nil, fmt.Errorf("apns private key is not an ECDSA key")
|
|
}
|
|
return key, nil
|
|
}
|
|
|
|
// jwt returns a cached provider authentication token, regenerating it when it
|
|
// is older than 40 minutes (Apple allows reuse for up to 60 minutes).
|
|
func (c *apnsClient) jwt() (string, error) {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
if c.cachedJWT != "" && time.Since(c.jwtIssued) < 40*time.Minute {
|
|
return c.cachedJWT, nil
|
|
}
|
|
|
|
now := time.Now()
|
|
header := map[string]string{"alg": "ES256", "kid": c.keyID}
|
|
claims := map[string]any{"iss": c.teamID, "iat": now.Unix()}
|
|
|
|
headerJSON, _ := json.Marshal(header)
|
|
claimsJSON, _ := json.Marshal(claims)
|
|
|
|
signingInput := base64URL(headerJSON) + "." + base64URL(claimsJSON)
|
|
|
|
digest := sha256.Sum256([]byte(signingInput))
|
|
r, s, err := ecdsa.Sign(rand.Reader, c.key, digest[:])
|
|
if err != nil {
|
|
return "", fmt.Errorf("sign apns jwt: %w", err)
|
|
}
|
|
signature := ecdsaSignatureBytes(r, s)
|
|
|
|
token := signingInput + "." + base64URL(signature)
|
|
c.cachedJWT = token
|
|
c.jwtIssued = now
|
|
return token, nil
|
|
}
|
|
|
|
type apnsPayload struct {
|
|
APS apnsAPS `json:"aps"`
|
|
Data map[string]string `json:"data,omitempty"`
|
|
}
|
|
|
|
type apnsAPS struct {
|
|
Alert apnsAlert `json:"alert"`
|
|
Sound string `json:"sound,omitempty"`
|
|
}
|
|
|
|
type apnsAlert struct {
|
|
Title string `json:"title,omitempty"`
|
|
Body string `json:"body,omitempty"`
|
|
}
|
|
|
|
func (c *apnsClient) send(ctx context.Context, deviceToken string, n Notification) error {
|
|
token, err := c.jwt()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
payload := apnsPayload{
|
|
APS: apnsAPS{
|
|
Alert: apnsAlert{Title: n.Title, Body: n.Body},
|
|
Sound: "default",
|
|
},
|
|
Data: n.Data,
|
|
}
|
|
body, err := json.Marshal(payload)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
url := c.host + "/3/device/" + deviceToken
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, strings.NewReader(string(body)))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
req.Header.Set("authorization", "bearer "+token)
|
|
req.Header.Set("apns-topic", c.topic)
|
|
req.Header.Set("apns-push-type", "alert")
|
|
req.Header.Set("content-type", "application/json")
|
|
|
|
resp, err := c.http.Do(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode == http.StatusOK {
|
|
return nil
|
|
}
|
|
|
|
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
|
// 410 Gone (and 400 BadDeviceToken/DeviceTokenNotForTopic) means prune.
|
|
if resp.StatusCode == http.StatusGone || apnsIsBadToken(respBody) {
|
|
return errTokenUnregistered
|
|
}
|
|
return fmt.Errorf("apns send status %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody)))
|
|
}
|
|
|
|
func apnsIsBadToken(body []byte) bool {
|
|
var parsed struct {
|
|
Reason string `json:"reason"`
|
|
}
|
|
if err := json.Unmarshal(body, &parsed); err != nil {
|
|
return false
|
|
}
|
|
switch parsed.Reason {
|
|
case "Unregistered", "BadDeviceToken", "DeviceTokenNotForTopic":
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func base64URL(b []byte) string {
|
|
return base64.RawURLEncoding.EncodeToString(b)
|
|
}
|
|
|
|
// ecdsaSignatureBytes encodes an ES256 signature as the fixed-width R||S form
|
|
// required by JWS (each integer left-padded to 32 bytes).
|
|
func ecdsaSignatureBytes(r, s *big.Int) []byte {
|
|
const size = 32
|
|
out := make([]byte, size*2)
|
|
r.FillBytes(out[:size])
|
|
s.FillBytes(out[size:])
|
|
return out
|
|
}
|