package admin import ( "context" "errors" "strconv" "strings" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/ultisuite/ulti-backend/internal/api/query" ) var ErrGroupNameTaken = errors.New("group name already exists") type UserGroupsList struct { Groups []map[string]any `json:"groups"` Pagination query.PaginationMeta `json:"pagination,omitempty"` } func (s *Service) ListUserGroups(ctx context.Context, params query.ListParams) (UserGroupsList, error) { q := strings.TrimSpace(params.Q) whereSQL := "" args := make([]any, 0, 3) if q != "" { pattern := "%" + strings.ToLower(q) + "%" args = append(args, pattern) whereSQL = " WHERE LOWER(name) LIKE $1 OR LOWER(description) LIKE $1" } var total int64 if err := s.db.QueryRow(ctx, "SELECT COUNT(*) FROM user_groups"+whereSQL, args...).Scan(&total); err != nil { return UserGroupsList{}, err } listSQL := ` SELECT g.id, g.name, g.description, g.created_at, g.updated_at, COALESCE(m.member_count, 0) AS member_count FROM user_groups g LEFT JOIN ( SELECT group_id, COUNT(*) AS member_count FROM user_group_members GROUP BY group_id ) m ON m.group_id = g.id` + whereSQL + ` ORDER BY LOWER(g.name) LIMIT $` + strconv.Itoa(len(args)+1) + ` OFFSET $` + strconv.Itoa(len(args)+2) args = append(args, params.Limit(), params.Offset()) rows, err := s.db.Query(ctx, listSQL, args...) if err != nil { return UserGroupsList{}, err } defer rows.Close() groups := make([]map[string]any, 0) for rows.Next() { group, err := scanUserGroupRow(rows) if err != nil { return UserGroupsList{}, err } groups = append(groups, group) } if err := rows.Err(); err != nil { return UserGroupsList{}, err } return UserGroupsList{ Groups: groups, Pagination: params.Meta(&total), }, nil } func (s *Service) GetUserGroup(ctx context.Context, groupID string) (map[string]any, error) { row := s.db.QueryRow(ctx, ` SELECT g.id, g.name, g.description, g.created_at, g.updated_at, COALESCE(m.member_count, 0) AS member_count FROM user_groups g LEFT JOIN ( SELECT group_id, COUNT(*) AS member_count FROM user_group_members GROUP BY group_id ) m ON m.group_id = g.id WHERE g.id = $1 `, groupID) group, err := scanUserGroupRow(row) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return nil, ErrNotFound } return nil, err } return group, nil } type createUserGroupRequest struct { Name string `json:"name"` Description string `json:"description"` } type updateUserGroupRequest struct { Name *string `json:"name"` Description *string `json:"description"` } func (s *Service) CreateUserGroup(ctx context.Context, actorSub string, req createUserGroupRequest) (map[string]any, error) { name := strings.TrimSpace(req.Name) description := strings.TrimSpace(req.Description) var id string err := s.db.QueryRow(ctx, ` INSERT INTO user_groups (name, description) VALUES ($1, $2) RETURNING id `, name, description).Scan(&id) if err != nil { if isUniqueViolation(err) { return nil, ErrGroupNameTaken } return nil, err } s.logAudit(ctx, actorSub, "create_user_group", map[string]any{ "group_id": id, "name": name, }) return s.GetUserGroup(ctx, id) } func (s *Service) UpdateUserGroup(ctx context.Context, actorSub, groupID string, req updateUserGroupRequest) (map[string]any, error) { result, err := s.db.Exec(ctx, ` UPDATE user_groups SET name = COALESCE($2, name), description = COALESCE($3, description), updated_at = NOW() WHERE id = $1 `, groupID, trimStringPtr(req.Name), trimStringPtr(req.Description)) if err != nil { if isUniqueViolation(err) { return nil, ErrGroupNameTaken } return nil, err } if result.RowsAffected() == 0 { return nil, ErrNotFound } s.logAudit(ctx, actorSub, "update_user_group", map[string]any{"group_id": groupID}) return s.GetUserGroup(ctx, groupID) } func (s *Service) DeleteUserGroup(ctx context.Context, actorSub, groupID string) error { result, err := s.db.Exec(ctx, `DELETE FROM user_groups WHERE id = $1`, groupID) if err != nil { return err } if result.RowsAffected() == 0 { return ErrNotFound } s.logAudit(ctx, actorSub, "delete_user_group", map[string]any{"group_id": groupID}) return nil } type setGroupMembersRequest struct { UserIDs []string `json:"user_ids"` } func (s *Service) SetGroupMembers(ctx context.Context, actorSub, groupID string, req setGroupMembersRequest) (map[string]any, error) { exists, err := s.groupExists(ctx, groupID) if err != nil { return nil, err } if !exists { return nil, ErrNotFound } tx, err := s.db.Begin(ctx) if err != nil { return nil, err } defer tx.Rollback(ctx) if _, err := tx.Exec(ctx, `DELETE FROM user_group_members WHERE group_id = $1`, groupID); err != nil { return nil, err } for _, userID := range req.UserIDs { userID = strings.TrimSpace(userID) if userID == "" { continue } ok, err := s.userExistsTx(ctx, tx, userID) if err != nil { return nil, err } if !ok { continue } if _, err := tx.Exec(ctx, ` INSERT INTO user_group_members (group_id, user_id) VALUES ($1, $2) ON CONFLICT DO NOTHING `, groupID, userID); err != nil { return nil, err } } if err := tx.Commit(ctx); err != nil { return nil, err } s.logAudit(ctx, actorSub, "set_user_group_members", map[string]any{ "group_id": groupID, "member_count": len(req.UserIDs), }) return s.GetUserGroup(ctx, groupID) } func (s *Service) AddUsersToGroup(ctx context.Context, actorSub, groupID string, userIDs []string) error { exists, err := s.groupExists(ctx, groupID) if err != nil { return err } if !exists { return ErrNotFound } for _, userID := range userIDs { userID = strings.TrimSpace(userID) if userID == "" { continue } ok, err := s.userExists(ctx, userID) if err != nil { return err } if !ok { continue } if _, err := s.db.Exec(ctx, ` INSERT INTO user_group_members (group_id, user_id) VALUES ($1, $2) ON CONFLICT DO NOTHING `, groupID, userID); err != nil { return err } } s.logAudit(ctx, actorSub, "add_users_to_group", map[string]any{ "group_id": groupID, "count": len(userIDs), }) return nil } func (s *Service) RemoveUsersFromGroup(ctx context.Context, actorSub, groupID string, userIDs []string) error { exists, err := s.groupExists(ctx, groupID) if err != nil { return err } if !exists { return ErrNotFound } for _, userID := range userIDs { userID = strings.TrimSpace(userID) if userID == "" { continue } if _, err := s.db.Exec(ctx, ` DELETE FROM user_group_members WHERE group_id = $1 AND user_id = $2 `, groupID, userID); err != nil { return err } } s.logAudit(ctx, actorSub, "remove_users_from_group", map[string]any{ "group_id": groupID, "count": len(userIDs), }) return nil } func (s *Service) groupExists(ctx context.Context, groupID string) (bool, error) { var exists bool if err := s.db.QueryRow(ctx, `SELECT EXISTS(SELECT 1 FROM user_groups WHERE id = $1)`, groupID).Scan(&exists); err != nil { return false, err } return exists, nil } func (s *Service) userExistsTx(ctx context.Context, tx pgx.Tx, userID string) (bool, error) { var exists bool if err := tx.QueryRow(ctx, `SELECT EXISTS(SELECT 1 FROM users WHERE id = $1)`, userID).Scan(&exists); err != nil { return false, err } return exists, nil } type userGroupRowScanner interface { Scan(dest ...any) error } func scanUserGroupRow(row userGroupRowScanner) (map[string]any, error) { var id, name, description string var memberCount int64 var createdAt, updatedAt any if err := row.Scan(&id, &name, &description, &createdAt, &updatedAt, &memberCount); err != nil { return nil, err } return map[string]any{ "id": id, "name": name, "description": description, "member_count": memberCount, "created_at": createdAt, "updated_at": updatedAt, }, nil } func (s *Service) attachUsersGroups(ctx context.Context, users []map[string]any) error { if len(users) == 0 { return nil } ids := make([]string, 0, len(users)) byID := make(map[string]map[string]any, len(users)) for _, user := range users { id, _ := user["id"].(string) if id == "" { continue } ids = append(ids, id) byID[id] = user user["groups"] = []map[string]any{} } if len(ids) == 0 { return nil } rows, err := s.db.Query(ctx, ` SELECT ugm.user_id::text, g.id, g.name FROM user_group_members ugm JOIN user_groups g ON g.id = ugm.group_id WHERE ugm.user_id = ANY($1::uuid[]) ORDER BY LOWER(g.name) `, ids) if err != nil { return err } defer rows.Close() for rows.Next() { var userID, groupID, name string if err := rows.Scan(&userID, &groupID, &name); err != nil { return err } user, ok := byID[userID] if !ok { continue } groups, _ := user["groups"].([]map[string]any) user["groups"] = append(groups, map[string]any{ "id": groupID, "name": name, }) } return rows.Err() } func isUniqueViolation(err error) bool { var pgErr *pgconn.PgError return errors.As(err, &pgErr) && pgErr.Code == "23505" }