feat(deploy): enhance Nginx configuration and API integration for UltiAI
- Updated .env.example to include new configuration options for the UltiAI branding and API endpoints. - Enhanced Nginx configuration to support new API routes for the MCP and WebSocket connections. - Introduced sub-filters for branding adjustments in Nginx responses. - Added new JavaScript patch for API endpoint adjustments. - Implemented tests for new API functionalities and improved error handling in the AI gateway.
This commit is contained in:
parent
1e4e373f93
commit
621b0099d6
@ -181,6 +181,12 @@ BYPASS_MODEL_ACCESS_CONTROL=true
|
||||
WEBUI_NAME=UltiAI
|
||||
AI_ASSISTANT_PUBLIC_PATH=/ai
|
||||
ULTIMAIL_MCP_URL=http://ultimail-mcp:3100
|
||||
# Public MCP endpoint (nginx → ultid → ultimail-mcp): /api/v1/ai/mcp
|
||||
# OpenWebUI uses AI_GATEWAY_API_KEY + forwarded X-OpenWebUI-User-* headers for per-user tokens.
|
||||
# Ultimail MCP auto-enabled for all users (see deploy/openwebui/docker-compose.openwebui.yml):
|
||||
# DEFAULT_MODEL_PARAMS={"function_calling":"native"}
|
||||
# DEFAULT_MODEL_METADATA={"toolIds":["server:mcp:ultimail"],...}
|
||||
# TOOL_SERVER_CONNECTIONS with config.access_grants public read
|
||||
# OpenWebUI utilise POSTGRES_USER/POSTGRES_PASSWORD (base openwebui créée dans init-db.sh)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
@ -6,6 +6,8 @@ services:
|
||||
- "80:80"
|
||||
volumes:
|
||||
- ./nginx/default.conf.template:/etc/nginx/templates/default.conf.template:ro
|
||||
- ./nginx/openwebui-subfilters.conf:/etc/nginx/openwebui-subfilters.conf:ro
|
||||
- ./nginx/patches/QGuclOcQ.js:/etc/nginx/patches/QGuclOcQ.js:ro
|
||||
environment:
|
||||
DOMAIN: ${DOMAIN:-localhost}
|
||||
MAIL_FRONTEND_UPSTREAM: ${MAIL_FRONTEND_UPSTREAM:-host.docker.internal:3004}
|
||||
|
||||
@ -180,6 +180,40 @@ server {
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
}
|
||||
location ^~ /api/v1/ai/mcp {
|
||||
resolver 127.0.0.11 valid=10s ipv6=off;
|
||||
set $ultid_upstream ultid:8080;
|
||||
proxy_hide_header Access-Control-Allow-Origin;
|
||||
proxy_hide_header Access-Control-Allow-Methods;
|
||||
proxy_hide_header Access-Control-Allow-Headers;
|
||||
proxy_hide_header Access-Control-Expose-Headers;
|
||||
proxy_hide_header Access-Control-Max-Age;
|
||||
proxy_hide_header Access-Control-Allow-Credentials;
|
||||
proxy_hide_header Vary;
|
||||
add_header Access-Control-Allow-Origin $cors_allow_origin always;
|
||||
add_header Access-Control-Allow-Methods "GET, HEAD, POST, PUT, PATCH, DELETE, OPTIONS" always;
|
||||
add_header Access-Control-Allow-Headers "Accept, Authorization, Content-Type, Idempotency-Key, Mcp-Session-Id, Origin, X-Requested-With, X-Trace-Id, X-Ulti-Token, X-OpenWebUI-User-Email, X-OpenWebUI-User-Id, X-OpenWebUI-User-Name, X-OpenWebUI-User-Role" always;
|
||||
add_header Access-Control-Expose-Headers "Mcp-Session-Id, X-Trace-Id" always;
|
||||
add_header Access-Control-Max-Age 300 always;
|
||||
add_header Vary Origin always;
|
||||
proxy_pass http://$ultid_upstream;
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header Cookie $http_cookie;
|
||||
proxy_set_header Authorization $http_authorization;
|
||||
proxy_set_header X-Ulti-Token $http_x_ulti_token;
|
||||
proxy_set_header X-OpenWebUI-User-Email $http_x_openwebui_user_email;
|
||||
proxy_set_header X-OpenWebUI-User-Id $http_x_openwebui_user_id;
|
||||
proxy_set_header X-OpenWebUI-User-Name $http_x_openwebui_user_name;
|
||||
proxy_set_header X-OpenWebUI-User-Role $http_x_openwebui_user_role;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_buffering off;
|
||||
proxy_cache off;
|
||||
proxy_read_timeout 86400s;
|
||||
proxy_send_timeout 86400s;
|
||||
}
|
||||
location ^~ /api/v1/ai/ {
|
||||
resolver 127.0.0.11 valid=10s ipv6=off;
|
||||
set $ultid_upstream ultid:8080;
|
||||
@ -192,16 +226,27 @@ server {
|
||||
proxy_hide_header Vary;
|
||||
add_header Access-Control-Allow-Origin $cors_allow_origin always;
|
||||
add_header Access-Control-Allow-Methods "GET, HEAD, POST, PUT, PATCH, DELETE, OPTIONS" always;
|
||||
add_header Access-Control-Allow-Headers "Accept, Authorization, Content-Type, Idempotency-Key, Origin, X-Requested-With, X-Trace-Id" always;
|
||||
add_header Access-Control-Expose-Headers "X-Trace-Id" always;
|
||||
add_header Access-Control-Allow-Headers "Accept, Authorization, Content-Type, Idempotency-Key, Mcp-Session-Id, Origin, X-Requested-With, X-Trace-Id, X-Ulti-Token, X-OpenWebUI-User-Email, X-OpenWebUI-User-Id, X-OpenWebUI-User-Name, X-OpenWebUI-User-Role" always;
|
||||
add_header Access-Control-Expose-Headers "Mcp-Session-Id, X-Trace-Id" always;
|
||||
add_header Access-Control-Max-Age 300 always;
|
||||
add_header Vary Origin always;
|
||||
proxy_pass http://$ultid_upstream;
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header Cookie $http_cookie;
|
||||
proxy_set_header Authorization $http_authorization;
|
||||
proxy_set_header X-Ulti-Token $http_x_ulti_token;
|
||||
proxy_set_header X-OpenWebUI-User-Email $http_x_openwebui_user_email;
|
||||
proxy_set_header X-OpenWebUI-User-Id $http_x_openwebui_user_id;
|
||||
proxy_set_header X-OpenWebUI-User-Name $http_x_openwebui_user_name;
|
||||
proxy_set_header X-OpenWebUI-User-Role $http_x_openwebui_user_role;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_buffering off;
|
||||
proxy_cache off;
|
||||
proxy_read_timeout 86400s;
|
||||
proxy_send_timeout 86400s;
|
||||
}
|
||||
location ^~ /api/v1/users/me {
|
||||
resolver 127.0.0.11 valid=10s ipv6=off;
|
||||
@ -446,6 +491,27 @@ server {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_buffering off;
|
||||
proxy_cache off;
|
||||
proxy_read_timeout 86400s;
|
||||
proxy_send_timeout 86400s;
|
||||
}
|
||||
|
||||
# OpenWebUI socket.io under /ai prefix (SvelteKit base /ai — avoid embed-auth on /ai/)
|
||||
location ^~ /ai/ws/socket.io {
|
||||
resolver 127.0.0.11 valid=10s ipv6=off;
|
||||
set $openwebui_upstream openwebui:8080;
|
||||
rewrite ^/ai/ws/socket.io(.*)$ /ws/socket.io$1 break;
|
||||
proxy_pass http://$openwebui_upstream;
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection $connection_upgrade;
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_buffering off;
|
||||
proxy_cache off;
|
||||
proxy_read_timeout 86400s;
|
||||
proxy_send_timeout 86400s;
|
||||
}
|
||||
@ -648,6 +714,13 @@ server {
|
||||
# SvelteKit base path — without this, /ai/ routes 404 (base "" expects site root)
|
||||
sub_filter 'base: ""' 'base: "/ai"';
|
||||
sub_filter "base: ''" 'base: "/ai"';
|
||||
# Keep t="" so socket.io uses default namespace (/), not /ai — see openwebui-subfilters.conf
|
||||
include /etc/nginx/openwebui-subfilters.conf;
|
||||
sub_filter 'location.href="/auth"' 'location.href="/ai/auth"';
|
||||
sub_filter "location.href='/auth'" "location.href='/ai/auth'";
|
||||
sub_filter '"/auth?"' '"/ai/auth?"';
|
||||
sub_filter '"/auth"' '"/ai/auth"';
|
||||
sub_filter '<script src="/static/loader.js"' '<script src="/ai/static/ulti-session.js"></script><script src="/ai/static/loader.js"';
|
||||
# In-app links that escape to site root (e.g. "Nouvelle conversation" → /)
|
||||
sub_filter 'href="/"' 'href="/ai/"';
|
||||
sub_filter "href='/'" "href='/ai/'";
|
||||
@ -668,6 +741,43 @@ server {
|
||||
return 301 /ai/;
|
||||
}
|
||||
|
||||
# OpenWebUI trusted-header signin under /ai prefix (iframe loads /ai/api/v1/auths/signin).
|
||||
location = /ai/api/v1/auths/signin {
|
||||
resolver 127.0.0.11 valid=10s ipv6=off;
|
||||
set $ultid_upstream ultid:8080;
|
||||
|
||||
auth_request /api/v1/ai/embed-auth;
|
||||
auth_request_set $ulti_user_email $upstream_http_x_ulti_user_email;
|
||||
auth_request_set $ulti_user_name $upstream_http_x_ulti_user_name;
|
||||
auth_request_set $ulti_user_role $upstream_http_x_ulti_user_role;
|
||||
|
||||
proxy_hide_header Access-Control-Allow-Origin;
|
||||
proxy_hide_header Access-Control-Allow-Methods;
|
||||
proxy_hide_header Access-Control-Allow-Headers;
|
||||
proxy_hide_header Access-Control-Expose-Headers;
|
||||
proxy_hide_header Access-Control-Max-Age;
|
||||
proxy_hide_header Access-Control-Allow-Credentials;
|
||||
proxy_hide_header Vary;
|
||||
add_header Access-Control-Allow-Origin $cors_allow_origin always;
|
||||
add_header Access-Control-Allow-Methods "GET, HEAD, POST, PUT, PATCH, DELETE, OPTIONS" always;
|
||||
add_header Access-Control-Allow-Headers "Accept, Authorization, Content-Type, Idempotency-Key, Origin, X-Requested-With, X-Trace-Id" always;
|
||||
add_header Access-Control-Expose-Headers "X-Trace-Id" always;
|
||||
add_header Access-Control-Max-Age 300 always;
|
||||
add_header Vary Origin always;
|
||||
|
||||
proxy_pass http://$ultid_upstream/api/v1/ai/embed-signin;
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header Cookie $http_cookie;
|
||||
proxy_set_header Authorization $http_authorization;
|
||||
proxy_set_header X-Ulti-User-Email $ulti_user_email;
|
||||
proxy_set_header X-Ulti-User-Name $ulti_user_name;
|
||||
proxy_set_header X-Ulti-User-Role $ulti_user_role;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
}
|
||||
|
||||
# OpenWebUI API (prefix /ai/api — évite collision avec ultid /api/v1/)
|
||||
location /ai/api/ {
|
||||
resolver 127.0.0.11 valid=10s ipv6=off;
|
||||
@ -682,10 +792,24 @@ server {
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection $connection_upgrade;
|
||||
proxy_buffering off;
|
||||
proxy_cache off;
|
||||
proxy_read_timeout 86400s;
|
||||
proxy_send_timeout 86400s;
|
||||
}
|
||||
|
||||
# OpenWebUI constants chunk — t="" keeps socket.io on default namespace; /ai API paths baked in
|
||||
location = /ai/_app/immutable/chunks/QGuclOcQ.js {
|
||||
auth_request /api/v1/ai/embed-auth;
|
||||
default_type application/javascript;
|
||||
alias /etc/nginx/patches/QGuclOcQ.js;
|
||||
}
|
||||
location = /_app/immutable/chunks/QGuclOcQ.js {
|
||||
auth_request /api/v1/ai/embed-auth;
|
||||
default_type application/javascript;
|
||||
alias /etc/nginx/patches/QGuclOcQ.js;
|
||||
}
|
||||
|
||||
# OpenWebUI assets (SPA uses root-relative /static, /_app — not /ai/…)
|
||||
location ^~ /static/ {
|
||||
resolver 127.0.0.11 valid=10s ipv6=off;
|
||||
@ -711,6 +835,10 @@ server {
|
||||
sub_filter_types application/javascript text/javascript;
|
||||
sub_filter '"/api/' '"/ai/api/';
|
||||
sub_filter "'/api/" "'/ai/api/";
|
||||
include /etc/nginx/openwebui-subfilters.conf;
|
||||
sub_filter 'location.href="/auth"' 'location.href="/ai/auth"';
|
||||
sub_filter '"/auth?"' '"/ai/auth?"';
|
||||
sub_filter '"/auth"' '"/ai/auth"';
|
||||
sub_filter 'href="/"' 'href="/ai/"';
|
||||
sub_filter "href='/'" "href='/ai/'";
|
||||
sub_filter 'Open WebUI' 'UltiAI';
|
||||
@ -741,6 +869,10 @@ server {
|
||||
sub_filter_types application/javascript text/javascript application/json;
|
||||
sub_filter '"/api/' '"/ai/api/';
|
||||
sub_filter "'/api/" "'/ai/api/";
|
||||
include /etc/nginx/openwebui-subfilters.conf;
|
||||
sub_filter 'location.href="/auth"' 'location.href="/ai/auth"';
|
||||
sub_filter '"/auth?"' '"/ai/auth?"';
|
||||
sub_filter '"/auth"' '"/ai/auth"';
|
||||
sub_filter 'href="/"' 'href="/ai/"';
|
||||
sub_filter "href='/'" "href='/ai/'";
|
||||
sub_filter 'Open WebUI' 'UltiAI';
|
||||
|
||||
4
deploy/nginx/openwebui-subfilters.conf
Normal file
4
deploy/nginx/openwebui-subfilters.conf
Normal file
@ -0,0 +1,4 @@
|
||||
# Branding only — API base paths served via patches/QGuclOcQ.js (sub_filter cannot match ${t}).
|
||||
sub_filter 'const e="Open WebUI",t=""' 'const e="UltiAI",t=""';
|
||||
sub_filter '<title>Open WebUI</title>' '<title>UltiAI</title>';
|
||||
sub_filter '<title>UltiAI (Open WebUI)</title>' '<title>UltiAI</title>';
|
||||
1
deploy/nginx/patches/QGuclOcQ.js
Normal file
1
deploy/nginx/patches/QGuclOcQ.js
Normal file
@ -0,0 +1 @@
|
||||
const e="UltiAI",t="",_=`/ai/api/v1`,a=`/ai/ollama`,s=`/ai/openai`,A=`/ai/api/v1/audio`,o=`/ai/api/v1/images`,I=`/ai/api/v1/retrieval`,n="0.9.6",E="02dc3e689ceac915a870b373318b99c029ddf603",c={file_context:!0,vision:!0,file_upload:!0,web_search:!0,image_generation:!0,code_interpreter:!0,terminal:!0,citations:!0,status_updates:!0,usage:void 0,builtin_tools:!0},i=1e3;export{A,c as D,o as I,s as O,i as P,I as R,n as W,t as a,_ as b,a as c,e as d,E as e};
|
||||
@ -10,6 +10,9 @@ services:
|
||||
WEBUI_AUTH_TRUSTED_ROLE_HEADER: X-Ulti-User-Role
|
||||
ENABLE_PERSISTENT_CONFIG: "false"
|
||||
ENABLE_DIRECT_CONNECTIONS: "false"
|
||||
ENABLE_FORWARD_USER_INFO_HEADERS: "true"
|
||||
# Polling fallback — more reliable than websocket-only through nginx + iframe embed
|
||||
ENABLE_WEBSOCKET_SUPPORT: "false"
|
||||
BYPASS_MODEL_ACCESS_CONTROL: "true"
|
||||
WEBUI_NAME: UltiAI
|
||||
OPENAI_API_BASE_URL: http://ultid:8080/api/v1/ai
|
||||
@ -18,12 +21,25 @@ services:
|
||||
WEBUI_SECRET_KEY: ${WEBUI_SECRET_KEY:-changeme-openwebui-dev-secret}
|
||||
DATABASE_URL: postgresql://${POSTGRES_USER}:${POSTGRES_PASSWORD}@postgres:5432/openwebui
|
||||
USER_PERMISSIONS_CHAT_TEMPORARY_ENFORCED: "false"
|
||||
ENABLE_MCP: "true"
|
||||
MCP_INITIALIZE_TIMEOUT: "30"
|
||||
USER_PERMISSIONS_FEATURES_DIRECT_TOOL_SERVERS: "true"
|
||||
DEFAULT_MODEL_PARAMS: '{"function_calling":"native"}'
|
||||
DEFAULT_MODEL_METADATA: '{"toolIds":["server:mcp:ultimail"],"capabilities":{"builtin_tools":true}}'
|
||||
TOOL_SERVER_CONNECTIONS: >-
|
||||
[{"type":"mcp","url":"http://nginx/api/v1/ai/mcp","auth_type":"bearer","key":"${AI_GATEWAY_API_KEY:-ulti-gateway}","config":{"enable":true,"access_grants":[{"principal_type":"user","principal_id":"*","permission":"read"}]},"info":{"id":"ultimail","name":"Ultimail","description":"Mail, drive, contacts, agenda, recherche suite et recherche web"}}]
|
||||
volumes:
|
||||
- openwebui_data:/app/backend/data
|
||||
- ../services/openwebui/pipelines:/app/pipelines/custom:ro
|
||||
- ../services/openwebui/static/logo.png:/app/backend/open_webui/static/logo.png:ro
|
||||
- ../services/openwebui/static/favicon.png:/app/backend/open_webui/static/favicon.png:ro
|
||||
- ../services/openwebui/static/favicon.png:/app/backend/open_webui/static/favicon-96x96.png:ro
|
||||
- ../services/openwebui/static/favicon.png:/app/backend/open_webui/static/apple-touch-icon.png:ro
|
||||
- ../services/openwebui/static/favicon.png:/app/backend/open_webui/static/favicon-dark.png:ro
|
||||
- ../services/openwebui/static/ultiai-mark.svg:/app/backend/open_webui/static/favicon.svg:ro
|
||||
- ../services/openwebui/static/favicon.ico:/app/backend/open_webui/static/favicon.ico:ro
|
||||
- ../services/openwebui/static/custom.css:/app/backend/open_webui/static/custom.css:ro
|
||||
- ../services/openwebui/static/ulti-session.js:/app/backend/open_webui/static/ulti-session.js:ro
|
||||
networks:
|
||||
- ulti-net
|
||||
depends_on:
|
||||
|
||||
@ -115,37 +115,56 @@ func (g *Gateway) listModelsFromSettings(ctx context.Context, settings llm.Setti
|
||||
return ApplyModelCatalog(out, policy.Models), nil
|
||||
}
|
||||
|
||||
func (g *Gateway) ProxyChatCompletions(ctx context.Context, externalUserID string, body []byte, w http.ResponseWriter) error {
|
||||
if err := g.quota.AssertAvailable(ctx, externalUserID); err != nil {
|
||||
return err
|
||||
func (g *Gateway) ProxyChatCompletions(ctx context.Context, quotaExternalUserID string, useOrgSettings bool, body []byte, w http.ResponseWriter) error {
|
||||
if strings.TrimSpace(quotaExternalUserID) != "" {
|
||||
if err := g.quota.AssertAvailable(ctx, quotaExternalUserID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var req chatCompletionRequest
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
var modelProbe struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &modelProbe); err != nil {
|
||||
return fmt.Errorf("invalid chat completion request: %w", err)
|
||||
}
|
||||
|
||||
policy, _ := LoadAssistantPolicy(ctx, g.db)
|
||||
if !IsModelAllowed(req.Model, policy.Models) {
|
||||
return fmt.Errorf("model %q is not allowed", strings.TrimSpace(req.Model))
|
||||
if !IsModelAllowed(modelProbe.Model, policy.Models) {
|
||||
return fmt.Errorf("model %q is not allowed", strings.TrimSpace(modelProbe.Model))
|
||||
}
|
||||
|
||||
settings, err := LoadEffectiveLLMSettings(ctx, g.db, externalUserID)
|
||||
var (
|
||||
settings llm.Settings
|
||||
err error
|
||||
)
|
||||
if useOrgSettings {
|
||||
settings, err = LoadOrgLLMSettings(ctx, g.db)
|
||||
} else {
|
||||
settings, err = LoadEffectiveLLMSettings(ctx, g.db, quotaExternalUserID)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
provider, model, err := resolveProviderForModel(settings, req.Model)
|
||||
provider, model, err := resolveProviderForModel(settings, modelProbe.Model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if strings.TrimSpace(req.Model) == "" {
|
||||
req.Model = model
|
||||
}
|
||||
|
||||
upstreamBody, err := json.Marshal(req)
|
||||
upstreamBody, err := repairChatCompletionBody(body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stream := chatCompletionStreamRequested(upstreamBody)
|
||||
if strings.TrimSpace(modelProbe.Model) == "" {
|
||||
var req map[string]any
|
||||
if err := json.Unmarshal(upstreamBody, &req); err == nil {
|
||||
req["model"] = model
|
||||
if patched, err := json.Marshal(req); err == nil {
|
||||
upstreamBody = patched
|
||||
}
|
||||
}
|
||||
}
|
||||
baseURL := strings.TrimRight(strings.TrimSpace(provider.BaseURL), "/")
|
||||
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+"/chat/completions", bytes.NewReader(upstreamBody))
|
||||
if err != nil {
|
||||
@ -162,8 +181,8 @@ func (g *Gateway) ProxyChatCompletions(ctx context.Context, externalUserID strin
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if req.Stream {
|
||||
return g.proxyStream(ctx, externalUserID, w, resp)
|
||||
if stream {
|
||||
return g.proxyStream(ctx, quotaExternalUserID, w, resp)
|
||||
}
|
||||
payload, err := io.ReadAll(io.LimitReader(resp.Body, 8<<20))
|
||||
if err != nil {
|
||||
@ -175,19 +194,19 @@ func (g *Gateway) ProxyChatCompletions(ctx context.Context, externalUserID strin
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil
|
||||
}
|
||||
tokens := extractUsageTokens(payload)
|
||||
_ = g.quota.Record(ctx, externalUserID, tokens)
|
||||
if strings.TrimSpace(quotaExternalUserID) != "" {
|
||||
tokens := extractUsageTokens(payload)
|
||||
_ = g.quota.Record(ctx, quotaExternalUserID, tokens)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *Gateway) proxyStream(ctx context.Context, externalUserID string, w http.ResponseWriter, resp *http.Response) error {
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
return fmt.Errorf("streaming not supported")
|
||||
}
|
||||
func (g *Gateway) proxyStream(ctx context.Context, quotaExternalUserID string, w http.ResponseWriter, resp *http.Response) error {
|
||||
rc := http.NewResponseController(w)
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
@ -196,7 +215,9 @@ func (g *Gateway) proxyStream(ctx context.Context, externalUserID string, w http
|
||||
line, err := reader.ReadString('\n')
|
||||
if len(line) > 0 {
|
||||
_, _ = w.Write([]byte(line))
|
||||
flusher.Flush()
|
||||
if err := rc.Flush(); err != nil {
|
||||
return fmt.Errorf("streaming not supported: %w", err)
|
||||
}
|
||||
if strings.HasPrefix(line, "data: ") && !strings.Contains(line, "[DONE]") {
|
||||
totalTokens += extractStreamUsageTokens([]byte(strings.TrimPrefix(strings.TrimSpace(line), "data: ")))
|
||||
}
|
||||
@ -213,11 +234,11 @@ func (g *Gateway) proxyStream(ctx context.Context, externalUserID string, w http
|
||||
default:
|
||||
}
|
||||
}
|
||||
if resp.StatusCode < 400 {
|
||||
if resp.StatusCode < 400 && strings.TrimSpace(quotaExternalUserID) != "" {
|
||||
if totalTokens == 0 {
|
||||
totalTokens = 1
|
||||
}
|
||||
_ = g.quota.Record(ctx, externalUserID, totalTokens)
|
||||
_ = g.quota.Record(ctx, quotaExternalUserID, totalTokens)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
111
internal/ai/gateway_messages.go
Normal file
111
internal/ai/gateway_messages.go
Normal file
@ -0,0 +1,111 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func repairChatCompletionBody(body []byte) ([]byte, error) {
|
||||
var req map[string]any
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return body, fmt.Errorf("repair chat completion body: %w", err)
|
||||
}
|
||||
|
||||
rawMessages, ok := req["messages"].([]any)
|
||||
if !ok || len(rawMessages) == 0 {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
repaired := repairToolMessages(rawMessages)
|
||||
if sameMessages(rawMessages, repaired) {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
req["messages"] = repaired
|
||||
return json.Marshal(req)
|
||||
}
|
||||
|
||||
func repairToolMessages(messages []any) []any {
|
||||
out := make([]any, 0, len(messages)+1)
|
||||
for _, raw := range messages {
|
||||
msg, ok := raw.(map[string]any)
|
||||
if !ok {
|
||||
out = append(out, raw)
|
||||
continue
|
||||
}
|
||||
|
||||
if messageRole(msg) == "tool" && !previousMessageHasToolCalls(out) {
|
||||
toolCallID := messageToolCallID(msg)
|
||||
if toolCallID == "" {
|
||||
toolCallID = "call_repaired"
|
||||
}
|
||||
out = append(out, map[string]any{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": []any{
|
||||
map[string]any{
|
||||
"id": toolCallID,
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "tool",
|
||||
"arguments": "{}",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
out = append(out, msg)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func previousMessageHasToolCalls(messages []any) bool {
|
||||
if len(messages) == 0 {
|
||||
return false
|
||||
}
|
||||
msg, ok := messages[len(messages)-1].(map[string]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if messageRole(msg) != "assistant" {
|
||||
return false
|
||||
}
|
||||
toolCalls, ok := msg["tool_calls"].([]any)
|
||||
return ok && len(toolCalls) > 0
|
||||
}
|
||||
|
||||
func messageRole(msg map[string]any) string {
|
||||
role, _ := msg["role"].(string)
|
||||
return role
|
||||
}
|
||||
|
||||
func messageToolCallID(msg map[string]any) string {
|
||||
id, _ := msg["tool_call_id"].(string)
|
||||
return id
|
||||
}
|
||||
|
||||
func sameMessages(before, after []any) bool {
|
||||
if len(before) != len(after) {
|
||||
return false
|
||||
}
|
||||
b, err := json.Marshal(before)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
a, err := json.Marshal(after)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return string(b) == string(a)
|
||||
}
|
||||
|
||||
func chatCompletionStreamRequested(body []byte) bool {
|
||||
var probe struct {
|
||||
Stream *bool `json:"stream"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &probe); err != nil {
|
||||
return false
|
||||
}
|
||||
return probe.Stream != nil && *probe.Stream
|
||||
}
|
||||
112
internal/ai/gateway_messages_test.go
Normal file
112
internal/ai/gateway_messages_test.go
Normal file
@ -0,0 +1,112 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRepairChatCompletionBody_preservesToolCalls(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model":"gpt-4o",
|
||||
"messages":[
|
||||
{"role":"user","content":"list mail"},
|
||||
{"role":"assistant","content":"","tool_calls":[{"id":"call_1","type":"function","function":{"name":"mail_list","arguments":"{\"limit\":3}"}}]},
|
||||
{"role":"tool","tool_call_id":"call_1","content":"{\"items\":[]}"}
|
||||
]
|
||||
}`)
|
||||
|
||||
repaired, err := repairChatCompletionBody(body)
|
||||
if err != nil {
|
||||
t.Fatalf("repairChatCompletionBody() error = %v", err)
|
||||
}
|
||||
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal(repaired, &parsed); err != nil {
|
||||
t.Fatalf("json.Unmarshal() error = %v", err)
|
||||
}
|
||||
messages := parsed["messages"].([]any)
|
||||
assistant := messages[1].(map[string]any)
|
||||
toolCalls, ok := assistant["tool_calls"].([]any)
|
||||
if !ok || len(toolCalls) != 1 {
|
||||
t.Fatalf("expected assistant tool_calls preserved, got %#v", assistant)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepairChatCompletionBody_insertsMissingAssistantToolCalls(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"model":"gpt-4o",
|
||||
"messages":[
|
||||
{"role":"system","content":"prompt"},
|
||||
{"role":"user","content":"list mail"},
|
||||
{"role":"tool","tool_call_id":"call_abc","content":"{\"items\":[]}"}
|
||||
]
|
||||
}`)
|
||||
|
||||
repaired, err := repairChatCompletionBody(body)
|
||||
if err != nil {
|
||||
t.Fatalf("repairChatCompletionBody() error = %v", err)
|
||||
}
|
||||
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal(repaired, &parsed); err != nil {
|
||||
t.Fatalf("json.Unmarshal() error = %v", err)
|
||||
}
|
||||
messages := parsed["messages"].([]any)
|
||||
if len(messages) != 4 {
|
||||
t.Fatalf("messages len = %d, want 4", len(messages))
|
||||
}
|
||||
assistant := messages[2].(map[string]any)
|
||||
if messageRole(assistant) != "assistant" {
|
||||
t.Fatalf("messages[2].role = %q, want assistant", messageRole(assistant))
|
||||
}
|
||||
toolCalls, ok := assistant["tool_calls"].([]any)
|
||||
if !ok || len(toolCalls) != 1 {
|
||||
t.Fatalf("expected inserted tool_calls, got %#v", assistant)
|
||||
}
|
||||
call := toolCalls[0].(map[string]any)
|
||||
if call["id"] != "call_abc" {
|
||||
t.Fatalf("tool_call id = %v, want call_abc", call["id"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepairChatCompletionBody_stripsNothingWhenAlreadyValid(t *testing.T) {
|
||||
body := []byte(`{"model":"gpt-4o","messages":[{"role":"user","content":"hi"}]}`)
|
||||
repaired, err := repairChatCompletionBody(body)
|
||||
if err != nil {
|
||||
t.Fatalf("repairChatCompletionBody() error = %v", err)
|
||||
}
|
||||
if string(repaired) != string(body) {
|
||||
t.Fatalf("expected unchanged body, got %s", repaired)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepairToolMessages_doesNotDuplicateAssistant(t *testing.T) {
|
||||
messages := []any{
|
||||
map[string]any{"role": "user", "content": "list"},
|
||||
map[string]any{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": []any{
|
||||
map[string]any{"id": "call_1", "type": "function"},
|
||||
},
|
||||
},
|
||||
map[string]any{"role": "tool", "tool_call_id": "call_1", "content": "{}"},
|
||||
}
|
||||
|
||||
repaired := repairToolMessages(messages)
|
||||
if len(repaired) != 3 {
|
||||
t.Fatalf("len = %d, want 3", len(repaired))
|
||||
}
|
||||
if !strings.Contains(string(mustJSON(repaired)), "call_1") {
|
||||
t.Fatal("expected original tool_call_id preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func mustJSON(v any) []byte {
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return b
|
||||
}
|
||||
78
internal/ai/openwebui_profile.go
Normal file
78
internal/ai/openwebui_profile.go
Normal file
@ -0,0 +1,78 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/auth"
|
||||
"github.com/ultisuite/ulti-backend/internal/config"
|
||||
"github.com/ultisuite/ulti-backend/internal/permission"
|
||||
)
|
||||
|
||||
const defaultOpenWebUIProfileImage = "/user.png"
|
||||
|
||||
// SyncOpenWebUIProfile pushes profile_image_url to OpenWebUI via trusted-header auth.
|
||||
func SyncOpenWebUIProfile(ctx context.Context, cfg *config.Config, claims *auth.Claims, profileImageURL string) error {
|
||||
if cfg == nil || !cfg.AIAssistantEnabled || claims == nil {
|
||||
return nil
|
||||
}
|
||||
email := strings.TrimSpace(claims.Email)
|
||||
if email == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
baseURL := strings.TrimRight(strings.TrimSpace(cfg.OpenWebUIInternalURL), "/")
|
||||
if baseURL == "" {
|
||||
baseURL = "http://openwebui:8080"
|
||||
}
|
||||
|
||||
profileImageURL = strings.TrimSpace(profileImageURL)
|
||||
if profileImageURL == "" {
|
||||
profileImageURL = defaultOpenWebUIProfileImage
|
||||
}
|
||||
|
||||
body, err := json.Marshal(map[string]string{
|
||||
"profile_image_url": profileImageURL,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+"/api/v1/auths/update/profile", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("X-Ulti-User-Email", email)
|
||||
name := strings.TrimSpace(claims.Name)
|
||||
if name == "" {
|
||||
name = email
|
||||
}
|
||||
req.Header.Set("X-Ulti-User-Name", name)
|
||||
req.Header.Set("X-Ulti-User-Role", openWebUIRoleFromClaims(claims))
|
||||
|
||||
client := &http.Client{Timeout: 15 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||
return nil
|
||||
}
|
||||
raw, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
||||
return fmt.Errorf("openwebui profile sync: %d %s", resp.StatusCode, strings.TrimSpace(string(raw)))
|
||||
}
|
||||
|
||||
func openWebUIRoleFromClaims(claims *auth.Claims) string {
|
||||
if claims != nil && permission.HasRole(claims.Groups, permission.RoleAdmin) {
|
||||
return "admin"
|
||||
}
|
||||
return "user"
|
||||
}
|
||||
@ -165,8 +165,8 @@ func LoadAssistantPolicy(ctx context.Context, db *pgxpool.Pool) (AssistantPolicy
|
||||
defaults := AssistantPolicy{
|
||||
Enabled: false,
|
||||
PublicPath: "/ai",
|
||||
EmbedDefaultTemporary: true,
|
||||
EnabledTools: []string{"mail", "drive", "contacts", "search"},
|
||||
EmbedDefaultTemporary: false,
|
||||
EnabledTools: []string{"mail", "drive", "contacts", "agenda", "search", "web_search"},
|
||||
ChatSyncEnabled: true,
|
||||
ChatNCPath: "/.ultimail/ai/chats",
|
||||
}
|
||||
@ -202,6 +202,32 @@ func LoadAssistantPolicy(ctx context.Context, db *pgxpool.Pool) (AssistantPolicy
|
||||
return stored, nil
|
||||
}
|
||||
|
||||
func ResolveDefaultModel(ctx context.Context, db *pgxpool.Pool, policy AssistantPolicy) string {
|
||||
if model := strings.TrimSpace(policy.DefaultModel); model != "" {
|
||||
return model
|
||||
}
|
||||
settings, err := LoadOrgLLMSettings(ctx, db)
|
||||
if err != nil || len(settings.Providers) == 0 {
|
||||
return ""
|
||||
}
|
||||
if defaultID := strings.TrimSpace(settings.DefaultProviderID); defaultID != "" {
|
||||
for _, provider := range settings.Providers {
|
||||
if provider.ID == defaultID {
|
||||
if model := strings.TrimSpace(provider.DefaultModel); model != "" {
|
||||
return model
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, provider := range settings.Providers {
|
||||
if model := strings.TrimSpace(provider.DefaultModel); model != "" {
|
||||
return model
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func LoadQuotaLimits(ctx context.Context, db *pgxpool.Pool) (QuotaLimits, error) {
|
||||
defaults := QuotaLimits{RequestsPerDay: 100, TokensPerMonth: 500_000}
|
||||
if db == nil {
|
||||
|
||||
@ -78,8 +78,10 @@ type SessionContext struct {
|
||||
}
|
||||
|
||||
type SessionResponse struct {
|
||||
SessionID string `json:"session_id"`
|
||||
EmbedURL string `json:"embed_url"`
|
||||
TokenSecret string `json:"token_secret,omitempty"`
|
||||
Temporary bool `json:"temporary"`
|
||||
SessionID string `json:"session_id"`
|
||||
EmbedURL string `json:"embed_url"`
|
||||
TokenSecret string `json:"token_secret,omitempty"`
|
||||
Temporary bool `json:"temporary"`
|
||||
MCPURL string `json:"mcp_url,omitempty"`
|
||||
EnabledTools []string `json:"enabled_tools,omitempty"`
|
||||
}
|
||||
|
||||
30
internal/ai/users.go
Normal file
30
internal/ai/users.go
Normal file
@ -0,0 +1,30 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
// ResolveExternalIDByEmail maps an Ultimail account email to users.external_id (Authentik sub).
|
||||
func ResolveExternalIDByEmail(ctx context.Context, db *pgxpool.Pool, email string) (string, error) {
|
||||
email = strings.TrimSpace(email)
|
||||
if db == nil || email == "" {
|
||||
return "", nil
|
||||
}
|
||||
var externalID string
|
||||
err := db.QueryRow(ctx, `
|
||||
SELECT external_id FROM users
|
||||
WHERE lower(email) = lower($1) AND status != 'disabled'
|
||||
LIMIT 1
|
||||
`, email).Scan(&externalID)
|
||||
if err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
return "", nil
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
return strings.TrimSpace(externalID), nil
|
||||
}
|
||||
@ -18,6 +18,7 @@ func (h *Handler) registerDriveAdminRoutes(r chi.Router, read, write func(http.H
|
||||
r.With(write).Put("/drive/org-folders/{folderID}", h.UpdateDriveOrgFolder)
|
||||
r.With(write).Delete("/drive/org-folders/{folderID}", h.DeleteDriveOrgFolder)
|
||||
r.With(write).Post("/drive/org-folders/sync", h.SyncDriveOrgFolders)
|
||||
h.registerDriveOrgMountRoutes(r, read, write)
|
||||
}
|
||||
|
||||
func (h *Handler) driveService() *drive.Service {
|
||||
|
||||
90
internal/api/admin/drive_org_mounts.go
Normal file
90
internal/api/admin/drive_org_mounts.go
Normal file
@ -0,0 +1,90 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/api/apiresponse"
|
||||
"github.com/ultisuite/ulti-backend/internal/api/apivalidate"
|
||||
"github.com/ultisuite/ulti-backend/internal/nextcloud"
|
||||
)
|
||||
|
||||
func (h *Handler) registerDriveOrgMountRoutes(r chi.Router, read, write func(http.Handler) http.Handler) {
|
||||
r.With(read).Get("/drive/org-mounts", h.ListDriveOrgMounts)
|
||||
r.With(write).Post("/drive/org-mounts", h.CreateDriveOrgMount)
|
||||
r.With(write).Delete("/drive/org-mounts/{mountID}", h.DeleteDriveOrgMount)
|
||||
}
|
||||
|
||||
func (h *Handler) ListDriveOrgMounts(w http.ResponseWriter, r *http.Request) {
|
||||
svc := h.driveService()
|
||||
mounts, err := svc.ListOrgMountsAdmin(r.Context())
|
||||
if err != nil {
|
||||
h.logger.Error("list drive org mounts", "error", err)
|
||||
apivalidate.WriteInternal(w, r)
|
||||
return
|
||||
}
|
||||
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"mounts": mounts})
|
||||
}
|
||||
|
||||
func (h *Handler) CreateDriveOrgMount(w http.ResponseWriter, r *http.Request) {
|
||||
svc := h.driveService()
|
||||
var req struct {
|
||||
OrgSlug string `json:"org_slug"`
|
||||
DisplayName string `json:"display_name"`
|
||||
WebDAV nextcloud.WebDAVMountConfig `json:"webdav"`
|
||||
}
|
||||
if err := apivalidate.DecodeJSON(w, r, 32<<10, &req); err != nil {
|
||||
return
|
||||
}
|
||||
if verr := validateOrgWebDAVMountRequest(req.OrgSlug, req.DisplayName, req.WebDAV); verr != nil {
|
||||
apivalidate.WriteValidationError(w, r, verr)
|
||||
return
|
||||
}
|
||||
mount, err := svc.CreateOrgWebDAVMount(r.Context(), req.OrgSlug, req.DisplayName, req.WebDAV)
|
||||
if err != nil {
|
||||
writeDriveAdminError(w, r, err)
|
||||
return
|
||||
}
|
||||
apiresponse.WriteJSON(w, http.StatusCreated, mount)
|
||||
}
|
||||
|
||||
func (h *Handler) DeleteDriveOrgMount(w http.ResponseWriter, r *http.Request) {
|
||||
svc := h.driveService()
|
||||
mountID := chi.URLParam(r, "mountID")
|
||||
mount, err := svc.GetMountAdmin(r.Context(), mountID)
|
||||
if err != nil {
|
||||
writeDriveAdminError(w, r, err)
|
||||
return
|
||||
}
|
||||
if mount.Scope != "org" {
|
||||
apivalidate.WriteNotFound(w, r, "mount not found")
|
||||
return
|
||||
}
|
||||
if err := svc.DeleteMount(r.Context(), mountID); err != nil {
|
||||
writeDriveAdminError(w, r, err)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func validateOrgWebDAVMountRequest(orgSlug, displayName string, cfg nextcloud.WebDAVMountConfig) *apivalidate.ValidationError {
|
||||
var details []apivalidate.FieldDetail
|
||||
if strings.TrimSpace(orgSlug) == "" {
|
||||
details = append(details, apivalidate.FieldDetail{Field: "org_slug", Message: "required"})
|
||||
}
|
||||
if strings.TrimSpace(displayName) == "" {
|
||||
details = append(details, apivalidate.FieldDetail{Field: "display_name", Message: "required"})
|
||||
}
|
||||
if strings.TrimSpace(cfg.Host) == "" {
|
||||
details = append(details, apivalidate.FieldDetail{Field: "webdav.host", Message: "required"})
|
||||
}
|
||||
if strings.TrimSpace(cfg.User) == "" {
|
||||
details = append(details, apivalidate.FieldDetail{Field: "webdav.user", Message: "required"})
|
||||
}
|
||||
if len(details) > 0 {
|
||||
return &apivalidate.ValidationError{Details: details}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -59,6 +59,14 @@ func (h *Handler) Routes() chi.Router {
|
||||
r.With(write).Put("/users/{userID}/quota", h.SetQuota)
|
||||
r.With(write).Put("/users/{userID}/role", h.SetUserRole)
|
||||
r.With(write).Delete("/users/{userID}", h.DeleteUser)
|
||||
r.With(write).Post("/users/bulk", h.BulkUsersAction)
|
||||
|
||||
r.With(read).Get("/user-groups", h.ListUserGroups)
|
||||
r.With(read).Get("/user-groups/{groupID}", h.GetUserGroup)
|
||||
r.With(write).Post("/user-groups", h.CreateUserGroup)
|
||||
r.With(write).Put("/user-groups/{groupID}", h.UpdateUserGroup)
|
||||
r.With(write).Delete("/user-groups/{groupID}", h.DeleteUserGroup)
|
||||
r.With(write).Put("/user-groups/{groupID}/members", h.SetUserGroupMembers)
|
||||
|
||||
r.With(read).Get("/audit", h.ListAuditLogs)
|
||||
r.With(read).Get("/audit/export", h.ExportAuditLogs)
|
||||
@ -97,11 +105,17 @@ func (h *Handler) ListUsers(w http.ResponseWriter, r *http.Request) {
|
||||
apivalidate.WriteValidationError(w, r, verr)
|
||||
return
|
||||
}
|
||||
groupID, verr := validateGroupIDFilter(r.URL.Query().Get("group_id"))
|
||||
if verr != nil {
|
||||
apivalidate.WriteValidationError(w, r, verr)
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.svc.ListUsers(r.Context(), params, UserFilter{
|
||||
Status: status,
|
||||
Role: role,
|
||||
Q: strings.TrimSpace(params.Q),
|
||||
Status: status,
|
||||
Role: role,
|
||||
GroupID: groupID,
|
||||
Q: strings.TrimSpace(params.Q),
|
||||
})
|
||||
if err != nil {
|
||||
h.logger.Error("list users", "error", err)
|
||||
|
||||
37
internal/api/admin/org_search_secrets.go
Normal file
37
internal/api/admin/org_search_secrets.go
Normal file
@ -0,0 +1,37 @@
|
||||
package admin
|
||||
|
||||
import "strings"
|
||||
|
||||
func buildWebSearchProviderSecretsStatus(policy map[string]any) map[string]any {
|
||||
searchSection, ok := policy["search"].(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
webSearch, ok := searchSection["web_search"].(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
providers, ok := webSearch["providers"].([]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
out := map[string]any{}
|
||||
for _, item := range providers {
|
||||
pm, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
id, _ := pm["id"].(string)
|
||||
if strings.TrimSpace(id) == "" {
|
||||
continue
|
||||
}
|
||||
apiKey, _ := pm["api_key"].(string)
|
||||
out[id] = map[string]any{
|
||||
"configured": strings.TrimSpace(apiKey) != "",
|
||||
}
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
@ -117,9 +117,9 @@ func defaultOrgPolicy() map[string]any {
|
||||
"enabled": false,
|
||||
"openwebui_internal_url": "",
|
||||
"public_path": "/ai",
|
||||
"embed_default_temporary": true,
|
||||
"embed_default_temporary": false,
|
||||
"default_model": "",
|
||||
"enabled_tools": []any{"mail", "drive", "contacts", "search"},
|
||||
"enabled_tools": []any{"mail", "drive", "contacts", "agenda", "search", "web_search"},
|
||||
"chat_sync_enabled": true,
|
||||
"chat_nc_path": "/.ultimail/ai/chats",
|
||||
"models": []any{},
|
||||
@ -159,7 +159,8 @@ func defaultOrgPolicy() map[string]any {
|
||||
map[string]any{"id": "contact-discovery", "name": "Découverte contacts", "description": "Enrichissement IA et signatures détectées.", "enabled": true, "version": "1.0.0"},
|
||||
map[string]any{"id": "public-share", "name": "Partage public Drive", "description": "Liens publics et partages externes.", "enabled": true, "version": "1.0.0"},
|
||||
map[string]any{"id": "office-editor", "name": "Édition OnlyOffice", "description": "Édition collaborative de documents.", "enabled": false, "version": "1.0.0"},
|
||||
map[string]any{"id": "ai-assistant", "name": "UltiAI", "description": "Assistant IA intégré (chat, tools mail/drive/contacts).", "enabled": false, "version": "1.0.0"},
|
||||
map[string]any{"id": "richtext-editor", "name": "Édition rich text TipTap", "description": "Édition rich text TipTap pour documents Word.", "enabled": true, "version": "1.0.0"},
|
||||
map[string]any{"id": "ai-assistant", "name": "UltiAI", "description": "Assistant IA intégré (chat, tools mail/drive/contacts, recherche web).", "enabled": false, "version": "1.0.0"},
|
||||
},
|
||||
"integrations": []any{
|
||||
map[string]any{"id": "authentik", "name": "Authentik", "description": "SSO, groupes et provisionnement des comptes.", "enabled": true, "configured": false},
|
||||
@ -220,6 +221,62 @@ func mergeMaps(base, patch map[string]any) map[string]any {
|
||||
return out
|
||||
}
|
||||
|
||||
func mergeOrgPlugins(defaults, stored []any) []any {
|
||||
if len(defaults) == 0 {
|
||||
return stored
|
||||
}
|
||||
storedByID := map[string]map[string]any{}
|
||||
for _, item := range stored {
|
||||
m, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
id, _ := m["id"].(string)
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
storedByID[id] = m
|
||||
}
|
||||
out := make([]any, 0, len(defaults)+len(storedByID))
|
||||
seen := map[string]bool{}
|
||||
for _, item := range defaults {
|
||||
def, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
id, _ := def["id"].(string)
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
seen[id] = true
|
||||
if stored, ok := storedByID[id]; ok {
|
||||
out = append(out, mergeMaps(def, stored))
|
||||
continue
|
||||
}
|
||||
out = append(out, def)
|
||||
}
|
||||
for _, item := range stored {
|
||||
m, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
id, _ := m["id"].(string)
|
||||
if id == "" || seen[id] {
|
||||
continue
|
||||
}
|
||||
out = append(out, m)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeOrgPolicy(policy map[string]any) map[string]any {
|
||||
defaults := defaultOrgPolicy()
|
||||
defPlugins, _ := defaults["plugins"].([]any)
|
||||
storedPlugins, _ := policy["plugins"].([]any)
|
||||
policy["plugins"] = mergeOrgPlugins(defPlugins, storedPlugins)
|
||||
return policy
|
||||
}
|
||||
|
||||
func mergeOrgSecrets(existing, patch map[string]any) map[string]any {
|
||||
merged := mergeMaps(existing, patch)
|
||||
secretPaths := []struct {
|
||||
@ -699,6 +756,9 @@ func buildOrgSecretsStatus(policy map[string]any, cfg *config.Config) map[string
|
||||
if llmSecrets := buildLLMProviderSecretsStatus(policy); len(llmSecrets) > 0 {
|
||||
secrets["llm_providers"] = llmSecrets
|
||||
}
|
||||
if webSearchSecrets := buildWebSearchProviderSecretsStatus(policy); len(webSearchSecrets) > 0 {
|
||||
secrets["web_search_providers"] = webSearchSecrets
|
||||
}
|
||||
return secrets
|
||||
}
|
||||
|
||||
@ -802,7 +862,7 @@ func (s *Service) GetOrgSettings(ctx context.Context) (map[string]any, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
policy := mergeMaps(defaultOrgPolicy(), stored)
|
||||
policy := normalizeOrgPolicy(mergeMaps(defaultOrgPolicy(), stored))
|
||||
masked := maskOrgPolicy(policy)
|
||||
return map[string]any{
|
||||
"policy": masked,
|
||||
|
||||
49
internal/api/admin/org_settings_plugins_test.go
Normal file
49
internal/api/admin/org_settings_plugins_test.go
Normal file
@ -0,0 +1,49 @@
|
||||
package admin
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestMergeOrgPluginsAddsMissingDefaults(t *testing.T) {
|
||||
defaults := defaultOrgPolicy()["plugins"].([]any)
|
||||
stored := []any{
|
||||
map[string]any{
|
||||
"id": "mail-automation",
|
||||
"name": "Old name",
|
||||
"description": "Old description",
|
||||
"enabled": false,
|
||||
"version": "0.9.0",
|
||||
},
|
||||
}
|
||||
|
||||
merged := mergeOrgPlugins(defaults, stored)
|
||||
if len(merged) != len(defaults) {
|
||||
t.Fatalf("len(merged) = %d, want %d", len(merged), len(defaults))
|
||||
}
|
||||
|
||||
first, ok := merged[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("first plugin is not a map")
|
||||
}
|
||||
if first["enabled"] != false {
|
||||
t.Fatalf("enabled = %v, want false", first["enabled"])
|
||||
}
|
||||
if first["version"] != "0.9.0" {
|
||||
t.Fatalf("version = %v, want stored override", first["version"])
|
||||
}
|
||||
|
||||
var hasRichtext bool
|
||||
for _, item := range merged {
|
||||
plugin, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if plugin["id"] == "richtext-editor" {
|
||||
hasRichtext = true
|
||||
if plugin["enabled"] != true {
|
||||
t.Fatalf("richtext-editor enabled = %v, want true", plugin["enabled"])
|
||||
}
|
||||
}
|
||||
}
|
||||
if !hasRichtext {
|
||||
t.Fatal("expected richtext-editor from defaults")
|
||||
}
|
||||
}
|
||||
@ -91,10 +91,6 @@ func (s *Service) ListPublicShares(ctx context.Context, params query.ListParams)
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(all, func(i, j int) bool {
|
||||
return shareCreatedAt(all[i]).After(shareCreatedAt(all[j]))
|
||||
})
|
||||
|
||||
tokens := make([]string, 0, len(all))
|
||||
for _, item := range all {
|
||||
if t, _ := item["token"].(string); t != "" {
|
||||
@ -116,6 +112,8 @@ func (s *Service) ListPublicShares(ctx context.Context, params query.ListParams)
|
||||
}
|
||||
}
|
||||
|
||||
sortPublicShares(all, params.Sort)
|
||||
|
||||
total := int64(len(all))
|
||||
start := params.Offset()
|
||||
if start > len(all) {
|
||||
@ -192,3 +190,70 @@ func firstNonEmptyStr(values ...string) string {
|
||||
func ptrInt64(v int64) *int64 {
|
||||
return &v
|
||||
}
|
||||
|
||||
func sortPublicShares(items []map[string]any, sortParam string) {
|
||||
field, desc := parseSortField(sortParam)
|
||||
if field == "" {
|
||||
field = "created_at"
|
||||
desc = true
|
||||
}
|
||||
|
||||
less := func(i, j int) bool {
|
||||
switch field {
|
||||
case "last_access_at":
|
||||
return shareLastAccessAt(items[i]).Before(shareLastAccessAt(items[j]))
|
||||
case "access_count":
|
||||
return shareAccessCount(items[i]) < shareAccessCount(items[j])
|
||||
case "path":
|
||||
return strings.ToLower(shareString(items[i], "path")) < strings.ToLower(shareString(items[j], "path"))
|
||||
case "owner_email":
|
||||
return strings.ToLower(shareString(items[i], "owner_email")) < strings.ToLower(shareString(items[j], "owner_email"))
|
||||
case "updated_at", "created_at":
|
||||
return shareCreatedAt(items[i]).Before(shareCreatedAt(items[j]))
|
||||
default:
|
||||
return shareCreatedAt(items[i]).Before(shareCreatedAt(items[j]))
|
||||
}
|
||||
}
|
||||
|
||||
sort.SliceStable(items, func(i, j int) bool {
|
||||
if desc {
|
||||
return !less(i, j)
|
||||
}
|
||||
return less(i, j)
|
||||
})
|
||||
}
|
||||
|
||||
func shareString(item map[string]any, key string) string {
|
||||
raw, _ := item[key].(string)
|
||||
return strings.TrimSpace(raw)
|
||||
}
|
||||
|
||||
func shareAccessCount(item map[string]any) int64 {
|
||||
switch v := item["access_count"].(type) {
|
||||
case int64:
|
||||
return v
|
||||
case int:
|
||||
return int64(v)
|
||||
case float64:
|
||||
return int64(v)
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func shareLastAccessAt(item map[string]any) time.Time {
|
||||
raw, ok := item["last_access_at"]
|
||||
if !ok || raw == nil {
|
||||
return time.Time{}
|
||||
}
|
||||
switch v := raw.(type) {
|
||||
case string:
|
||||
t, err := time.Parse(time.RFC3339, v)
|
||||
if err != nil {
|
||||
return time.Time{}
|
||||
}
|
||||
return t
|
||||
default:
|
||||
return time.Time{}
|
||||
}
|
||||
}
|
||||
|
||||
@ -62,9 +62,10 @@ type UsersList struct {
|
||||
}
|
||||
|
||||
type UserFilter struct {
|
||||
Status string
|
||||
Role string
|
||||
Q string
|
||||
Status string
|
||||
Role string
|
||||
GroupID string
|
||||
Q string
|
||||
}
|
||||
|
||||
func (s *Service) ListUsers(ctx context.Context, params query.ListParams, filter UserFilter) (UsersList, error) {
|
||||
@ -79,7 +80,7 @@ func (s *Service) ListUsers(ctx context.Context, params query.ListParams, filter
|
||||
listSQL := `
|
||||
SELECT id, external_id, email, name, status, platform_admin, invited_at, disabled_at, created_at, updated_at
|
||||
FROM users` + whereSQL + `
|
||||
ORDER BY created_at DESC
|
||||
` + buildUserOrderBy(params.Sort) + `
|
||||
LIMIT $` + strconv.Itoa(len(args)+1) + ` OFFSET $` + strconv.Itoa(len(args)+2)
|
||||
args = append(args, params.Limit(), params.Offset())
|
||||
|
||||
@ -103,6 +104,9 @@ func (s *Service) ListUsers(ctx context.Context, params query.ListParams, filter
|
||||
if err := s.attachUsersStorage(ctx, users); err != nil {
|
||||
return UsersList{}, err
|
||||
}
|
||||
if err := s.attachUsersGroups(ctx, users); err != nil {
|
||||
return UsersList{}, err
|
||||
}
|
||||
|
||||
return UsersList{
|
||||
Users: users,
|
||||
@ -111,8 +115,8 @@ func (s *Service) ListUsers(ctx context.Context, params query.ListParams, filter
|
||||
}
|
||||
|
||||
func buildUserFilter(filter UserFilter) (string, []any) {
|
||||
clauses := make([]string, 0, 3)
|
||||
args := make([]any, 0, 3)
|
||||
clauses := make([]string, 0, 4)
|
||||
args := make([]any, 0, 4)
|
||||
if role := strings.TrimSpace(filter.Role); role != "" {
|
||||
switch permission.AccountRole(role) {
|
||||
case permission.AccountRoleAdmin:
|
||||
@ -135,6 +139,11 @@ func buildUserFilter(filter UserFilter) (string, []any) {
|
||||
idx := strconv.Itoa(len(args))
|
||||
clauses = append(clauses, "(LOWER(email) LIKE $"+idx+" OR LOWER(name) LIKE $"+idx+" OR LOWER(external_id) LIKE $"+idx+")")
|
||||
}
|
||||
if strings.TrimSpace(filter.GroupID) != "" {
|
||||
args = append(args, strings.TrimSpace(filter.GroupID))
|
||||
idx := strconv.Itoa(len(args))
|
||||
clauses = append(clauses, "EXISTS (SELECT 1 FROM user_group_members ugm WHERE ugm.user_id = users.id AND ugm.group_id = $"+idx+")")
|
||||
}
|
||||
if len(clauses) == 0 {
|
||||
return "", args
|
||||
}
|
||||
@ -188,6 +197,9 @@ func (s *Service) GetUser(ctx context.Context, userID string) (map[string]any, e
|
||||
"max_storage_bytes": photosMax,
|
||||
},
|
||||
}
|
||||
if err := s.attachUsersGroups(ctx, []map[string]any{user}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
|
||||
35
internal/api/admin/sort.go
Normal file
35
internal/api/admin/sort.go
Normal file
@ -0,0 +1,35 @@
|
||||
package admin
|
||||
|
||||
import "strings"
|
||||
|
||||
func parseSortField(sort string) (field string, desc bool) {
|
||||
sort = strings.TrimSpace(sort)
|
||||
if sort == "" {
|
||||
return "", false
|
||||
}
|
||||
if strings.HasPrefix(sort, "-") {
|
||||
return strings.TrimPrefix(sort, "-"), true
|
||||
}
|
||||
return sort, false
|
||||
}
|
||||
|
||||
func buildUserOrderBy(sort string) string {
|
||||
field, desc := parseSortField(sort)
|
||||
allowed := map[string]string{
|
||||
"created_at": "created_at",
|
||||
"updated_at": "updated_at",
|
||||
"email": "LOWER(email)",
|
||||
"name": "LOWER(name)",
|
||||
"status": "status",
|
||||
"external_id": "LOWER(external_id)",
|
||||
}
|
||||
col, ok := allowed[field]
|
||||
if !ok {
|
||||
return "ORDER BY created_at DESC"
|
||||
}
|
||||
dir := "ASC"
|
||||
if desc {
|
||||
dir = "DESC"
|
||||
}
|
||||
return "ORDER BY " + col + " " + dir
|
||||
}
|
||||
31
internal/api/admin/sort_test.go
Normal file
31
internal/api/admin/sort_test.go
Normal file
@ -0,0 +1,31 @@
|
||||
package admin
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestBuildUserOrderBy(t *testing.T) {
|
||||
tests := []struct {
|
||||
sort string
|
||||
want string
|
||||
}{
|
||||
{"", "ORDER BY created_at DESC"},
|
||||
{"-email", "ORDER BY LOWER(email) DESC"},
|
||||
{"name", "ORDER BY LOWER(name) ASC"},
|
||||
{"invalid", "ORDER BY created_at DESC"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := buildUserOrderBy(tt.sort); got != tt.want {
|
||||
t.Fatalf("buildUserOrderBy(%q) = %q, want %q", tt.sort, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSortField(t *testing.T) {
|
||||
field, desc := parseSortField("-created_at")
|
||||
if field != "created_at" || !desc {
|
||||
t.Fatalf("parseSortField = %q %v", field, desc)
|
||||
}
|
||||
field, desc = parseSortField("email")
|
||||
if field != "email" || desc {
|
||||
t.Fatalf("parseSortField = %q %v", field, desc)
|
||||
}
|
||||
}
|
||||
131
internal/api/admin/user_bulk.go
Normal file
131
internal/api/admin/user_bulk.go
Normal file
@ -0,0 +1,131 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
platformusers "github.com/ultisuite/ulti-backend/internal/users"
|
||||
)
|
||||
|
||||
type bulkUsersRequest struct {
|
||||
UserIDs []string `json:"user_ids"`
|
||||
Action string `json:"action"`
|
||||
Role string `json:"role"`
|
||||
GroupID string `json:"group_id"`
|
||||
}
|
||||
|
||||
type bulkUsersResult struct {
|
||||
SuccessCount int `json:"success_count"`
|
||||
Failed []bulkUserFailure `json:"failed,omitempty"`
|
||||
}
|
||||
|
||||
type bulkUserFailure struct {
|
||||
UserID string `json:"user_id"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
func (s *Service) BulkUsersAction(ctx context.Context, actorSub string, req bulkUsersRequest) (bulkUsersResult, error) {
|
||||
action := strings.ToLower(strings.TrimSpace(req.Action))
|
||||
userIDs := dedupeNonEmpty(req.UserIDs)
|
||||
if len(userIDs) == 0 {
|
||||
return bulkUsersResult{}, fmt.Errorf("no user ids")
|
||||
}
|
||||
|
||||
result := bulkUsersResult{}
|
||||
switch action {
|
||||
case "disable":
|
||||
for _, userID := range userIDs {
|
||||
if err := s.DisableUser(ctx, actorSub, userID); err != nil {
|
||||
result.Failed = append(result.Failed, bulkUserFailure{UserID: userID, Message: bulkErrorMessage(err)})
|
||||
continue
|
||||
}
|
||||
result.SuccessCount++
|
||||
}
|
||||
case "reactivate":
|
||||
for _, userID := range userIDs {
|
||||
if err := s.ReactivateUser(ctx, actorSub, userID); err != nil {
|
||||
result.Failed = append(result.Failed, bulkUserFailure{UserID: userID, Message: bulkErrorMessage(err)})
|
||||
continue
|
||||
}
|
||||
result.SuccessCount++
|
||||
}
|
||||
case "delete":
|
||||
for _, userID := range userIDs {
|
||||
if err := s.DeleteUser(ctx, actorSub, userID); err != nil {
|
||||
result.Failed = append(result.Failed, bulkUserFailure{UserID: userID, Message: bulkErrorMessage(err)})
|
||||
continue
|
||||
}
|
||||
result.SuccessCount++
|
||||
}
|
||||
case "set_role":
|
||||
role := strings.TrimSpace(req.Role)
|
||||
for _, userID := range userIDs {
|
||||
if _, err := s.SetUserRole(ctx, actorSub, userID, role); err != nil {
|
||||
result.Failed = append(result.Failed, bulkUserFailure{UserID: userID, Message: bulkErrorMessage(err)})
|
||||
continue
|
||||
}
|
||||
result.SuccessCount++
|
||||
}
|
||||
case "add_to_group":
|
||||
groupID := strings.TrimSpace(req.GroupID)
|
||||
if groupID == "" {
|
||||
return bulkUsersResult{}, fmt.Errorf("group_id required")
|
||||
}
|
||||
if err := s.AddUsersToGroup(ctx, actorSub, groupID, userIDs); err != nil {
|
||||
return bulkUsersResult{}, err
|
||||
}
|
||||
result.SuccessCount = len(userIDs)
|
||||
case "remove_from_group":
|
||||
groupID := strings.TrimSpace(req.GroupID)
|
||||
if groupID == "" {
|
||||
return bulkUsersResult{}, fmt.Errorf("group_id required")
|
||||
}
|
||||
if err := s.RemoveUsersFromGroup(ctx, actorSub, groupID, userIDs); err != nil {
|
||||
return bulkUsersResult{}, err
|
||||
}
|
||||
result.SuccessCount = len(userIDs)
|
||||
default:
|
||||
return bulkUsersResult{}, fmt.Errorf("invalid action")
|
||||
}
|
||||
|
||||
s.logAudit(ctx, actorSub, "bulk_users_action", map[string]any{
|
||||
"action": action,
|
||||
"success_count": result.SuccessCount,
|
||||
"failed_count": len(result.Failed),
|
||||
"group_id": strings.TrimSpace(req.GroupID),
|
||||
"role": strings.TrimSpace(req.Role),
|
||||
})
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func bulkErrorMessage(err error) string {
|
||||
if errors.Is(err, ErrNotFound) {
|
||||
return "not found"
|
||||
}
|
||||
if errors.Is(err, platformusers.ErrLastPlatformAdmin) {
|
||||
return "cannot remove the last platform admin"
|
||||
}
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
return err.Error()
|
||||
}
|
||||
|
||||
func dedupeNonEmpty(ids []string) []string {
|
||||
seen := make(map[string]struct{}, len(ids))
|
||||
out := make([]string, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
id = strings.TrimSpace(id)
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
out = append(out, id)
|
||||
}
|
||||
return out
|
||||
}
|
||||
366
internal/api/admin/user_groups.go
Normal file
366
internal/api/admin/user_groups.go
Normal file
@ -0,0 +1,366 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/api/query"
|
||||
)
|
||||
|
||||
var ErrGroupNameTaken = errors.New("group name already exists")
|
||||
|
||||
type UserGroupsList struct {
|
||||
Groups []map[string]any `json:"groups"`
|
||||
Pagination query.PaginationMeta `json:"pagination,omitempty"`
|
||||
}
|
||||
|
||||
func (s *Service) ListUserGroups(ctx context.Context, params query.ListParams) (UserGroupsList, error) {
|
||||
q := strings.TrimSpace(params.Q)
|
||||
whereSQL := ""
|
||||
args := make([]any, 0, 3)
|
||||
if q != "" {
|
||||
pattern := "%" + strings.ToLower(q) + "%"
|
||||
args = append(args, pattern)
|
||||
whereSQL = " WHERE LOWER(name) LIKE $1 OR LOWER(description) LIKE $1"
|
||||
}
|
||||
|
||||
var total int64
|
||||
if err := s.db.QueryRow(ctx, "SELECT COUNT(*) FROM user_groups"+whereSQL, args...).Scan(&total); err != nil {
|
||||
return UserGroupsList{}, err
|
||||
}
|
||||
|
||||
listSQL := `
|
||||
SELECT g.id, g.name, g.description, g.created_at, g.updated_at,
|
||||
COALESCE(m.member_count, 0) AS member_count
|
||||
FROM user_groups g
|
||||
LEFT JOIN (
|
||||
SELECT group_id, COUNT(*) AS member_count
|
||||
FROM user_group_members
|
||||
GROUP BY group_id
|
||||
) m ON m.group_id = g.id` + whereSQL + `
|
||||
ORDER BY LOWER(g.name)
|
||||
LIMIT $` + strconv.Itoa(len(args)+1) + ` OFFSET $` + strconv.Itoa(len(args)+2)
|
||||
args = append(args, params.Limit(), params.Offset())
|
||||
|
||||
rows, err := s.db.Query(ctx, listSQL, args...)
|
||||
if err != nil {
|
||||
return UserGroupsList{}, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
groups := make([]map[string]any, 0)
|
||||
for rows.Next() {
|
||||
group, err := scanUserGroupRow(rows)
|
||||
if err != nil {
|
||||
return UserGroupsList{}, err
|
||||
}
|
||||
groups = append(groups, group)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return UserGroupsList{}, err
|
||||
}
|
||||
|
||||
return UserGroupsList{
|
||||
Groups: groups,
|
||||
Pagination: params.Meta(&total),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Service) GetUserGroup(ctx context.Context, groupID string) (map[string]any, error) {
|
||||
row := s.db.QueryRow(ctx, `
|
||||
SELECT g.id, g.name, g.description, g.created_at, g.updated_at,
|
||||
COALESCE(m.member_count, 0) AS member_count
|
||||
FROM user_groups g
|
||||
LEFT JOIN (
|
||||
SELECT group_id, COUNT(*) AS member_count
|
||||
FROM user_group_members
|
||||
GROUP BY group_id
|
||||
) m ON m.group_id = g.id
|
||||
WHERE g.id = $1
|
||||
`, groupID)
|
||||
group, err := scanUserGroupRow(row)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return group, nil
|
||||
}
|
||||
|
||||
type createUserGroupRequest struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
type updateUserGroupRequest struct {
|
||||
Name *string `json:"name"`
|
||||
Description *string `json:"description"`
|
||||
}
|
||||
|
||||
func (s *Service) CreateUserGroup(ctx context.Context, actorSub string, req createUserGroupRequest) (map[string]any, error) {
|
||||
name := strings.TrimSpace(req.Name)
|
||||
description := strings.TrimSpace(req.Description)
|
||||
var id string
|
||||
err := s.db.QueryRow(ctx, `
|
||||
INSERT INTO user_groups (name, description)
|
||||
VALUES ($1, $2)
|
||||
RETURNING id
|
||||
`, name, description).Scan(&id)
|
||||
if err != nil {
|
||||
if isUniqueViolation(err) {
|
||||
return nil, ErrGroupNameTaken
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
s.logAudit(ctx, actorSub, "create_user_group", map[string]any{
|
||||
"group_id": id,
|
||||
"name": name,
|
||||
})
|
||||
return s.GetUserGroup(ctx, id)
|
||||
}
|
||||
|
||||
func (s *Service) UpdateUserGroup(ctx context.Context, actorSub, groupID string, req updateUserGroupRequest) (map[string]any, error) {
|
||||
result, err := s.db.Exec(ctx, `
|
||||
UPDATE user_groups
|
||||
SET name = COALESCE($2, name),
|
||||
description = COALESCE($3, description),
|
||||
updated_at = NOW()
|
||||
WHERE id = $1
|
||||
`, groupID, trimStringPtr(req.Name), trimStringPtr(req.Description))
|
||||
if err != nil {
|
||||
if isUniqueViolation(err) {
|
||||
return nil, ErrGroupNameTaken
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if result.RowsAffected() == 0 {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
s.logAudit(ctx, actorSub, "update_user_group", map[string]any{"group_id": groupID})
|
||||
return s.GetUserGroup(ctx, groupID)
|
||||
}
|
||||
|
||||
func (s *Service) DeleteUserGroup(ctx context.Context, actorSub, groupID string) error {
|
||||
result, err := s.db.Exec(ctx, `DELETE FROM user_groups WHERE id = $1`, groupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if result.RowsAffected() == 0 {
|
||||
return ErrNotFound
|
||||
}
|
||||
s.logAudit(ctx, actorSub, "delete_user_group", map[string]any{"group_id": groupID})
|
||||
return nil
|
||||
}
|
||||
|
||||
type setGroupMembersRequest struct {
|
||||
UserIDs []string `json:"user_ids"`
|
||||
}
|
||||
|
||||
func (s *Service) SetGroupMembers(ctx context.Context, actorSub, groupID string, req setGroupMembersRequest) (map[string]any, error) {
|
||||
exists, err := s.groupExists(ctx, groupID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !exists {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
|
||||
tx, err := s.db.Begin(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
if _, err := tx.Exec(ctx, `DELETE FROM user_group_members WHERE group_id = $1`, groupID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, userID := range req.UserIDs {
|
||||
userID = strings.TrimSpace(userID)
|
||||
if userID == "" {
|
||||
continue
|
||||
}
|
||||
ok, err := s.userExistsTx(ctx, tx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if _, err := tx.Exec(ctx, `
|
||||
INSERT INTO user_group_members (group_id, user_id)
|
||||
VALUES ($1, $2)
|
||||
ON CONFLICT DO NOTHING
|
||||
`, groupID, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.logAudit(ctx, actorSub, "set_user_group_members", map[string]any{
|
||||
"group_id": groupID,
|
||||
"member_count": len(req.UserIDs),
|
||||
})
|
||||
return s.GetUserGroup(ctx, groupID)
|
||||
}
|
||||
|
||||
func (s *Service) AddUsersToGroup(ctx context.Context, actorSub, groupID string, userIDs []string) error {
|
||||
exists, err := s.groupExists(ctx, groupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !exists {
|
||||
return ErrNotFound
|
||||
}
|
||||
for _, userID := range userIDs {
|
||||
userID = strings.TrimSpace(userID)
|
||||
if userID == "" {
|
||||
continue
|
||||
}
|
||||
ok, err := s.userExists(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if _, err := s.db.Exec(ctx, `
|
||||
INSERT INTO user_group_members (group_id, user_id)
|
||||
VALUES ($1, $2)
|
||||
ON CONFLICT DO NOTHING
|
||||
`, groupID, userID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
s.logAudit(ctx, actorSub, "add_users_to_group", map[string]any{
|
||||
"group_id": groupID,
|
||||
"count": len(userIDs),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) RemoveUsersFromGroup(ctx context.Context, actorSub, groupID string, userIDs []string) error {
|
||||
exists, err := s.groupExists(ctx, groupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !exists {
|
||||
return ErrNotFound
|
||||
}
|
||||
for _, userID := range userIDs {
|
||||
userID = strings.TrimSpace(userID)
|
||||
if userID == "" {
|
||||
continue
|
||||
}
|
||||
if _, err := s.db.Exec(ctx, `
|
||||
DELETE FROM user_group_members WHERE group_id = $1 AND user_id = $2
|
||||
`, groupID, userID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
s.logAudit(ctx, actorSub, "remove_users_from_group", map[string]any{
|
||||
"group_id": groupID,
|
||||
"count": len(userIDs),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) groupExists(ctx context.Context, groupID string) (bool, error) {
|
||||
var exists bool
|
||||
if err := s.db.QueryRow(ctx, `SELECT EXISTS(SELECT 1 FROM user_groups WHERE id = $1)`, groupID).Scan(&exists); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
func (s *Service) userExistsTx(ctx context.Context, tx pgx.Tx, userID string) (bool, error) {
|
||||
var exists bool
|
||||
if err := tx.QueryRow(ctx, `SELECT EXISTS(SELECT 1 FROM users WHERE id = $1)`, userID).Scan(&exists); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
type userGroupRowScanner interface {
|
||||
Scan(dest ...any) error
|
||||
}
|
||||
|
||||
func scanUserGroupRow(row userGroupRowScanner) (map[string]any, error) {
|
||||
var id, name, description string
|
||||
var memberCount int64
|
||||
var createdAt, updatedAt any
|
||||
if err := row.Scan(&id, &name, &description, &createdAt, &updatedAt, &memberCount); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return map[string]any{
|
||||
"id": id,
|
||||
"name": name,
|
||||
"description": description,
|
||||
"member_count": memberCount,
|
||||
"created_at": createdAt,
|
||||
"updated_at": updatedAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Service) attachUsersGroups(ctx context.Context, users []map[string]any) error {
|
||||
if len(users) == 0 {
|
||||
return nil
|
||||
}
|
||||
ids := make([]string, 0, len(users))
|
||||
byID := make(map[string]map[string]any, len(users))
|
||||
for _, user := range users {
|
||||
id, _ := user["id"].(string)
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
ids = append(ids, id)
|
||||
byID[id] = user
|
||||
user["groups"] = []map[string]any{}
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
rows, err := s.db.Query(ctx, `
|
||||
SELECT ugm.user_id::text, g.id, g.name
|
||||
FROM user_group_members ugm
|
||||
JOIN user_groups g ON g.id = ugm.group_id
|
||||
WHERE ugm.user_id = ANY($1::uuid[])
|
||||
ORDER BY LOWER(g.name)
|
||||
`, ids)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var userID, groupID, name string
|
||||
if err := rows.Scan(&userID, &groupID, &name); err != nil {
|
||||
return err
|
||||
}
|
||||
user, ok := byID[userID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
groups, _ := user["groups"].([]map[string]any)
|
||||
user["groups"] = append(groups, map[string]any{
|
||||
"id": groupID,
|
||||
"name": name,
|
||||
})
|
||||
}
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
func isUniqueViolation(err error) bool {
|
||||
var pgErr *pgconn.PgError
|
||||
return errors.As(err, &pgErr) && pgErr.Code == "23505"
|
||||
}
|
||||
187
internal/api/admin/user_groups_handlers.go
Normal file
187
internal/api/admin/user_groups_handlers.go
Normal file
@ -0,0 +1,187 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/api/apiresponse"
|
||||
"github.com/ultisuite/ulti-backend/internal/api/apivalidate"
|
||||
"github.com/ultisuite/ulti-backend/internal/api/middleware"
|
||||
"github.com/ultisuite/ulti-backend/internal/api/query"
|
||||
)
|
||||
|
||||
func (h *Handler) ListUserGroups(w http.ResponseWriter, r *http.Request) {
|
||||
params, err := query.ParseListRequest(r)
|
||||
if err != nil {
|
||||
apivalidate.WriteQueryError(w, r, err)
|
||||
return
|
||||
}
|
||||
result, err := h.svc.ListUserGroups(r.Context(), params)
|
||||
if err != nil {
|
||||
h.logger.Error("list user groups", "error", err)
|
||||
apivalidate.WriteInternal(w, r)
|
||||
return
|
||||
}
|
||||
apiresponse.WriteJSON(w, http.StatusOK, result)
|
||||
}
|
||||
|
||||
func (h *Handler) GetUserGroup(w http.ResponseWriter, r *http.Request) {
|
||||
groupID := chi.URLParam(r, "groupID")
|
||||
if verr := validateGroupID(groupID); verr != nil {
|
||||
apivalidate.WriteValidationError(w, r, verr)
|
||||
return
|
||||
}
|
||||
group, err := h.svc.GetUserGroup(r.Context(), groupID)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrNotFound) {
|
||||
apivalidate.WriteNotFound(w, r, "not found")
|
||||
return
|
||||
}
|
||||
h.logger.Error("get user group", "error", err)
|
||||
apivalidate.WriteInternal(w, r)
|
||||
return
|
||||
}
|
||||
apiresponse.WriteJSON(w, http.StatusOK, group)
|
||||
}
|
||||
|
||||
func (h *Handler) CreateUserGroup(w http.ResponseWriter, r *http.Request) {
|
||||
claims := middleware.ClaimsFromContext(r.Context())
|
||||
var req createUserGroupRequest
|
||||
if err := apivalidate.DecodeJSON(w, r, maxQuotaRequestBody, &req); err != nil {
|
||||
return
|
||||
}
|
||||
if verr := validateCreateUserGroup(&req); verr != nil {
|
||||
apivalidate.WriteValidationError(w, r, verr)
|
||||
return
|
||||
}
|
||||
group, err := h.svc.CreateUserGroup(r.Context(), claims.Sub, req)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrGroupNameTaken) {
|
||||
apivalidate.WriteValidationError(w, r, apivalidate.NewValidationError(
|
||||
apivalidate.FieldDetail{Field: "name", Message: "already exists"},
|
||||
))
|
||||
return
|
||||
}
|
||||
h.logger.Error("create user group", "error", err)
|
||||
apivalidate.WriteInternal(w, r)
|
||||
return
|
||||
}
|
||||
apiresponse.WriteJSON(w, http.StatusCreated, group)
|
||||
}
|
||||
|
||||
func (h *Handler) UpdateUserGroup(w http.ResponseWriter, r *http.Request) {
|
||||
groupID := chi.URLParam(r, "groupID")
|
||||
if verr := validateGroupID(groupID); verr != nil {
|
||||
apivalidate.WriteValidationError(w, r, verr)
|
||||
return
|
||||
}
|
||||
claims := middleware.ClaimsFromContext(r.Context())
|
||||
var req updateUserGroupRequest
|
||||
if err := apivalidate.DecodeJSON(w, r, maxQuotaRequestBody, &req); err != nil {
|
||||
return
|
||||
}
|
||||
if verr := validateUpdateUserGroup(&req); verr != nil {
|
||||
apivalidate.WriteValidationError(w, r, verr)
|
||||
return
|
||||
}
|
||||
group, err := h.svc.UpdateUserGroup(r.Context(), claims.Sub, groupID, req)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrNotFound) {
|
||||
apivalidate.WriteNotFound(w, r, "not found")
|
||||
return
|
||||
}
|
||||
if errors.Is(err, ErrGroupNameTaken) {
|
||||
apivalidate.WriteValidationError(w, r, apivalidate.NewValidationError(
|
||||
apivalidate.FieldDetail{Field: "name", Message: "already exists"},
|
||||
))
|
||||
return
|
||||
}
|
||||
h.logger.Error("update user group", "error", err)
|
||||
apivalidate.WriteInternal(w, r)
|
||||
return
|
||||
}
|
||||
apiresponse.WriteJSON(w, http.StatusOK, group)
|
||||
}
|
||||
|
||||
func (h *Handler) DeleteUserGroup(w http.ResponseWriter, r *http.Request) {
|
||||
groupID := chi.URLParam(r, "groupID")
|
||||
if verr := validateGroupID(groupID); verr != nil {
|
||||
apivalidate.WriteValidationError(w, r, verr)
|
||||
return
|
||||
}
|
||||
claims := middleware.ClaimsFromContext(r.Context())
|
||||
if err := h.svc.DeleteUserGroup(r.Context(), claims.Sub, groupID); err != nil {
|
||||
if errors.Is(err, ErrNotFound) {
|
||||
apivalidate.WriteNotFound(w, r, "not found")
|
||||
return
|
||||
}
|
||||
h.logger.Error("delete user group", "error", err)
|
||||
apivalidate.WriteInternal(w, r)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (h *Handler) SetUserGroupMembers(w http.ResponseWriter, r *http.Request) {
|
||||
groupID := chi.URLParam(r, "groupID")
|
||||
if verr := validateGroupID(groupID); verr != nil {
|
||||
apivalidate.WriteValidationError(w, r, verr)
|
||||
return
|
||||
}
|
||||
claims := middleware.ClaimsFromContext(r.Context())
|
||||
var req setGroupMembersRequest
|
||||
if err := apivalidate.DecodeJSON(w, r, maxQuotaRequestBody, &req); err != nil {
|
||||
return
|
||||
}
|
||||
group, err := h.svc.SetGroupMembers(r.Context(), claims.Sub, groupID, req)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrNotFound) {
|
||||
apivalidate.WriteNotFound(w, r, "not found")
|
||||
return
|
||||
}
|
||||
h.logger.Error("set user group members", "error", err)
|
||||
apivalidate.WriteInternal(w, r)
|
||||
return
|
||||
}
|
||||
apiresponse.WriteJSON(w, http.StatusOK, group)
|
||||
}
|
||||
|
||||
func (h *Handler) BulkUsersAction(w http.ResponseWriter, r *http.Request) {
|
||||
claims := middleware.ClaimsFromContext(r.Context())
|
||||
var req bulkUsersRequest
|
||||
if err := apivalidate.DecodeJSON(w, r, maxQuotaRequestBody, &req); err != nil {
|
||||
return
|
||||
}
|
||||
if verr := validateBulkUsers(&req); verr != nil {
|
||||
apivalidate.WriteValidationError(w, r, verr)
|
||||
return
|
||||
}
|
||||
result, err := h.svc.BulkUsersAction(r.Context(), claims.Sub, req)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "invalid action") {
|
||||
apivalidate.WriteValidationError(w, r, apivalidate.NewValidationError(
|
||||
apivalidate.FieldDetail{Field: "action", Message: "invalid"},
|
||||
))
|
||||
return
|
||||
}
|
||||
if strings.Contains(err.Error(), "group_id required") {
|
||||
apivalidate.WriteValidationError(w, r, apivalidate.NewValidationError(
|
||||
apivalidate.FieldDetail{Field: "group_id", Message: "required"},
|
||||
))
|
||||
return
|
||||
}
|
||||
if strings.Contains(err.Error(), "no user ids") {
|
||||
apivalidate.WriteValidationError(w, r, apivalidate.NewValidationError(
|
||||
apivalidate.FieldDetail{Field: "user_ids", Message: "required"},
|
||||
))
|
||||
return
|
||||
}
|
||||
h.logger.Error("bulk users action", "error", err)
|
||||
apivalidate.WriteInternal(w, r)
|
||||
return
|
||||
}
|
||||
apiresponse.WriteJSON(w, http.StatusOK, result)
|
||||
}
|
||||
@ -155,3 +155,68 @@ func validateAccountRoleFilter(raw string) (string, *apivalidate.ValidationError
|
||||
}
|
||||
return role, nil
|
||||
}
|
||||
|
||||
func validateGroupID(groupID string) *apivalidate.ValidationError {
|
||||
if strings.TrimSpace(groupID) == "" {
|
||||
return apivalidate.NewValidationError(apivalidate.FieldDetail{Field: "groupID", Message: "required"})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateGroupIDFilter(raw string) (string, *apivalidate.ValidationError) {
|
||||
groupID := strings.TrimSpace(raw)
|
||||
if groupID == "" {
|
||||
return "", nil
|
||||
}
|
||||
return groupID, nil
|
||||
}
|
||||
|
||||
func validateCreateUserGroup(req *createUserGroupRequest) *apivalidate.ValidationError {
|
||||
if strings.TrimSpace(req.Name) == "" {
|
||||
return apivalidate.NewValidationError(apivalidate.FieldDetail{Field: "name", Message: "required"})
|
||||
}
|
||||
if len(strings.TrimSpace(req.Name)) > 120 {
|
||||
return apivalidate.NewValidationError(apivalidate.FieldDetail{Field: "name", Message: "must be at most 120 characters"})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateUpdateUserGroup(req *updateUserGroupRequest) *apivalidate.ValidationError {
|
||||
if req.Name == nil && req.Description == nil {
|
||||
return apivalidate.NewValidationError(apivalidate.FieldDetail{Field: "group", Message: "at least one field is required"})
|
||||
}
|
||||
if req.Name != nil {
|
||||
if strings.TrimSpace(*req.Name) == "" {
|
||||
return apivalidate.NewValidationError(apivalidate.FieldDetail{Field: "name", Message: "required"})
|
||||
}
|
||||
if len(strings.TrimSpace(*req.Name)) > 120 {
|
||||
return apivalidate.NewValidationError(apivalidate.FieldDetail{Field: "name", Message: "must be at most 120 characters"})
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateBulkUsers(req *bulkUsersRequest) *apivalidate.ValidationError {
|
||||
if len(req.UserIDs) == 0 {
|
||||
return apivalidate.NewValidationError(apivalidate.FieldDetail{Field: "user_ids", Message: "required"})
|
||||
}
|
||||
action := strings.ToLower(strings.TrimSpace(req.Action))
|
||||
switch action {
|
||||
case "disable", "reactivate", "delete", "add_to_group", "remove_from_group":
|
||||
if action == "add_to_group" || action == "remove_from_group" {
|
||||
if strings.TrimSpace(req.GroupID) == "" {
|
||||
return apivalidate.NewValidationError(apivalidate.FieldDetail{Field: "group_id", Message: "required"})
|
||||
}
|
||||
}
|
||||
case "set_role":
|
||||
if _, ok := permission.ParseAccountRole(req.Role); !ok {
|
||||
return apivalidate.NewValidationError(apivalidate.FieldDetail{
|
||||
Field: "role",
|
||||
Message: "must be one of: admin,user,guest,suspended",
|
||||
})
|
||||
}
|
||||
default:
|
||||
return apivalidate.NewValidationError(apivalidate.FieldDetail{Field: "action", Message: "invalid"})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
@ -20,6 +21,7 @@ import (
|
||||
"github.com/ultisuite/ulti-backend/internal/config"
|
||||
"github.com/ultisuite/ulti-backend/internal/nextcloud"
|
||||
"github.com/ultisuite/ulti-backend/internal/permission"
|
||||
platformusers "github.com/ultisuite/ulti-backend/internal/users"
|
||||
)
|
||||
|
||||
const sessionAccessCookie = "ulti_access_token"
|
||||
@ -49,6 +51,9 @@ func (h *Handler) Routes(authMiddleware func(http.Handler) http.Handler) chi.Rou
|
||||
r.Get("/config", h.GetConfig)
|
||||
r.Get("/embed-auth", h.EmbedAuth)
|
||||
r.Post("/embed-signin", h.EmbedSignin)
|
||||
r.Get("/mcp/health", h.MCPHealth)
|
||||
r.Handle("/mcp", http.HandlerFunc(h.MCPProxy))
|
||||
r.Handle("/mcp/*", http.HandlerFunc(h.MCPProxy))
|
||||
// OpenWebUI gateway (Bearer AI_GATEWAY_API_KEY) or user JWT — not behind Auth middleware
|
||||
r.Get("/models", h.ListModels)
|
||||
r.Post("/chat/completions", h.ChatCompletions)
|
||||
@ -82,16 +87,18 @@ func (h *Handler) GetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
"enabled": entry.Enabled,
|
||||
})
|
||||
}
|
||||
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{
|
||||
payload := map[string]any{
|
||||
"enabled": enabled,
|
||||
"public_path": publicPath,
|
||||
"embed_default_temporary": policy.EmbedDefaultTemporary,
|
||||
"default_model": policy.DefaultModel,
|
||||
"default_model": ai.ResolveDefaultModel(r.Context(), h.db, policy),
|
||||
"enabled_tools": policy.EnabledTools,
|
||||
"chat_sync_enabled": policy.ChatSyncEnabled,
|
||||
"models": models,
|
||||
"restrict_models": len(policy.Models) > 0,
|
||||
})
|
||||
}
|
||||
h.writeMCPConfigFields(r.Context(), payload)
|
||||
apiresponse.WriteJSON(w, http.StatusOK, payload)
|
||||
}
|
||||
|
||||
func (h *Handler) EmbedAuth(w http.ResponseWriter, r *http.Request) {
|
||||
@ -151,6 +158,16 @@ func (h *Handler) EmbedSignin(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||
avatarURL, avatarErr := platformusers.GetAvatarURL(r.Context(), h.db, claims.Sub)
|
||||
if avatarErr == nil && avatarURL == "" {
|
||||
avatarURL, _ = platformusers.ImportAvatarFromAuthentik(r.Context(), h.db, h.cfg, claims.Sub)
|
||||
}
|
||||
if syncErr := ai.SyncOpenWebUIProfile(r.Context(), h.cfg, claims, avatarURL); syncErr != nil {
|
||||
slog.Warn("sync openwebui profile on signin", "sub", claims.Sub, "error", syncErr)
|
||||
}
|
||||
}
|
||||
|
||||
for k, vals := range resp.Header {
|
||||
for _, v := range vals {
|
||||
w.Header().Add(k, v)
|
||||
@ -227,11 +244,16 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
|
||||
apiresponse.WriteError(w, r, http.StatusBadRequest, apiresponse.CodeInvalidRequest, "invalid body", nil)
|
||||
return
|
||||
}
|
||||
subject := externalUserID
|
||||
quotaSubject := externalUserID
|
||||
if useOrg {
|
||||
subject = "openwebui-gateway"
|
||||
quotaSubject = ""
|
||||
if email := strings.TrimSpace(r.Header.Get("X-OpenWebUI-User-Email")); email != "" {
|
||||
if extID, err := ai.ResolveExternalIDByEmail(r.Context(), h.db, email); err == nil && extID != "" {
|
||||
quotaSubject = extID
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := h.gateway.ProxyChatCompletions(r.Context(), subject, body, w); err != nil {
|
||||
if err := h.gateway.ProxyChatCompletions(r.Context(), quotaSubject, useOrg, body, w); err != nil {
|
||||
if errors.Is(err, ai.ErrQuotaExceeded) {
|
||||
apiresponse.WriteError(w, r, http.StatusTooManyRequests, apiresponse.CodeRateLimited, err.Error(), nil)
|
||||
return
|
||||
@ -281,7 +303,12 @@ func (h *Handler) CreateSession(w http.ResponseWriter, r *http.Request) {
|
||||
if h.cfg != nil && strings.TrimSpace(h.cfg.AIAssistantPublicPath) != "" {
|
||||
publicPath = h.cfg.AIAssistantPublicPath
|
||||
}
|
||||
temporary := req.Temporary || policy.EmbedDefaultTemporary
|
||||
temporary := policy.EmbedDefaultTemporary
|
||||
if strings.EqualFold(strings.TrimSpace(req.App), "standalone") {
|
||||
temporary = false
|
||||
} else if req.Temporary {
|
||||
temporary = true
|
||||
}
|
||||
q := url.Values{}
|
||||
if temporary {
|
||||
q.Set("temporary-chat", "true")
|
||||
@ -293,12 +320,15 @@ func (h *Handler) CreateSession(w http.ResponseWriter, r *http.Request) {
|
||||
if enc := q.Encode(); enc != "" {
|
||||
embedURL += "?" + enc
|
||||
}
|
||||
apiresponse.WriteJSON(w, http.StatusOK, ai.SessionResponse{
|
||||
resp := ai.SessionResponse{
|
||||
SessionID: created.ID,
|
||||
EmbedURL: embedURL,
|
||||
TokenSecret: created.TokenSecret,
|
||||
Temporary: temporary,
|
||||
})
|
||||
MCPURL: h.publicMCPPath(),
|
||||
EnabledTools: h.loadEnabledTools(r.Context()),
|
||||
}
|
||||
apiresponse.WriteJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (h *Handler) SyncChat(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
184
internal/api/ai/mcp_proxy.go
Normal file
184
internal/api/ai/mcp_proxy.go
Normal file
@ -0,0 +1,184 @@
|
||||
package aiapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/ai"
|
||||
"github.com/ultisuite/ulti-backend/internal/api/apiresponse"
|
||||
"github.com/ultisuite/ulti-backend/internal/apitokens"
|
||||
"github.com/ultisuite/ulti-backend/internal/users"
|
||||
)
|
||||
|
||||
func (h *Handler) MCPProxy(w http.ResponseWriter, r *http.Request) {
|
||||
if h.cfg == nil || strings.TrimSpace(h.cfg.UltimailMCPURL) == "" {
|
||||
apiresponse.WriteError(w, r, http.StatusServiceUnavailable, apiresponse.CodeInternal, "mcp not configured", nil)
|
||||
return
|
||||
}
|
||||
|
||||
token, enabledTools, err := h.resolveMCPToken(r)
|
||||
if err != nil {
|
||||
apiresponse.WriteError(w, r, http.StatusUnauthorized, apiresponse.CodeAuthUnauthorized, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
|
||||
target, err := url.Parse(strings.TrimRight(strings.TrimSpace(h.cfg.UltimailMCPURL), "/"))
|
||||
if err != nil {
|
||||
apiresponse.WriteError(w, r, http.StatusInternalServerError, apiresponse.CodeInternal, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
|
||||
proxy := httputil.NewSingleHostReverseProxy(target)
|
||||
origDirector := proxy.Director
|
||||
upstreamPath := mapMCPUpstreamPath(r.URL.Path)
|
||||
proxy.Director = func(req *http.Request) {
|
||||
origDirector(req)
|
||||
req.Host = target.Host
|
||||
req.URL.Scheme = target.Scheme
|
||||
req.URL.Host = target.Host
|
||||
req.URL.Path = upstreamPath
|
||||
req.Header.Set("X-Ulti-Token", token)
|
||||
if len(enabledTools) > 0 {
|
||||
req.Header.Set("X-Ulti-Enabled-Tools", strings.Join(enabledTools, ","))
|
||||
}
|
||||
}
|
||||
proxy.ModifyResponse = func(resp *http.Response) error {
|
||||
resp.Header.Del("Access-Control-Allow-Origin")
|
||||
return nil
|
||||
}
|
||||
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
apiresponse.WriteError(w, r, http.StatusBadGateway, apiresponse.CodeInternal, err.Error(), nil)
|
||||
}
|
||||
|
||||
proxy.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func (h *Handler) resolveMCPToken(r *http.Request) (token string, enabledTools []string, err error) {
|
||||
ctx := r.Context()
|
||||
enabledTools = h.loadEnabledTools(ctx)
|
||||
|
||||
if headerToken := strings.TrimSpace(r.Header.Get("X-Ulti-Token")); headerToken != "" {
|
||||
if _, authErr := apitokens.Authenticate(ctx, h.db, headerToken); authErr == nil {
|
||||
return headerToken, enabledTools, nil
|
||||
}
|
||||
}
|
||||
|
||||
if bearer := bearerToken(r); strings.HasPrefix(bearer, apitokens.TokenPrefix()) {
|
||||
if _, authErr := apitokens.Authenticate(ctx, h.db, bearer); authErr == nil {
|
||||
return bearer, enabledTools, nil
|
||||
}
|
||||
}
|
||||
|
||||
if bearer := bearerToken(r); h.cfg != nil && h.cfg.AIGatewayAPIKey != "" && bearer == h.cfg.AIGatewayAPIKey {
|
||||
email := openWebUIUserEmail(r)
|
||||
if email == "" {
|
||||
return "", nil, fmt.Errorf("missing openwebui user email")
|
||||
}
|
||||
created, createErr := h.createMCPSessionForEmail(ctx, email, apitokens.ChatSessionStandalone)
|
||||
if createErr != nil {
|
||||
return "", nil, createErr
|
||||
}
|
||||
return created.TokenSecret, enabledTools, nil
|
||||
}
|
||||
|
||||
if claims, ok := h.resolveClaims(r); ok && strings.TrimSpace(claims.Sub) != "" {
|
||||
created, createErr := apitokens.CreateChatSession(ctx, h.db, claims.Sub, claims.Email, apitokens.ChatSessionInput{
|
||||
Preset: apitokens.ChatSessionStandalone,
|
||||
})
|
||||
if createErr != nil {
|
||||
return "", nil, createErr
|
||||
}
|
||||
return created.TokenSecret, enabledTools, nil
|
||||
}
|
||||
|
||||
return "", nil, fmt.Errorf("unauthorized")
|
||||
}
|
||||
|
||||
func (h *Handler) createMCPSessionForEmail(ctx context.Context, email string, preset apitokens.ChatSessionPreset) (apitokens.CreatedToken, error) {
|
||||
externalID, storedEmail, err := users.LookupIdentityByEmail(ctx, h.db, email)
|
||||
if err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
return apitokens.CreatedToken{}, fmt.Errorf("user not found")
|
||||
}
|
||||
return apitokens.CreatedToken{}, err
|
||||
}
|
||||
return apitokens.CreateChatSession(ctx, h.db, externalID, storedEmail, apitokens.ChatSessionInput{
|
||||
Preset: preset,
|
||||
})
|
||||
}
|
||||
|
||||
func openWebUIUserEmail(r *http.Request) string {
|
||||
for _, key := range []string{
|
||||
"X-OpenWebUI-User-Email",
|
||||
"X-Ulti-User-Email",
|
||||
} {
|
||||
if email := strings.TrimSpace(r.Header.Get(key)); email != "" {
|
||||
return email
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (h *Handler) loadEnabledTools(ctx context.Context) []string {
|
||||
deployEnabled := h.cfg != nil && h.cfg.AIAssistantEnabled
|
||||
policy, _ := ai.LoadAssistantPolicy(ctx, h.db)
|
||||
if !deployEnabled && !policy.Enabled {
|
||||
return policy.EnabledTools
|
||||
}
|
||||
if len(policy.EnabledTools) == 0 {
|
||||
return []string{"mail", "drive", "contacts", "agenda", "search", "web_search", "docs"}
|
||||
}
|
||||
return policy.EnabledTools
|
||||
}
|
||||
|
||||
func mapMCPUpstreamPath(path string) string {
|
||||
for _, prefix := range []string{"/api/v1/ai/mcp", "/mcp"} {
|
||||
if !strings.HasPrefix(path, prefix) {
|
||||
continue
|
||||
}
|
||||
suffix := strings.TrimPrefix(path, prefix)
|
||||
if suffix == "" || suffix == "/" {
|
||||
return "/mcp"
|
||||
}
|
||||
if suffix == "/messages" || suffix == "/sse" {
|
||||
return suffix
|
||||
}
|
||||
return "/mcp" + suffix
|
||||
}
|
||||
return "/mcp"
|
||||
}
|
||||
|
||||
func (h *Handler) publicMCPPath() string {
|
||||
return "/api/v1/ai/mcp"
|
||||
}
|
||||
|
||||
func (h *Handler) writeMCPConfigFields(ctx context.Context, out map[string]any) {
|
||||
out["mcp_url"] = h.publicMCPPath()
|
||||
out["enabled_tools"] = h.loadEnabledTools(ctx)
|
||||
}
|
||||
|
||||
// MCPHealth proxies ultimail-mcp /health without auth (for compose healthchecks).
|
||||
func (h *Handler) MCPHealth(w http.ResponseWriter, r *http.Request) {
|
||||
if h.cfg == nil || strings.TrimSpace(h.cfg.UltimailMCPURL) == "" {
|
||||
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"ok": false, "reason": "mcp not configured"})
|
||||
return
|
||||
}
|
||||
target := strings.TrimRight(strings.TrimSpace(h.cfg.UltimailMCPURL), "/") + "/health"
|
||||
resp, err := http.Get(target)
|
||||
if err != nil {
|
||||
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"ok": false, "reason": err.Error()})
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
_, _ = w.Write(body)
|
||||
}
|
||||
@ -45,6 +45,10 @@ func NewHandler(nc *nextcloud.Client, db *pgxpool.Pool) *Handler {
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) Discovery() *discovery.Service {
|
||||
return h.discovery
|
||||
}
|
||||
|
||||
func (h *Handler) Routes() chi.Router {
|
||||
r := chi.NewRouter()
|
||||
read := middleware.RequirePermission(permission.ResourceContacts, permission.LevelRead)
|
||||
|
||||
@ -54,6 +54,32 @@ func (s *Service) orgPolicyLoader() *orgpolicy.Loader {
|
||||
return orgpolicy.NewLoader(s.db, nil)
|
||||
}
|
||||
|
||||
func (s *Service) ListOrgMountsAdmin(ctx context.Context) ([]MountView, error) {
|
||||
store := s.ensureStore()
|
||||
if store == nil {
|
||||
return nil, fmt.Errorf("store not configured")
|
||||
}
|
||||
mounts, err := store.ListOrgMounts(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]MountView, 0, len(mounts))
|
||||
for _, m := range mounts {
|
||||
out = append(out, mapMountView(m))
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *Service) CreateOrgWebDAVMount(ctx context.Context, orgSlug, displayName string, webdav nextcloud.WebDAVMountConfig) (MountView, error) {
|
||||
return s.CreateMount(ctx, "", "", CreateMountParams{
|
||||
Scope: "org",
|
||||
OrgSlug: orgSlug,
|
||||
DisplayName: displayName,
|
||||
BackendType: "webdav",
|
||||
WebDAV: &webdav,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Service) ListMountsForUser(ctx context.Context, platformUserID, ncUserID string, orgSlugs []string) ([]MountView, error) {
|
||||
store := s.ensureStore()
|
||||
if store == nil {
|
||||
@ -203,6 +229,21 @@ func (s *Service) CreateMount(ctx context.Context, platformUserID, ncUserID stri
|
||||
return mapMountView(row), nil
|
||||
}
|
||||
|
||||
func (s *Service) GetMountAdmin(ctx context.Context, mountID string) (MountView, error) {
|
||||
store := s.ensureStore()
|
||||
if store == nil {
|
||||
return MountView{}, fmt.Errorf("store not configured")
|
||||
}
|
||||
mount, err := store.GetMount(ctx, mountID)
|
||||
if err != nil {
|
||||
if errors.Is(err, drivestore.ErrMountNotFound) {
|
||||
return MountView{}, ErrNotFound
|
||||
}
|
||||
return MountView{}, err
|
||||
}
|
||||
return mapMountView(mount), nil
|
||||
}
|
||||
|
||||
func (s *Service) DeleteMount(ctx context.Context, mountID string) error {
|
||||
store := s.ensureStore()
|
||||
if store == nil {
|
||||
|
||||
@ -16,6 +16,50 @@ var systemFolderSlugs = map[string]string{
|
||||
"spam": "spam",
|
||||
}
|
||||
|
||||
// HiddenMailboxExclusion decides whether spam/trash should be filtered out of list/search
|
||||
// results. Explicit folder or label targeting spam/trash disables the matching exclusion.
|
||||
func HiddenMailboxExclusion(folder, label string, includeSpam, includeTrash bool) (excludeSpam, excludeTrash bool) {
|
||||
folder = strings.ToLower(strings.TrimSpace(folder))
|
||||
label = strings.ToLower(strings.TrimSpace(label))
|
||||
excludeSpam = !includeSpam && folder != "spam" && label != "spam"
|
||||
excludeTrash = !includeTrash && folder != "trash" && label != "trash"
|
||||
return excludeSpam, excludeTrash
|
||||
}
|
||||
|
||||
func AppendHiddenMailboxExclusion(base string, args []any, argIdx int, excludeSpam, excludeTrash bool) (string, []any, int) {
|
||||
if !excludeSpam && !excludeTrash {
|
||||
return base, args, argIdx
|
||||
}
|
||||
|
||||
folderTypes := make([]string, 0, 2)
|
||||
if excludeSpam {
|
||||
folderTypes = append(folderTypes, "spam")
|
||||
}
|
||||
if excludeTrash {
|
||||
folderTypes = append(folderTypes, "trash")
|
||||
}
|
||||
base += fmt.Sprintf(`
|
||||
AND NOT EXISTS (
|
||||
SELECT 1 FROM mail_folders f
|
||||
WHERE f.id = m.folder_id AND f.account_id = m.account_id
|
||||
AND f.folder_type = ANY($%d)
|
||||
)`, argIdx)
|
||||
args = append(args, folderTypes)
|
||||
argIdx++
|
||||
|
||||
if excludeSpam {
|
||||
base += fmt.Sprintf(" AND NOT ($%d = ANY(m.labels))", argIdx)
|
||||
args = append(args, "spam")
|
||||
argIdx++
|
||||
}
|
||||
if excludeTrash {
|
||||
base += fmt.Sprintf(" AND NOT ($%d = ANY(m.labels))", argIdx)
|
||||
args = append(args, "trash")
|
||||
argIdx++
|
||||
}
|
||||
return base, args, argIdx
|
||||
}
|
||||
|
||||
// folderFilterClause builds a SQL fragment that resolves a folder query param to
|
||||
// mail_folders rows. System slugs (e.g. "inbox") match folder_type; UUIDs match
|
||||
// folder id; everything else matches display name case-insensitively.
|
||||
|
||||
@ -42,6 +42,33 @@ func TestFolderFilterClause(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHiddenMailboxExclusion(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
folder string
|
||||
label string
|
||||
includeSpam bool
|
||||
includeTrash bool
|
||||
wantSpam bool
|
||||
wantTrash bool
|
||||
}{
|
||||
{name: "default", wantSpam: true, wantTrash: true},
|
||||
{name: "include spam", includeSpam: true, wantTrash: true},
|
||||
{name: "folder spam", folder: "spam", wantSpam: false, wantTrash: true},
|
||||
{name: "label spam", label: "spam", wantSpam: false, wantTrash: true},
|
||||
{name: "folder trash", folder: "trash", wantSpam: true, wantTrash: false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
gotSpam, gotTrash := HiddenMailboxExclusion(tc.folder, tc.label, tc.includeSpam, tc.includeTrash)
|
||||
if gotSpam != tc.wantSpam || gotTrash != tc.wantTrash {
|
||||
t.Fatalf("HiddenMailboxExclusion() = (%v,%v), want (%v,%v)", gotSpam, gotTrash, tc.wantSpam, tc.wantTrash)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func stringsContains(s, sub string) bool {
|
||||
return len(s) >= len(sub) && (s == sub || len(sub) == 0 || indexOf(s, sub) >= 0)
|
||||
}
|
||||
|
||||
@ -312,8 +312,10 @@ func (h *Handler) ListMessages(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
filter := MessageListFilter{
|
||||
Folder: r.URL.Query().Get("folder"),
|
||||
AccountID: r.URL.Query().Get("account_id"),
|
||||
Folder: r.URL.Query().Get("folder"),
|
||||
AccountID: r.URL.Query().Get("account_id"),
|
||||
IncludeSpam: parseOptionalBool(r.URL.Query().Get("include_spam")),
|
||||
IncludeTrash: parseOptionalBool(r.URL.Query().Get("include_trash")),
|
||||
}
|
||||
h.applyMailListScope(&filter, r)
|
||||
|
||||
|
||||
@ -39,10 +39,12 @@ func (h *Handler) SearchMessages(w http.ResponseWriter, r *http.Request) {
|
||||
func parseMessageSearchFilter(r *http.Request) (MessageSearchFilter, *apivalidate.ValidationError) {
|
||||
q := r.URL.Query()
|
||||
filter := MessageSearchFilter{
|
||||
Query: q.Get("q"),
|
||||
Sender: parseSearchSender(q),
|
||||
Label: q.Get("label"),
|
||||
AccountID: q.Get("account_id"),
|
||||
Query: q.Get("q"),
|
||||
Sender: parseSearchSender(q),
|
||||
Label: q.Get("label"),
|
||||
AccountID: q.Get("account_id"),
|
||||
IncludeSpam: parseOptionalBool(q.Get("include_spam")),
|
||||
IncludeTrash: parseOptionalBool(q.Get("include_trash")),
|
||||
}
|
||||
|
||||
if raw := q.Get("date_from"); raw != "" {
|
||||
@ -122,3 +124,12 @@ func parseSearchSender(q url.Values) string {
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func parseOptionalBool(raw string) bool {
|
||||
switch strings.ToLower(strings.TrimSpace(raw)) {
|
||||
case "1", "true", "yes", "on":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@ -20,6 +20,8 @@ type MessageSearchFilter struct {
|
||||
Label string
|
||||
AccountID string
|
||||
ScopedAccountIDs []string
|
||||
IncludeSpam bool
|
||||
IncludeTrash bool
|
||||
}
|
||||
|
||||
type MessageSearchResult struct {
|
||||
@ -68,6 +70,8 @@ func (s *Service) SearchMessages(ctx context.Context, externalID string, filter
|
||||
args = append(args, tsQuery)
|
||||
argIdx++
|
||||
}
|
||||
excludeSpam, excludeTrash := HiddenMailboxExclusion("", filter.Label, filter.IncludeSpam, filter.IncludeTrash)
|
||||
base, args, argIdx = AppendHiddenMailboxExclusion(base, args, argIdx, excludeSpam, excludeTrash)
|
||||
|
||||
var total int64
|
||||
if err := s.db.QueryRow(ctx, "SELECT COUNT(*) "+base, args...).Scan(&total); err != nil {
|
||||
|
||||
@ -187,6 +187,8 @@ type MessageListFilter struct {
|
||||
Folder string
|
||||
AccountID string
|
||||
ScopedAccountIDs []string
|
||||
IncludeSpam bool
|
||||
IncludeTrash bool
|
||||
}
|
||||
|
||||
type MessagesList struct {
|
||||
@ -210,6 +212,8 @@ func (s *Service) ListMessages(ctx context.Context, externalID string, filter Me
|
||||
args = append(args, arg)
|
||||
argIdx++
|
||||
}
|
||||
excludeSpam, excludeTrash := HiddenMailboxExclusion(filter.Folder, "", filter.IncludeSpam, filter.IncludeTrash)
|
||||
baseQuery, args, argIdx = AppendHiddenMailboxExclusion(baseQuery, args, argIdx, excludeSpam, excludeTrash)
|
||||
|
||||
var total int64
|
||||
countQuery := "SELECT COUNT(*) " + baseQuery
|
||||
|
||||
@ -18,6 +18,12 @@ func (rw *responseWriter) WriteHeader(code int) {
|
||||
rw.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
func (rw *responseWriter) Flush() {
|
||||
if f, ok := rw.ResponseWriter.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func Logging(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
|
||||
@ -102,7 +102,7 @@ func ParseList(values url.Values) (ListParams, error) {
|
||||
params.Page = page
|
||||
}
|
||||
|
||||
pageSize, sizeErr := parsePageSize(values.Get("page_size"))
|
||||
pageSize, sizeErr := parsePageSize(firstNonEmpty(values.Get("page_size"), values.Get("limit")))
|
||||
if sizeErr != nil {
|
||||
details = append(details, FieldDetail{Field: "page_size", Message: sizeErr.Error()})
|
||||
} else {
|
||||
@ -152,6 +152,15 @@ func ParseList(values url.Values) (ListParams, error) {
|
||||
return params, nil
|
||||
}
|
||||
|
||||
func firstNonEmpty(values ...string) string {
|
||||
for _, value := range values {
|
||||
if trimmed := strings.TrimSpace(value); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// ParseListRequest parses list query parameters from an HTTP request.
|
||||
func ParseListRequest(r *http.Request) (ListParams, error) {
|
||||
if r == nil {
|
||||
|
||||
@ -183,6 +183,24 @@ func TestMeta(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseList_limitAlias(t *testing.T) {
|
||||
params, err := ParseList(url.Values{"limit": {"5"}})
|
||||
if err != nil {
|
||||
t.Fatalf("ParseList() error = %v", err)
|
||||
}
|
||||
if params.PageSize != 5 {
|
||||
t.Fatalf("PageSize = %d, want 5", params.PageSize)
|
||||
}
|
||||
|
||||
params, err = ParseList(url.Values{"page_size": {"10"}, "limit": {"5"}})
|
||||
if err != nil {
|
||||
t.Fatalf("ParseList() error = %v", err)
|
||||
}
|
||||
if params.PageSize != 10 {
|
||||
t.Fatalf("PageSize = %d, want 10 when page_size takes precedence", params.PageSize)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseListRequest(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodGet, "/items?page=2&page_size=10&q=test", nil)
|
||||
if err != nil {
|
||||
|
||||
@ -13,6 +13,7 @@ import (
|
||||
"github.com/ultisuite/ulti-backend/internal/api/apiresponse"
|
||||
"github.com/ultisuite/ulti-backend/internal/api/apivalidate"
|
||||
"github.com/ultisuite/ulti-backend/internal/api/middleware"
|
||||
"github.com/ultisuite/ulti-backend/internal/config"
|
||||
"github.com/ultisuite/ulti-backend/internal/permission"
|
||||
platformusers "github.com/ultisuite/ulti-backend/internal/users"
|
||||
"github.com/ultisuite/ulti-backend/internal/orgpolicy"
|
||||
@ -20,13 +21,15 @@ import (
|
||||
|
||||
type Handler struct {
|
||||
db *pgxpool.Pool
|
||||
cfg *config.Config
|
||||
logger *slog.Logger
|
||||
orgPolicy *orgpolicy.Loader
|
||||
}
|
||||
|
||||
func NewHandler(db *pgxpool.Pool) *Handler {
|
||||
func NewHandler(db *pgxpool.Pool, cfg *config.Config) *Handler {
|
||||
return &Handler{
|
||||
db: db,
|
||||
cfg: cfg,
|
||||
orgPolicy: orgpolicy.NewLoader(db, nil),
|
||||
logger: slog.Default().With("component", "users-api"),
|
||||
}
|
||||
@ -75,6 +78,13 @@ func (h *Handler) Me(w http.ResponseWriter, r *http.Request) {
|
||||
if err != nil {
|
||||
h.logger.Warn("read user avatar", "error", err)
|
||||
}
|
||||
if avatarURL == "" {
|
||||
if imported, importErr := platformusers.ImportAvatarFromAuthentik(r.Context(), h.db, h.cfg, claims.Sub); importErr != nil {
|
||||
h.logger.Warn("import avatar from authentik", "error", importErr)
|
||||
} else if imported != "" {
|
||||
avatarURL = imported
|
||||
}
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"sub": claims.Sub,
|
||||
@ -124,6 +134,8 @@ func (h *Handler) PutAvatar(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
platformusers.PropagateAvatarOutbound(r.Context(), h.db, h.cfg, claims, body.AvatarURL)
|
||||
|
||||
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{
|
||||
"avatar_url": body.AvatarURL,
|
||||
})
|
||||
@ -146,5 +158,7 @@ func (h *Handler) DeleteAvatar(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
platformusers.PropagateAvatarOutbound(r.Context(), h.db, h.cfg, claims, "")
|
||||
|
||||
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"ok": true})
|
||||
}
|
||||
|
||||
@ -51,18 +51,24 @@ func chatSessionGrants(in ChatSessionInput) ([]PermissionGrant, MailScope, Drive
|
||||
return []PermissionGrant{
|
||||
{Resource: "mail.messages", Read: true},
|
||||
{Resource: "mail.search", Read: true},
|
||||
{Resource: "mail.send", Write: true},
|
||||
{Resource: "mail.labels", Read: true, Write: true},
|
||||
{Resource: "contacts.read", Read: true},
|
||||
{Resource: "contacts.search", Read: true},
|
||||
{Resource: "automation.chat", Read: true},
|
||||
}, mailScope, driveScope
|
||||
case ChatSessionDrive:
|
||||
return []PermissionGrant{
|
||||
{Resource: "drive.files", Read: true, Write: in.AllowWrite},
|
||||
{Resource: "drive.download", Read: true},
|
||||
{Resource: "automation.chat", Read: true},
|
||||
}, mailScope, driveScope
|
||||
case ChatSessionContacts:
|
||||
return []PermissionGrant{
|
||||
{Resource: "contacts.read", Read: true},
|
||||
{Resource: "contacts.search", Read: true},
|
||||
{Resource: "contacts.write", Write: true},
|
||||
{Resource: "contacts.delete", Write: true},
|
||||
{Resource: "mail.search", Read: true},
|
||||
{Resource: "automation.chat", Read: true},
|
||||
}, mailScope, driveScope
|
||||
@ -79,8 +85,17 @@ func chatSessionGrants(in ChatSessionInput) ([]PermissionGrant, MailScope, Drive
|
||||
{Resource: "mail.send", Write: true},
|
||||
{Resource: "mail.labels", Read: true, Write: true},
|
||||
{Resource: "drive.files", Read: true, Write: true},
|
||||
{Resource: "drive.download", Read: true},
|
||||
{Resource: "contacts.read", Read: true},
|
||||
{Resource: "contacts.search", Read: true},
|
||||
{Resource: "contacts.write", Write: true},
|
||||
{Resource: "contacts.delete", Write: true},
|
||||
{Resource: "agenda.calendars", Read: true, Write: true},
|
||||
{Resource: "agenda.events", Read: true},
|
||||
{Resource: "agenda.events.write", Write: true},
|
||||
{Resource: "agenda.events.delete", Write: true},
|
||||
{Resource: "agenda.freebusy", Read: true},
|
||||
{Resource: "agenda.response", Write: true},
|
||||
{Resource: "automation.search", Read: true},
|
||||
{Resource: "automation.chat", Read: true},
|
||||
}, mailScope, driveScope
|
||||
|
||||
@ -88,6 +88,9 @@ func RequirementForRequest(method, fullPath, typesQuery string) (Requirement, bo
|
||||
case strings.HasPrefix(path, "/api/v1/richtext/"):
|
||||
return richtextRequirement(method, path)
|
||||
|
||||
case strings.HasPrefix(path, "/api/v1/search/web"):
|
||||
return Requirement{Resource: "automation.search", Write: false}, true
|
||||
|
||||
case strings.HasPrefix(path, "/api/v1/search"):
|
||||
return searchRequirement(typesQuery)
|
||||
|
||||
|
||||
93
internal/authentik/users.go
Normal file
93
internal/authentik/users.go
Normal file
@ -0,0 +1,93 @@
|
||||
package authentik
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type akUser struct {
|
||||
PK int `json:"pk"`
|
||||
UUID string `json:"uuid"`
|
||||
Attributes map[string]any `json:"attributes"`
|
||||
}
|
||||
|
||||
const avatarAttributeKey = "avatar"
|
||||
|
||||
// FindUserByUUID returns the Authentik core user for an OIDC subject (uuid).
|
||||
func (c *Client) FindUserByUUID(ctx context.Context, userUUID string) (*akUser, bool, error) {
|
||||
userUUID = strings.TrimSpace(userUUID)
|
||||
if userUUID == "" {
|
||||
return nil, false, nil
|
||||
}
|
||||
q := url.Values{}
|
||||
q.Set("uuid", userUUID)
|
||||
var out listResponse[akUser]
|
||||
if err := c.getJSON(ctx, "/api/v3/core/users/?"+q.Encode(), &out); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if len(out.Results) == 0 {
|
||||
return nil, false, nil
|
||||
}
|
||||
user := out.Results[0]
|
||||
if user.Attributes == nil {
|
||||
user.Attributes = map[string]any{}
|
||||
}
|
||||
return &user, true, nil
|
||||
}
|
||||
|
||||
// GetUserAvatarAttribute reads attributes.avatar from Authentik.
|
||||
func (c *Client) GetUserAvatarAttribute(ctx context.Context, userUUID string) (string, error) {
|
||||
user, found, err := c.FindUserByUUID(ctx, userUUID)
|
||||
if err != nil || !found {
|
||||
return "", err
|
||||
}
|
||||
return avatarAttributeString(user.Attributes[avatarAttributeKey]), nil
|
||||
}
|
||||
|
||||
// SetUserAvatarAttribute writes attributes.avatar on the Authentik user.
|
||||
func (c *Client) SetUserAvatarAttribute(ctx context.Context, userUUID, avatarURL string) error {
|
||||
user, found, err := c.FindUserByUUID(ctx, userUUID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !found {
|
||||
return fmt.Errorf("authentik user not found: %s", userUUID)
|
||||
}
|
||||
attrs := cloneAttributes(user.Attributes)
|
||||
avatarURL = strings.TrimSpace(avatarURL)
|
||||
if avatarURL == "" {
|
||||
delete(attrs, avatarAttributeKey)
|
||||
} else {
|
||||
attrs[avatarAttributeKey] = avatarURL
|
||||
}
|
||||
return c.patchJSON(ctx, fmt.Sprintf("/api/v3/core/users/%d/", user.PK), map[string]any{
|
||||
"attributes": attrs,
|
||||
})
|
||||
}
|
||||
|
||||
func cloneAttributes(src map[string]any) map[string]any {
|
||||
if len(src) == 0 {
|
||||
return map[string]any{}
|
||||
}
|
||||
dst := make(map[string]any, len(src))
|
||||
for k, v := range src {
|
||||
dst[k] = v
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
func avatarAttributeString(raw any) string {
|
||||
switch v := raw.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(v)
|
||||
case []byte:
|
||||
return strings.TrimSpace(string(v))
|
||||
default:
|
||||
if raw == nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(fmt.Sprint(raw))
|
||||
}
|
||||
}
|
||||
74
internal/authentik/users_test.go
Normal file
74
internal/authentik/users_test.go
Normal file
@ -0,0 +1,74 @@
|
||||
package authentik
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSetUserAvatarAttribute(t *testing.T) {
|
||||
const userUUID = "11111111-1111-1111-1111-111111111111"
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.Method == http.MethodGet && r.URL.Path == "/api/v3/core/users/":
|
||||
_ = json.NewEncoder(w).Encode(listResponse[akUser]{
|
||||
Results: []akUser{{
|
||||
PK: 42,
|
||||
UUID: userUUID,
|
||||
Attributes: map[string]any{
|
||||
"phone": "+33123456789",
|
||||
},
|
||||
}},
|
||||
})
|
||||
case r.Method == http.MethodPatch && r.URL.Path == "/api/v3/core/users/42/":
|
||||
var body map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
t.Fatalf("decode patch body: %v", err)
|
||||
}
|
||||
attrs, ok := body["attributes"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("attributes = %#v", body["attributes"])
|
||||
}
|
||||
if attrs["phone"] != "+33123456789" {
|
||||
t.Fatalf("phone not preserved: %#v", attrs["phone"])
|
||||
}
|
||||
if attrs["avatar"] != "data:image/png;base64,abc" {
|
||||
t.Fatalf("avatar = %#v", attrs["avatar"])
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
default:
|
||||
t.Fatalf("unexpected request: %s %s", r.Method, r.URL.Path)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := NewClient(srv.URL, "token")
|
||||
if err := client.SetUserAvatarAttribute(context.Background(), userUUID, "data:image/png;base64,abc"); err != nil {
|
||||
t.Fatalf("SetUserAvatarAttribute() error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUserAvatarAttribute(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_ = json.NewEncoder(w).Encode(listResponse[akUser]{
|
||||
Results: []akUser{{
|
||||
PK: 7,
|
||||
Attributes: map[string]any{
|
||||
"avatar": " data:image/jpeg;base64,xyz ",
|
||||
},
|
||||
}},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := NewClient(srv.URL, "token")
|
||||
got, err := client.GetUserAvatarAttribute(context.Background(), "uuid")
|
||||
if err != nil {
|
||||
t.Fatalf("GetUserAvatarAttribute() error = %v", err)
|
||||
}
|
||||
if got != "data:image/jpeg;base64,xyz" {
|
||||
t.Fatalf("GetUserAvatarAttribute() = %q", got)
|
||||
}
|
||||
}
|
||||
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/websearch"
|
||||
)
|
||||
@ -16,6 +17,8 @@ func (s *Service) UpdateSearchSettings(ctx context.Context, externalUserID strin
|
||||
if s.db == nil {
|
||||
return websearch.Settings{}, fmt.Errorf("database unavailable")
|
||||
}
|
||||
existing, _ := s.loadSearchSettings(ctx, externalUserID)
|
||||
settings = mergeSearchProviderSecrets(existing, settings)
|
||||
raw, err := json.Marshal(settings)
|
||||
if err != nil {
|
||||
return websearch.Settings{}, err
|
||||
@ -60,3 +63,25 @@ func searchSettingsConfigured(settings websearch.Settings) bool {
|
||||
_, err := websearch.ResolveProvider(settings)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func mergeSearchProviderSecrets(existing, patch websearch.Settings) websearch.Settings {
|
||||
if len(patch.Providers) == 0 {
|
||||
return patch
|
||||
}
|
||||
existingByID := make(map[string]websearch.Provider, len(existing.Providers))
|
||||
for _, provider := range existing.Providers {
|
||||
if provider.ID != "" {
|
||||
existingByID[provider.ID] = provider
|
||||
}
|
||||
}
|
||||
for i, provider := range patch.Providers {
|
||||
if strings.TrimSpace(provider.APIKey) != "" || provider.ID == "" {
|
||||
continue
|
||||
}
|
||||
if old, ok := existingByID[provider.ID]; ok && strings.TrimSpace(old.APIKey) != "" {
|
||||
provider.APIKey = old.APIKey
|
||||
patch.Providers[i] = provider
|
||||
}
|
||||
}
|
||||
return patch
|
||||
}
|
||||
|
||||
41
internal/contacts/discovery/search_settings_test.go
Normal file
41
internal/contacts/discovery/search_settings_test.go
Normal file
@ -0,0 +1,41 @@
|
||||
package discovery
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/websearch"
|
||||
)
|
||||
|
||||
func TestMergeSearchProviderSecretsPreservesExistingKey(t *testing.T) {
|
||||
existing := websearch.Settings{
|
||||
Providers: []websearch.Provider{{
|
||||
ID: "brave-1", Type: websearch.ProviderBrave, APIKey: "stored-key",
|
||||
}},
|
||||
}
|
||||
patch := websearch.Settings{
|
||||
Providers: []websearch.Provider{{
|
||||
ID: "brave-1", Type: websearch.ProviderBrave, APIKey: "",
|
||||
}},
|
||||
}
|
||||
merged := mergeSearchProviderSecrets(existing, patch)
|
||||
if merged.Providers[0].APIKey != "stored-key" {
|
||||
t.Fatalf("api_key = %q, want stored-key", merged.Providers[0].APIKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeSearchProviderSecretsKeepsNewKey(t *testing.T) {
|
||||
existing := websearch.Settings{
|
||||
Providers: []websearch.Provider{{
|
||||
ID: "brave-1", Type: websearch.ProviderBrave, APIKey: "old-key",
|
||||
}},
|
||||
}
|
||||
patch := websearch.Settings{
|
||||
Providers: []websearch.Provider{{
|
||||
ID: "brave-1", Type: websearch.ProviderBrave, APIKey: "new-key",
|
||||
}},
|
||||
}
|
||||
merged := mergeSearchProviderSecrets(existing, patch)
|
||||
if merged.Providers[0].APIKey != "new-key" {
|
||||
t.Fatalf("api_key = %q, want new-key", merged.Providers[0].APIKey)
|
||||
}
|
||||
}
|
||||
34
internal/contacts/discovery/web_search.go
Normal file
34
internal/contacts/discovery/web_search.go
Normal file
@ -0,0 +1,34 @@
|
||||
package discovery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/websearch"
|
||||
)
|
||||
|
||||
var ErrWebSearchNotConfigured = errors.New("web search not configured")
|
||||
|
||||
func (s *Service) SearchWeb(ctx context.Context, externalUserID, query string, count int) ([]websearch.Result, error) {
|
||||
if s.websearch == nil {
|
||||
return nil, fmt.Errorf("web search unavailable")
|
||||
}
|
||||
query = strings.TrimSpace(query)
|
||||
if query == "" {
|
||||
return nil, fmt.Errorf("search query is required")
|
||||
}
|
||||
settings, err := s.loadSearchSettings(ctx, externalUserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !searchSettingsConfigured(settings) {
|
||||
return nil, ErrWebSearchNotConfigured
|
||||
}
|
||||
provider, err := websearch.ResolveProvider(settings)
|
||||
if err != nil {
|
||||
return nil, ErrWebSearchNotConfigured
|
||||
}
|
||||
return s.websearch.Search(ctx, provider, query, count)
|
||||
}
|
||||
@ -197,6 +197,24 @@ func (s *Store) DeleteOrgFolder(ctx context.Context, id string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) ListOrgMounts(ctx context.Context) ([]Mount, error) {
|
||||
if s.db == nil {
|
||||
return nil, fmt.Errorf("database not configured")
|
||||
}
|
||||
rows, err := s.db.Query(ctx, `
|
||||
SELECT id, scope, owner_user_id, org_slug, nc_mount_id, display_name,
|
||||
backend_type, mount_point, status, last_error, created_at, updated_at
|
||||
FROM drive_mounts
|
||||
WHERE scope = 'org'
|
||||
ORDER BY display_name ASC
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanMounts(rows)
|
||||
}
|
||||
|
||||
func (s *Store) ListMountsForUser(ctx context.Context, ownerUserID string, orgSlugs []string) ([]Mount, error) {
|
||||
if s.db == nil {
|
||||
return nil, fmt.Errorf("database not configured")
|
||||
@ -206,7 +224,10 @@ func (s *Store) ListMountsForUser(ctx context.Context, ownerUserID string, orgSl
|
||||
backend_type, mount_point, status, last_error, created_at, updated_at
|
||||
FROM drive_mounts
|
||||
WHERE (scope = 'user' AND owner_user_id = $1::uuid)
|
||||
OR (scope = 'org' AND org_slug = ANY($2))
|
||||
OR (
|
||||
scope = 'org'
|
||||
AND (cardinality($2::text[]) = 0 OR org_slug = ANY($2))
|
||||
)
|
||||
ORDER BY display_name ASC
|
||||
`, ownerUserID, orgSlugs)
|
||||
if err != nil {
|
||||
|
||||
@ -14,6 +14,7 @@ import (
|
||||
type Provider struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type,omitempty"`
|
||||
BaseURL string `json:"base_url"`
|
||||
APIKey string `json:"api_key,omitempty"`
|
||||
DefaultModel string `json:"default_model"`
|
||||
|
||||
@ -110,6 +110,12 @@ func (rw *metricsResponseWriter) WriteHeader(code int) {
|
||||
rw.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
func (rw *metricsResponseWriter) Flush() {
|
||||
if f, ok := rw.ResponseWriter.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func HTTPMetrics(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
|
||||
@ -67,7 +67,9 @@ func (h *Handler) Search(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
filters := SearchFilters{
|
||||
AccountID: strings.TrimSpace(r.URL.Query().Get("account_id")),
|
||||
AccountID: strings.TrimSpace(r.URL.Query().Get("account_id")),
|
||||
IncludeSpam: parseSearchOptionalBool(r.URL.Query().Get("include_spam")),
|
||||
IncludeTrash: parseSearchOptionalBool(r.URL.Query().Get("include_trash")),
|
||||
}
|
||||
|
||||
result, err := h.svc.Search(r.Context(), claims.Sub, q, types, params, filters)
|
||||
|
||||
@ -13,6 +13,8 @@ import (
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/api/query"
|
||||
"github.com/ultisuite/ulti-backend/internal/nextcloud"
|
||||
|
||||
mailapi "github.com/ultisuite/ulti-backend/internal/api/mail"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
@ -62,7 +64,9 @@ type SearchResponse struct {
|
||||
}
|
||||
|
||||
type SearchFilters struct {
|
||||
AccountID string
|
||||
AccountID string
|
||||
IncludeSpam bool
|
||||
IncludeTrash bool
|
||||
}
|
||||
|
||||
func (s *Service) Search(ctx context.Context, externalID, q, typesRaw string, params query.ListParams, filters SearchFilters) (SearchResponse, error) {
|
||||
@ -282,6 +286,8 @@ func (s *Service) searchMail(ctx context.Context, externalID, queryText string,
|
||||
args = append(args, params.To.UTC())
|
||||
argIdx++
|
||||
}
|
||||
excludeSpam, excludeTrash := mailapi.HiddenMailboxExclusion("", "", filters.IncludeSpam, filters.IncludeTrash)
|
||||
base, args, argIdx = mailapi.AppendHiddenMailboxExclusion(base, args, argIdx, excludeSpam, excludeTrash)
|
||||
|
||||
querySQL := `
|
||||
SELECT m.id, m.account_id, m.subject, m.snippet, m.date,
|
||||
|
||||
@ -55,3 +55,12 @@ func validateSearchTypes(typesRaw string) *apivalidate.ValidationError {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseSearchOptionalBool(raw string) bool {
|
||||
switch strings.ToLower(strings.TrimSpace(raw)) {
|
||||
case "1", "true", "yes", "on":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
58
internal/search/web_handler.go
Normal file
58
internal/search/web_handler.go
Normal file
@ -0,0 +1,58 @@
|
||||
package search
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/api/apiresponse"
|
||||
"github.com/ultisuite/ulti-backend/internal/api/middleware"
|
||||
"github.com/ultisuite/ulti-backend/internal/contacts/discovery"
|
||||
)
|
||||
|
||||
type WebHandler struct {
|
||||
discovery *discovery.Service
|
||||
}
|
||||
|
||||
func NewWebHandler(discovery *discovery.Service) *WebHandler {
|
||||
return &WebHandler{discovery: discovery}
|
||||
}
|
||||
|
||||
func (h *WebHandler) Search(w http.ResponseWriter, r *http.Request) {
|
||||
if h.discovery == nil {
|
||||
apiresponse.WriteError(w, r, http.StatusServiceUnavailable, apiresponse.CodeInternal, "web search unavailable", nil)
|
||||
return
|
||||
}
|
||||
|
||||
claims := middleware.ClaimsFromContext(r.Context())
|
||||
query := r.URL.Query().Get("q")
|
||||
if query == "" {
|
||||
apiresponse.WriteError(w, r, http.StatusBadRequest, apiresponse.CodeInvalidQueryParam, "q is required", nil)
|
||||
return
|
||||
}
|
||||
|
||||
count := 5
|
||||
if raw := r.URL.Query().Get("count"); raw != "" {
|
||||
parsed, err := strconv.Atoi(raw)
|
||||
if err != nil || parsed <= 0 {
|
||||
apiresponse.WriteError(w, r, http.StatusBadRequest, apiresponse.CodeInvalidQueryParam, "invalid count", nil)
|
||||
return
|
||||
}
|
||||
count = parsed
|
||||
}
|
||||
|
||||
results, err := h.discovery.SearchWeb(r.Context(), claims.Sub, query, count)
|
||||
if err != nil {
|
||||
if errors.Is(err, discovery.ErrWebSearchNotConfigured) {
|
||||
apiresponse.WriteError(w, r, http.StatusServiceUnavailable, "web_search_not_configured", "web search provider not configured", nil)
|
||||
return
|
||||
}
|
||||
apiresponse.WriteError(w, r, http.StatusBadGateway, apiresponse.CodeInternal, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
|
||||
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{
|
||||
"query": query,
|
||||
"results": results,
|
||||
})
|
||||
}
|
||||
@ -36,6 +36,7 @@ import (
|
||||
"github.com/ultisuite/ulti-backend/internal/automation"
|
||||
"github.com/ultisuite/ulti-backend/internal/authentik"
|
||||
"github.com/ultisuite/ulti-backend/internal/auth"
|
||||
"github.com/ultisuite/ulti-backend/internal/contacts/discovery"
|
||||
"github.com/ultisuite/ulti-backend/internal/config"
|
||||
"github.com/ultisuite/ulti-backend/internal/filescan"
|
||||
"github.com/ultisuite/ulti-backend/internal/httpcors"
|
||||
@ -412,7 +413,7 @@ func New(ctx context.Context, cfg *config.Config, opts Options) (*App, error) {
|
||||
r.Use(middleware.Auth(verifierHolder, pool, auditLogger, orgPolicyLoader))
|
||||
r.Use(middleware.EnforceApiTokenPolicy())
|
||||
|
||||
r.Mount("/api/v1/users", usersapi.NewHandler(pool).Routes())
|
||||
r.Mount("/api/v1/users", usersapi.NewHandler(pool, cfg).Routes())
|
||||
adminHandler := admin.NewHandler(pool, auditLogger, cfg, ncClient)
|
||||
adminHandler.SetHostedService(hostedSvc)
|
||||
adminHandler.SetMigrationService(migrationSvc)
|
||||
@ -425,7 +426,7 @@ func New(ctx context.Context, cfg *config.Config, opts Options) (*App, error) {
|
||||
r.Use(middleware.RequireFullAccount)
|
||||
r.Mount("/api/v1/mail", mailHandler.Routes())
|
||||
r.Mount("/api/v1/migration", migrationHandler.Routes())
|
||||
r.Get("/api/v1/search", search.NewHandler(pool, search.Options{
|
||||
searchHandler := search.NewHandler(pool, search.Options{
|
||||
Nextcloud: ncClient,
|
||||
Engine: cfg.SearchEngine,
|
||||
MeilisearchURL: cfg.MeilisearchURL,
|
||||
@ -434,7 +435,13 @@ func New(ctx context.Context, cfg *config.Config, opts Options) (*App, error) {
|
||||
TypesenseURL: cfg.TypesenseURL,
|
||||
TypesenseKey: cfg.TypesenseKey,
|
||||
TypesenseCollection: cfg.TypesenseCollection,
|
||||
}).Search)
|
||||
})
|
||||
r.Get("/api/v1/search", searchHandler.Search)
|
||||
webDiscovery := discovery.NewService(pool)
|
||||
if contactsHandler != nil && contactsHandler.Discovery() != nil {
|
||||
webDiscovery = contactsHandler.Discovery()
|
||||
}
|
||||
r.Get("/api/v1/search/web", search.NewWebHandler(webDiscovery).Search)
|
||||
if driveHandler != nil {
|
||||
r.Mount("/api/v1/calendar", calendar.NewHandler(ncClient, meetCfg, orgPolicyLoader).Routes())
|
||||
r.Mount("/api/v1/contacts", contactsHandler.Routes())
|
||||
|
||||
67
internal/users/avatar_sync.go
Normal file
67
internal/users/avatar_sync.go
Normal file
@ -0,0 +1,67 @@
|
||||
package users
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/ai"
|
||||
"github.com/ultisuite/ulti-backend/internal/auth"
|
||||
"github.com/ultisuite/ulti-backend/internal/authentik"
|
||||
"github.com/ultisuite/ulti-backend/internal/config"
|
||||
)
|
||||
|
||||
// ImportAvatarFromAuthentik copies attributes.avatar into users.avatar_url when PG is empty.
|
||||
func ImportAvatarFromAuthentik(ctx context.Context, db *pgxpool.Pool, cfg *config.Config, externalID string) (string, error) {
|
||||
if db == nil || cfg == nil || strings.TrimSpace(externalID) == "" {
|
||||
return "", nil
|
||||
}
|
||||
current, err := GetAvatarURL(ctx, db, externalID)
|
||||
if err != nil || current != "" {
|
||||
return current, err
|
||||
}
|
||||
client := authentikClient(cfg)
|
||||
if client == nil {
|
||||
return "", nil
|
||||
}
|
||||
raw, err := client.GetUserAvatarAttribute(ctx, externalID)
|
||||
if err != nil || strings.TrimSpace(raw) == "" {
|
||||
return "", err
|
||||
}
|
||||
normalized, err := normalizeAvatarURL(raw)
|
||||
if err != nil {
|
||||
slog.Warn("skip authentik avatar import", "sub", externalID, "error", err)
|
||||
return "", nil
|
||||
}
|
||||
if err := SetAvatarURL(ctx, db, externalID, normalized); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return normalized, nil
|
||||
}
|
||||
|
||||
// PropagateAvatarOutbound syncs avatar to Authentik and UltiAI (OpenWebUI).
|
||||
func PropagateAvatarOutbound(ctx context.Context, db *pgxpool.Pool, cfg *config.Config, claims *auth.Claims, avatarURL string) {
|
||||
if claims == nil || strings.TrimSpace(claims.Sub) == "" {
|
||||
return
|
||||
}
|
||||
avatarURL = strings.TrimSpace(avatarURL)
|
||||
|
||||
if client := authentikClient(cfg); client != nil {
|
||||
if err := client.SetUserAvatarAttribute(ctx, claims.Sub, avatarURL); err != nil {
|
||||
slog.Warn("sync avatar to authentik", "sub", claims.Sub, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := ai.SyncOpenWebUIProfile(ctx, cfg, claims, avatarURL); err != nil {
|
||||
slog.Warn("sync avatar to openwebui", "sub", claims.Sub, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func authentikClient(cfg *config.Config) *authentik.Client {
|
||||
if cfg == nil || strings.TrimSpace(cfg.AuthentikAPIToken) == "" {
|
||||
return nil
|
||||
}
|
||||
return authentik.NewClient(cfg.AuthentikAPIURL, cfg.AuthentikAPIToken)
|
||||
}
|
||||
@ -77,6 +77,21 @@ func LookupUserID(ctx context.Context, db *pgxpool.Pool, externalID string) (str
|
||||
return userID, err
|
||||
}
|
||||
|
||||
// LookupIdentityByEmail returns the OIDC subject and stored email for a user row.
|
||||
func LookupIdentityByEmail(ctx context.Context, db *pgxpool.Pool, email string) (externalID, storedEmail string, err error) {
|
||||
if db == nil {
|
||||
return "", "", fmt.Errorf("database not configured")
|
||||
}
|
||||
email = strings.ToLower(strings.TrimSpace(email))
|
||||
if email == "" {
|
||||
return "", "", pgx.ErrNoRows
|
||||
}
|
||||
err = db.QueryRow(ctx, `
|
||||
SELECT external_id, email FROM users WHERE lower(email) = $1
|
||||
`, email).Scan(&externalID, &storedEmail)
|
||||
return externalID, storedEmail, err
|
||||
}
|
||||
|
||||
// LookupUserIDByEmail returns the internal user UUID for a stored email address.
|
||||
func LookupUserIDByEmail(ctx context.Context, db *pgxpool.Pool, email string) (string, error) {
|
||||
if db == nil {
|
||||
|
||||
80
internal/websearch/bing.go
Normal file
80
internal/websearch/bing.go
Normal file
@ -0,0 +1,80 @@
|
||||
package websearch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type bingSearchResponse struct {
|
||||
WebPages *struct {
|
||||
Value []struct {
|
||||
Name string `json:"name"`
|
||||
URL string `json:"url"`
|
||||
Snippet string `json:"snippet"`
|
||||
} `json:"value"`
|
||||
} `json:"webPages"`
|
||||
}
|
||||
|
||||
func (c *Client) searchBing(ctx context.Context, provider Provider, query string, count int) ([]Result, error) {
|
||||
apiKey := strings.TrimSpace(provider.APIKey)
|
||||
if apiKey == "" {
|
||||
return nil, fmt.Errorf("bing api key is required")
|
||||
}
|
||||
|
||||
endpoint := strings.TrimSpace(provider.BaseURL)
|
||||
if endpoint == "" {
|
||||
endpoint = bingSearchURL
|
||||
}
|
||||
u, err := url.Parse(endpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
q := u.Query()
|
||||
q.Set(queryParamName(provider), query)
|
||||
q.Set("count", fmt.Sprintf("%d", count))
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
header := strings.TrimSpace(provider.AuthHeader)
|
||||
if header == "" {
|
||||
header = "Ocp-Apim-Subscription-Key"
|
||||
}
|
||||
req.Header.Set(header, apiKey)
|
||||
|
||||
body, status, err := c.doRequest(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if status >= 400 {
|
||||
return nil, fmt.Errorf("bing search failed (%d): %s", status, string(body))
|
||||
}
|
||||
|
||||
var parsed bingSearchResponse
|
||||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if parsed.WebPages == nil || len(parsed.WebPages.Value) == 0 {
|
||||
return []Result{}, nil
|
||||
}
|
||||
|
||||
results := make([]Result, 0, len(parsed.WebPages.Value))
|
||||
for _, item := range parsed.WebPages.Value {
|
||||
results = append(results, Result{
|
||||
Title: strings.TrimSpace(item.Name),
|
||||
URL: strings.TrimSpace(item.URL),
|
||||
Description: strings.TrimSpace(item.Snippet),
|
||||
})
|
||||
}
|
||||
if len(results) > count {
|
||||
results = results[:count]
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
84
internal/websearch/brave.go
Normal file
84
internal/websearch/brave.go
Normal file
@ -0,0 +1,84 @@
|
||||
package websearch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type braveSearchResponse struct {
|
||||
Web *struct {
|
||||
Results []struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Description string `json:"description"`
|
||||
} `json:"results"`
|
||||
} `json:"web"`
|
||||
}
|
||||
|
||||
func (c *Client) searchBrave(ctx context.Context, apiKey, query string, count int) ([]Result, error) {
|
||||
apiKey = strings.TrimSpace(apiKey)
|
||||
if apiKey == "" {
|
||||
return nil, fmt.Errorf("brave api key is required")
|
||||
}
|
||||
|
||||
u, err := url.Parse(braveSearchURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
q := u.Query()
|
||||
q.Set("q", query)
|
||||
q.Set("count", fmt.Sprintf("%d", count))
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Accept-Encoding", "gzip")
|
||||
req.Header.Set("X-Subscription-Token", apiKey)
|
||||
|
||||
body, status, err := c.doRequest(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if status >= 400 {
|
||||
return nil, fmt.Errorf("brave search failed (%d): %s", status, string(body))
|
||||
}
|
||||
|
||||
var parsed braveSearchResponse
|
||||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if parsed.Web == nil || len(parsed.Web.Results) == 0 {
|
||||
return []Result{}, nil
|
||||
}
|
||||
|
||||
results := make([]Result, 0, len(parsed.Web.Results))
|
||||
for _, item := range parsed.Web.Results {
|
||||
results = append(results, Result{
|
||||
Title: strings.TrimSpace(item.Title),
|
||||
URL: strings.TrimSpace(item.URL),
|
||||
Description: strings.TrimSpace(item.Description),
|
||||
})
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (c *Client) doRequest(req *http.Request) ([]byte, int, error) {
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
if err != nil {
|
||||
return nil, resp.StatusCode, err
|
||||
}
|
||||
return body, resp.StatusCode, nil
|
||||
}
|
||||
@ -2,16 +2,15 @@ package websearch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var braveSearchURL = "https://api.search.brave.com/res/v1/web/search"
|
||||
var bingSearchURL = "https://api.bing.microsoft.com/v7.0/search"
|
||||
var duckDuckGoHTMLURL = "https://html.duckduckgo.com/html/"
|
||||
|
||||
type Result struct {
|
||||
Title string `json:"title"`
|
||||
@ -27,16 +26,6 @@ func NewClient() *Client {
|
||||
return &Client{http: &http.Client{Timeout: 15 * time.Second}}
|
||||
}
|
||||
|
||||
type braveSearchResponse struct {
|
||||
Web *struct {
|
||||
Results []struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Description string `json:"description"`
|
||||
} `json:"results"`
|
||||
} `json:"web"`
|
||||
}
|
||||
|
||||
func (c *Client) Search(ctx context.Context, provider Provider, query string, count int) ([]Result, error) {
|
||||
query = strings.TrimSpace(query)
|
||||
if query == "" {
|
||||
@ -52,63 +41,15 @@ func (c *Client) Search(ctx context.Context, provider Provider, query string, co
|
||||
switch provider.Type {
|
||||
case ProviderBrave:
|
||||
return c.searchBrave(ctx, provider.APIKey, query, count)
|
||||
case ProviderBing:
|
||||
return c.searchBing(ctx, provider, query, count)
|
||||
case ProviderDuckDuckGo:
|
||||
return c.searchDuckDuckGo(ctx, provider, query, count)
|
||||
case ProviderSearXNG:
|
||||
return c.searchSearXNG(ctx, provider, query, count)
|
||||
case ProviderCustom:
|
||||
return c.searchCustom(ctx, provider, query, count)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported search provider type: %s", provider.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) searchBrave(ctx context.Context, apiKey, query string, count int) ([]Result, error) {
|
||||
apiKey = strings.TrimSpace(apiKey)
|
||||
if apiKey == "" {
|
||||
return nil, fmt.Errorf("brave api key is required")
|
||||
}
|
||||
|
||||
u, err := url.Parse(braveSearchURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
q := u.Query()
|
||||
q.Set("q", query)
|
||||
q.Set("count", fmt.Sprintf("%d", count))
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Accept-Encoding", "gzip")
|
||||
req.Header.Set("X-Subscription-Token", apiKey)
|
||||
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, fmt.Errorf("brave search failed (%d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var parsed braveSearchResponse
|
||||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if parsed.Web == nil || len(parsed.Web.Results) == 0 {
|
||||
return []Result{}, nil
|
||||
}
|
||||
|
||||
results := make([]Result, 0, len(parsed.Web.Results))
|
||||
for _, item := range parsed.Web.Results {
|
||||
results = append(results, Result{
|
||||
Title: strings.TrimSpace(item.Title),
|
||||
URL: strings.TrimSpace(item.URL),
|
||||
Description: strings.TrimSpace(item.Description),
|
||||
})
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
@ -44,6 +44,110 @@ func TestSearchBrave(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchBing(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if got := r.Header.Get("Ocp-Apim-Subscription-Key"); got != "bing-key" {
|
||||
t.Fatalf("unexpected bing key: %q", got)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"webPages":{"value":[{"name":"Result","url":"https://example.com","snippet":"Hello"}]}}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
oldURL := bingSearchURL
|
||||
bingSearchURL = srv.URL
|
||||
t.Cleanup(func() { bingSearchURL = oldURL })
|
||||
|
||||
client := NewClient()
|
||||
results, err := client.Search(context.Background(), Provider{
|
||||
Type: ProviderBing,
|
||||
APIKey: "bing-key",
|
||||
}, "hello", 5)
|
||||
if err != nil {
|
||||
t.Fatalf("Search: %v", err)
|
||||
}
|
||||
if len(results) != 1 || results[0].URL != "https://example.com" {
|
||||
t.Fatalf("unexpected results: %#v", results)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchSearXNG(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if got := r.URL.Query().Get("format"); got != "json" {
|
||||
t.Fatalf("expected json format, got %q", got)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"results":[{"title":"SearX","url":"https://searx.example","content":"desc"}]}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := NewClient()
|
||||
results, err := client.Search(context.Background(), Provider{
|
||||
Type: ProviderSearXNG,
|
||||
BaseURL: srv.URL,
|
||||
}, "test", 5)
|
||||
if err != nil {
|
||||
t.Fatalf("Search: %v", err)
|
||||
}
|
||||
if len(results) != 1 || results[0].Title != "SearX" {
|
||||
t.Fatalf("unexpected results: %#v", results)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchCustomJSON(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if got := r.URL.Query().Get("q"); got != "custom query" {
|
||||
t.Fatalf("unexpected query: %q", got)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"data":{"items":[{"title":"Custom","link":"https://custom.test","summary":"note"}]}}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := NewClient()
|
||||
results, err := client.Search(context.Background(), Provider{
|
||||
Type: ProviderCustom,
|
||||
BaseURL: srv.URL,
|
||||
ResultsPath: "data.items",
|
||||
TitleField: "title",
|
||||
URLField: "link",
|
||||
DescField: "summary",
|
||||
}, "custom query", 3)
|
||||
if err != nil {
|
||||
t.Fatalf("Search: %v", err)
|
||||
}
|
||||
if len(results) != 1 || results[0].Title != "Custom" {
|
||||
t.Fatalf("unexpected results: %#v", results)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchDuckDuckGoHTML(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Fatalf("expected POST, got %s", r.Method)
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
_, _ = w.Write([]byte(`
|
||||
<a class="result__a" href="https://duckduckgo.com/l/?uddg=https%3A%2F%2Fexample.com">Example</a>
|
||||
<a class="result__snippet">Example description</a>
|
||||
`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
oldURL := duckDuckGoHTMLURL
|
||||
duckDuckGoHTMLURL = srv.URL
|
||||
t.Cleanup(func() { duckDuckGoHTMLURL = oldURL })
|
||||
|
||||
client := NewClient()
|
||||
results, err := client.Search(context.Background(), Provider{Type: ProviderDuckDuckGo}, "example", 5)
|
||||
if err != nil {
|
||||
t.Fatalf("Search: %v", err)
|
||||
}
|
||||
if len(results) != 1 || results[0].URL != "https://example.com" {
|
||||
t.Fatalf("unexpected results: %#v", results)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildContactSearchQuery(t *testing.T) {
|
||||
got := BuildContactSearchQuery("Jean", "Dupont", "Marie", "Ultimail", "CTO", "Lyon")
|
||||
want := "Jean Marie Dupont Ultimail CTO Lyon"
|
||||
@ -84,3 +188,23 @@ func TestResolveProvider(t *testing.T) {
|
||||
t.Fatalf("unexpected provider: %#v", p)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderConfigured(t *testing.T) {
|
||||
cases := []struct {
|
||||
p Provider
|
||||
want bool
|
||||
}{
|
||||
{Provider{Type: ProviderDuckDuckGo}, true},
|
||||
{Provider{Type: ProviderSearXNG, BaseURL: "https://searx.local"}, true},
|
||||
{Provider{Type: ProviderSearXNG}, false},
|
||||
{Provider{
|
||||
Type: ProviderCustom, BaseURL: "https://api.local",
|
||||
ResultsPath: "results", TitleField: "title", URLField: "url",
|
||||
}, true},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
if got := providerConfigured(tc.p); got != tc.want {
|
||||
t.Fatalf("providerConfigured(%#v) = %v, want %v", tc.p, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
100
internal/websearch/custom.go
Normal file
100
internal/websearch/custom.go
Normal file
@ -0,0 +1,100 @@
|
||||
package websearch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (c *Client) searchCustom(ctx context.Context, provider Provider, query string, count int) ([]Result, error) {
|
||||
endpoint := strings.TrimSpace(provider.BaseURL)
|
||||
if endpoint == "" {
|
||||
return nil, fmt.Errorf("custom search base url is required")
|
||||
}
|
||||
|
||||
endpoint = strings.NewReplacer(
|
||||
"{query}", url.QueryEscape(query),
|
||||
"{count}", fmt.Sprintf("%d", count),
|
||||
).Replace(endpoint)
|
||||
|
||||
reqURL, err := url.Parse(endpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !strings.Contains(provider.BaseURL, "{query}") {
|
||||
q := reqURL.Query()
|
||||
q.Set(queryParamName(provider), query)
|
||||
if !strings.Contains(provider.BaseURL, "{count}") {
|
||||
q.Set("count", fmt.Sprintf("%d", count))
|
||||
}
|
||||
reqURL.RawQuery = q.Encode()
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL.String(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
if apiKey := strings.TrimSpace(provider.APIKey); apiKey != "" {
|
||||
header := strings.TrimSpace(provider.AuthHeader)
|
||||
switch {
|
||||
case header != "":
|
||||
if strings.EqualFold(header, "Authorization") {
|
||||
req.Header.Set(header, "Bearer "+apiKey)
|
||||
} else {
|
||||
req.Header.Set(header, apiKey)
|
||||
}
|
||||
case strings.Contains(reqURL.RawQuery, "api_key="):
|
||||
// key already in URL
|
||||
default:
|
||||
q := reqURL.Query()
|
||||
q.Set("api_key", apiKey)
|
||||
reqURL.RawQuery = q.Encode()
|
||||
req.URL = reqURL
|
||||
}
|
||||
}
|
||||
|
||||
body, status, err := c.doRequest(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if status >= 400 {
|
||||
return nil, fmt.Errorf("custom search failed (%d): %s", status, string(body))
|
||||
}
|
||||
|
||||
root, err := decodeJSONBody(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items, ok := resultsArray(root, provider.ResultsPath)
|
||||
if !ok {
|
||||
return []Result{}, nil
|
||||
}
|
||||
|
||||
descField := strings.TrimSpace(provider.DescField)
|
||||
if descField == "" {
|
||||
descField = "description"
|
||||
}
|
||||
|
||||
limit := count
|
||||
if len(items) < limit {
|
||||
limit = len(items)
|
||||
}
|
||||
results := make([]Result, 0, limit)
|
||||
for _, item := range items[:limit] {
|
||||
title := jsonFieldString(item, provider.TitleField)
|
||||
link := jsonFieldString(item, provider.URLField)
|
||||
desc := jsonFieldString(item, descField)
|
||||
if title == "" && link == "" {
|
||||
continue
|
||||
}
|
||||
results = append(results, Result{
|
||||
Title: title,
|
||||
URL: link,
|
||||
Description: desc,
|
||||
})
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
104
internal/websearch/duckduckgo.go
Normal file
104
internal/websearch/duckduckgo.go
Normal file
@ -0,0 +1,104 @@
|
||||
package websearch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
duckResultLinkRe = regexp.MustCompile(`(?is)<a[^>]+class="result__a"[^>]+href="([^"]+)"[^>]*>(.*?)</a>`)
|
||||
duckResultSnipRe = regexp.MustCompile(`(?is)<a[^>]+class="result__snippet"[^>]*>(.*?)</a>`)
|
||||
htmlTagRe = regexp.MustCompile(`(?is)<[^>]+>`)
|
||||
)
|
||||
|
||||
func (c *Client) searchDuckDuckGo(ctx context.Context, provider Provider, query string, count int) ([]Result, error) {
|
||||
endpoint := strings.TrimSpace(provider.BaseURL)
|
||||
if endpoint == "" {
|
||||
endpoint = duckDuckGoHTMLURL
|
||||
}
|
||||
|
||||
form := url.Values{}
|
||||
form.Set("q", query)
|
||||
form.Set("b", "")
|
||||
form.Set("kl", "")
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "text/html")
|
||||
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, fmt.Errorf("duckduckgo search failed (%d)", resp.StatusCode)
|
||||
}
|
||||
|
||||
html := string(body)
|
||||
links := duckResultLinkRe.FindAllStringSubmatch(html, count)
|
||||
snips := duckResultSnipRe.FindAllStringSubmatch(html, count)
|
||||
if len(links) == 0 {
|
||||
return []Result{}, nil
|
||||
}
|
||||
|
||||
results := make([]Result, 0, len(links))
|
||||
for i, match := range links {
|
||||
if len(match) < 3 {
|
||||
continue
|
||||
}
|
||||
title := stripHTML(match[2])
|
||||
href := decodeDuckDuckGoURL(strings.TrimSpace(match[1]))
|
||||
desc := ""
|
||||
if i < len(snips) && len(snips[i]) > 1 {
|
||||
desc = stripHTML(snips[i][1])
|
||||
}
|
||||
if title == "" || href == "" {
|
||||
continue
|
||||
}
|
||||
results = append(results, Result{
|
||||
Title: title,
|
||||
URL: href,
|
||||
Description: desc,
|
||||
})
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func stripHTML(raw string) string {
|
||||
return strings.TrimSpace(htmlTagRe.ReplaceAllString(raw, ""))
|
||||
}
|
||||
|
||||
func decodeDuckDuckGoURL(raw string) string {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
if strings.HasPrefix(raw, "//") {
|
||||
raw = "https:" + raw
|
||||
}
|
||||
parsed, err := url.Parse(raw)
|
||||
if err != nil {
|
||||
return raw
|
||||
}
|
||||
if strings.Contains(parsed.Host, "duckduckgo.com") && parsed.Path == "/l/" {
|
||||
if uddg := parsed.Query().Get("uddg"); uddg != "" {
|
||||
if decoded, err := url.QueryUnescape(uddg); err == nil && decoded != "" {
|
||||
return decoded
|
||||
}
|
||||
}
|
||||
}
|
||||
return parsed.String()
|
||||
}
|
||||
62
internal/websearch/jsonpath.go
Normal file
62
internal/websearch/jsonpath.go
Normal file
@ -0,0 +1,62 @@
|
||||
package websearch
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func jsonPathValue(v any, path string) (any, bool) {
|
||||
path = strings.TrimSpace(path)
|
||||
if path == "" {
|
||||
return v, v != nil
|
||||
}
|
||||
cur := v
|
||||
for _, part := range strings.Split(path, ".") {
|
||||
part = strings.TrimSpace(part)
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
switch node := cur.(type) {
|
||||
case map[string]any:
|
||||
next, ok := node[part]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
cur = next
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
return cur, true
|
||||
}
|
||||
|
||||
func jsonFieldString(item any, field string) string {
|
||||
val, ok := jsonPathValue(item, field)
|
||||
if !ok || val == nil {
|
||||
return ""
|
||||
}
|
||||
switch v := val.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(v)
|
||||
default:
|
||||
return strings.TrimSpace(fmt.Sprint(v))
|
||||
}
|
||||
}
|
||||
|
||||
func decodeJSONBody(body []byte) (any, error) {
|
||||
var root any
|
||||
if err := json.Unmarshal(body, &root); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return root, nil
|
||||
}
|
||||
|
||||
func resultsArray(root any, path string) ([]any, bool) {
|
||||
val, ok := jsonPathValue(root, path)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
items, ok := val.([]any)
|
||||
return items, ok
|
||||
}
|
||||
82
internal/websearch/searxng.go
Normal file
82
internal/websearch/searxng.go
Normal file
@ -0,0 +1,82 @@
|
||||
package websearch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type searxNGResponse struct {
|
||||
Results []struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Content string `json:"content"`
|
||||
} `json:"results"`
|
||||
}
|
||||
|
||||
func (c *Client) searchSearXNG(ctx context.Context, provider Provider, query string, count int) ([]Result, error) {
|
||||
base := strings.TrimRight(strings.TrimSpace(provider.BaseURL), "/")
|
||||
if base == "" {
|
||||
return nil, fmt.Errorf("searxng base url is required")
|
||||
}
|
||||
|
||||
u, err := url.Parse(base + "/search")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
q := u.Query()
|
||||
q.Set(queryParamName(provider), query)
|
||||
q.Set("format", "json")
|
||||
q.Set("categories", "general")
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
if apiKey := strings.TrimSpace(provider.APIKey); apiKey != "" {
|
||||
header := strings.TrimSpace(provider.AuthHeader)
|
||||
if header == "" {
|
||||
header = "Authorization"
|
||||
}
|
||||
if strings.EqualFold(header, "Authorization") {
|
||||
req.Header.Set(header, "Bearer "+apiKey)
|
||||
} else {
|
||||
req.Header.Set(header, apiKey)
|
||||
}
|
||||
}
|
||||
|
||||
body, status, err := c.doRequest(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if status >= 400 {
|
||||
return nil, fmt.Errorf("searxng search failed (%d): %s", status, string(body))
|
||||
}
|
||||
|
||||
var parsed searxNGResponse
|
||||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(parsed.Results) == 0 {
|
||||
return []Result{}, nil
|
||||
}
|
||||
|
||||
limit := count
|
||||
if len(parsed.Results) < limit {
|
||||
limit = len(parsed.Results)
|
||||
}
|
||||
results := make([]Result, 0, limit)
|
||||
for _, item := range parsed.Results[:limit] {
|
||||
results = append(results, Result{
|
||||
Title: strings.TrimSpace(item.Title),
|
||||
URL: strings.TrimSpace(item.URL),
|
||||
Description: strings.TrimSpace(item.Content),
|
||||
})
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
@ -7,13 +7,26 @@ import (
|
||||
|
||||
type ProviderType string
|
||||
|
||||
const ProviderBrave ProviderType = "brave"
|
||||
const (
|
||||
ProviderBrave ProviderType = "brave"
|
||||
ProviderBing ProviderType = "bing"
|
||||
ProviderDuckDuckGo ProviderType = "duckduckgo"
|
||||
ProviderSearXNG ProviderType = "searxng"
|
||||
ProviderCustom ProviderType = "custom"
|
||||
)
|
||||
|
||||
type Provider struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Type ProviderType `json:"type"`
|
||||
APIKey string `json:"api_key,omitempty"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Type ProviderType `json:"type"`
|
||||
APIKey string `json:"api_key,omitempty"`
|
||||
BaseURL string `json:"base_url,omitempty"`
|
||||
QueryParam string `json:"query_param,omitempty"`
|
||||
AuthHeader string `json:"auth_header,omitempty"`
|
||||
ResultsPath string `json:"results_path,omitempty"`
|
||||
TitleField string `json:"title_field,omitempty"`
|
||||
URLField string `json:"url_field,omitempty"`
|
||||
DescField string `json:"description_field,omitempty"`
|
||||
}
|
||||
|
||||
type Settings struct {
|
||||
@ -39,5 +52,26 @@ func ResolveProvider(settings Settings) (Provider, error) {
|
||||
}
|
||||
|
||||
func providerConfigured(p Provider) bool {
|
||||
return strings.TrimSpace(p.APIKey) != "" && strings.TrimSpace(string(p.Type)) != ""
|
||||
switch p.Type {
|
||||
case ProviderBrave, ProviderBing:
|
||||
return strings.TrimSpace(p.APIKey) != ""
|
||||
case ProviderDuckDuckGo:
|
||||
return true
|
||||
case ProviderSearXNG:
|
||||
return strings.TrimSpace(p.BaseURL) != ""
|
||||
case ProviderCustom:
|
||||
return strings.TrimSpace(p.BaseURL) != "" &&
|
||||
strings.TrimSpace(p.ResultsPath) != "" &&
|
||||
strings.TrimSpace(p.TitleField) != "" &&
|
||||
strings.TrimSpace(p.URLField) != ""
|
||||
default:
|
||||
return strings.TrimSpace(string(p.Type)) != "" && strings.TrimSpace(p.APIKey) != ""
|
||||
}
|
||||
}
|
||||
|
||||
func queryParamName(provider Provider) string {
|
||||
if name := strings.TrimSpace(provider.QueryParam); name != "" {
|
||||
return name
|
||||
}
|
||||
return "q"
|
||||
}
|
||||
|
||||
2
migrations/000051_user_groups.down.sql
Normal file
2
migrations/000051_user_groups.down.sql
Normal file
@ -0,0 +1,2 @@
|
||||
DROP TABLE IF EXISTS user_group_members;
|
||||
DROP TABLE IF EXISTS user_groups;
|
||||
18
migrations/000051_user_groups.up.sql
Normal file
18
migrations/000051_user_groups.up.sql
Normal file
@ -0,0 +1,18 @@
|
||||
CREATE TABLE user_groups (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
name TEXT NOT NULL,
|
||||
description TEXT NOT NULL DEFAULT '',
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX idx_user_groups_name_lower ON user_groups (LOWER(name));
|
||||
|
||||
CREATE TABLE user_group_members (
|
||||
group_id UUID NOT NULL REFERENCES user_groups(id) ON DELETE CASCADE,
|
||||
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
PRIMARY KEY (group_id, user_id)
|
||||
);
|
||||
|
||||
CREATE INDEX idx_user_group_members_user ON user_group_members(user_id);
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
Tu es UltiAI, l'assistant de la suite souveraine Ultimail.
|
||||
|
||||
- Utilise les tools pour lire/agir sur mail, drive, contacts et documents UltiDocs quand c'est pertinent.
|
||||
- Utilise les tools pour lire/agir sur mail, drive, contacts, documents UltiDocs et rechercher sur le web (web_search) quand c'est pertinent.
|
||||
- Cite les sources (sujet mail, chemin fichier, nom contact).
|
||||
- Ne fabrique pas de données : interroge l'API via les tools.
|
||||
- Réponds en français par défaut.
|
||||
|
||||
BIN
services/openwebui/static/favicon.ico
Normal file
BIN
services/openwebui/static/favicon.ico
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.9 KiB |
323
services/openwebui/static/ulti-session.js
Normal file
323
services/openwebui/static/ulti-session.js
Normal file
@ -0,0 +1,323 @@
|
||||
/**
|
||||
* UltiAI embed bootstrap — must run before OpenWebUI SvelteKit init.
|
||||
* Layout requires localStorage.token; without embed signin it redirects to /auth → redirect loop → 500.
|
||||
*/
|
||||
;(function () {
|
||||
const STORAGE_KEY = "ulti.ai.session"
|
||||
const CONTEXT_KEY = "ulti.ai.context"
|
||||
const DEFAULT_MODEL_KEY = "ulti.ai.default_model"
|
||||
const LAST_CHAT_PATH_KEY = "ulti.ai.last_chat_path"
|
||||
const SIGNIN_URL = "/ai/api/v1/auths/signin"
|
||||
const CONFIG_URL = "/api/v1/ai/config"
|
||||
const BASE_SYSTEM_PROMPT = [
|
||||
"Tu es UltiAI, l'assistant intégré à la suite Ultimail (mail, drive, contacts, agenda).",
|
||||
"Réponds en français sauf demande contraire. Utilise les tools disponibles pour agir sur les données utilisateur.",
|
||||
"Après chaque appel d'outil, réponds toujours en langage naturel : résume le résultat, cite les sources, propose la suite.",
|
||||
"Ne termine jamais un tour utilisateur avec uniquement un appel d'outil sans texte explicatif.",
|
||||
"Respecte strictement le paramètre limit des tools.",
|
||||
].join(" ")
|
||||
|
||||
function readPageUrl() {
|
||||
try {
|
||||
return new URL(window.location.href)
|
||||
} catch {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
function isTemporaryEmbed() {
|
||||
const url = readPageUrl()
|
||||
return url?.searchParams.get("temporary-chat") === "true"
|
||||
}
|
||||
|
||||
function modelFromUrl() {
|
||||
const url = readPageUrl()
|
||||
if (!url) return ""
|
||||
const raw = url.searchParams.get("model") || url.searchParams.get("models") || ""
|
||||
return raw.split(",")[0]?.trim() || ""
|
||||
}
|
||||
|
||||
function applySelectedModels(modelIds) {
|
||||
const ids = modelIds.map((id) => String(id || "").trim()).filter(Boolean)
|
||||
if (!ids.length) return
|
||||
try {
|
||||
sessionStorage.setItem("selectedModels", JSON.stringify(ids))
|
||||
localStorage.setItem(DEFAULT_MODEL_KEY, ids[0])
|
||||
} catch {
|
||||
// storage unavailable
|
||||
}
|
||||
}
|
||||
|
||||
function bootstrapSelectedModels(modelHint) {
|
||||
const resolved = modelFromUrl() || String(modelHint || "").trim() || localStorage.getItem(DEFAULT_MODEL_KEY) || ""
|
||||
if (resolved) applySelectedModels([resolved])
|
||||
return resolved
|
||||
}
|
||||
|
||||
function restoreLastChatIfHome() {
|
||||
if (isTemporaryEmbed()) return
|
||||
const url = readPageUrl()
|
||||
if (!url) return
|
||||
const path = url.pathname.replace(/\/$/, "") || "/"
|
||||
if (path !== "/ai") return
|
||||
if (url.searchParams.get("model") || url.searchParams.get("models")) return
|
||||
try {
|
||||
const last = localStorage.getItem(LAST_CHAT_PATH_KEY)
|
||||
if (!last || !/\/c\/[^/]+/.test(last)) return
|
||||
const target = last.startsWith("/") ? last : `/ai${last}`
|
||||
if (target === url.pathname + url.search) return
|
||||
window.location.replace(target)
|
||||
} catch {
|
||||
// ignore
|
||||
}
|
||||
}
|
||||
|
||||
function watchChatRoute() {
|
||||
const save = () => {
|
||||
if (isTemporaryEmbed()) return
|
||||
try {
|
||||
const path = window.location.pathname
|
||||
if (/\/c\/[^/]+/.test(path)) {
|
||||
localStorage.setItem(LAST_CHAT_PATH_KEY, path + window.location.search)
|
||||
}
|
||||
} catch {
|
||||
// ignore
|
||||
}
|
||||
}
|
||||
window.addEventListener("popstate", save)
|
||||
const timer = window.setInterval(save, 2000)
|
||||
window.addEventListener("beforeunload", () => window.clearInterval(timer))
|
||||
save()
|
||||
}
|
||||
|
||||
function ensureEmbedSignin() {
|
||||
try {
|
||||
if (localStorage.getItem("token")) return
|
||||
const xhr = new XMLHttpRequest()
|
||||
xhr.open("POST", SIGNIN_URL, false)
|
||||
xhr.withCredentials = true
|
||||
xhr.setRequestHeader("Content-Type", "application/json")
|
||||
xhr.send(JSON.stringify({ email: "", password: "" }))
|
||||
if (xhr.status !== 200) return
|
||||
const data = JSON.parse(xhr.responseText)
|
||||
if (data && data.token) {
|
||||
localStorage.setItem("token", data.token)
|
||||
}
|
||||
} catch {
|
||||
// Ultimail session not ready yet — parent may post ULTI_SESSION later
|
||||
}
|
||||
}
|
||||
|
||||
async function fetchEmbedConfig() {
|
||||
try {
|
||||
const response = await fetch(CONFIG_URL, { credentials: "include" })
|
||||
if (!response.ok) return null
|
||||
return await response.json()
|
||||
} catch {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
async function syncUserModelPreference(modelId) {
|
||||
const model = String(modelId || "").trim()
|
||||
if (!model) return false
|
||||
const token = localStorage.getItem("token")
|
||||
if (!token) return false
|
||||
try {
|
||||
const settingsRes = await fetch("/ai/api/v1/users/user/settings", {
|
||||
headers: { Authorization: `Bearer ${token}` },
|
||||
credentials: "include",
|
||||
})
|
||||
if (!settingsRes.ok) return false
|
||||
const settings = await settingsRes.json()
|
||||
const ui = settings?.ui || {}
|
||||
const current = ui.models || ui.model_ids
|
||||
if (Array.isArray(current) && current.length > 0 && String(current[0] || "").trim()) {
|
||||
return true
|
||||
}
|
||||
|
||||
const modelsRes = await fetch("/ai/api/models", {
|
||||
headers: { Authorization: `Bearer ${token}` },
|
||||
credentials: "include",
|
||||
})
|
||||
if (modelsRes.ok) {
|
||||
const modelsPayload = await modelsRes.json()
|
||||
const items = Array.isArray(modelsPayload?.data)
|
||||
? modelsPayload.data
|
||||
: Array.isArray(modelsPayload)
|
||||
? modelsPayload
|
||||
: []
|
||||
const known = items.some((entry) => entry && entry.id === model)
|
||||
if (!known) return false
|
||||
}
|
||||
|
||||
const updateRes = await fetch("/ai/api/v1/users/user/settings/update", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
Authorization: `Bearer ${token}`,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
credentials: "include",
|
||||
body: JSON.stringify({ ui: { models: [model], model_ids: [model] } }),
|
||||
})
|
||||
if (!updateRes.ok) return false
|
||||
applySelectedModels([model])
|
||||
return true
|
||||
} catch {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
function scheduleModelSync(modelId) {
|
||||
const model = String(modelId || "").trim()
|
||||
if (!model) return
|
||||
let attempts = 0
|
||||
const tick = async () => {
|
||||
attempts += 1
|
||||
const done = await syncUserModelPreference(model)
|
||||
if (done || attempts >= 12) return
|
||||
window.setTimeout(tick, 1000)
|
||||
}
|
||||
window.setTimeout(tick, 500)
|
||||
}
|
||||
|
||||
function readContextPrompt() {
|
||||
try {
|
||||
const raw = sessionStorage.getItem(CONTEXT_KEY)
|
||||
if (!raw) return BASE_SYSTEM_PROMPT
|
||||
const parsed = JSON.parse(raw)
|
||||
const parts = [BASE_SYSTEM_PROMPT, parsed.systemPrompt].filter(Boolean)
|
||||
return parts.join("\n\n")
|
||||
} catch {
|
||||
return BASE_SYSTEM_PROMPT
|
||||
}
|
||||
}
|
||||
|
||||
function injectSystemPrompt(body) {
|
||||
if (!body || !Array.isArray(body.messages)) return body
|
||||
const prompt = readContextPrompt()
|
||||
if (!prompt) return body
|
||||
const messages = body.messages.slice()
|
||||
const systemIndex = messages.findIndex((m) => m && m.role === "system")
|
||||
if (systemIndex >= 0) {
|
||||
messages[systemIndex] = { ...messages[systemIndex], content: prompt }
|
||||
} else {
|
||||
messages.unshift({ role: "system", content: prompt })
|
||||
}
|
||||
return { ...body, messages }
|
||||
}
|
||||
|
||||
function patchFetch() {
|
||||
if (window.__ultiFetchPatched) return
|
||||
window.__ultiFetchPatched = true
|
||||
const originalFetch = window.fetch.bind(window)
|
||||
window.fetch = async function patchedFetch(input, init) {
|
||||
const url = typeof input === "string" ? input : input instanceof Request ? input.url : ""
|
||||
const method = (init && init.method) || (input instanceof Request ? input.method : "GET")
|
||||
if (
|
||||
method.toUpperCase() === "POST" &&
|
||||
/\/api\/(v1\/)?chat\/completions(?:\?|$)/.test(url) &&
|
||||
init &&
|
||||
typeof init.body === "string"
|
||||
) {
|
||||
try {
|
||||
const body = JSON.parse(init.body)
|
||||
const patched = injectSystemPrompt(body)
|
||||
if (patched !== body) {
|
||||
init = { ...init, body: JSON.stringify(patched) }
|
||||
}
|
||||
} catch {
|
||||
// ignore malformed bodies
|
||||
}
|
||||
}
|
||||
return originalFetch(input, init)
|
||||
}
|
||||
}
|
||||
|
||||
function brandDocument() {
|
||||
const titleEl = document.querySelector("title")
|
||||
if (titleEl) {
|
||||
const t = titleEl.textContent || ""
|
||||
if (t.includes("Open WebUI")) {
|
||||
titleEl.textContent = t.replace(/UltiAI \(Open WebUI\)/g, "UltiAI").replace(/Open WebUI/g, "UltiAI")
|
||||
} else if (!t.trim()) {
|
||||
titleEl.textContent = "UltiAI"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function watchDocumentTitle() {
|
||||
brandDocument()
|
||||
const titleEl = document.querySelector("title")
|
||||
if (!titleEl) return
|
||||
new MutationObserver(() => brandDocument()).observe(titleEl, {
|
||||
childList: true,
|
||||
characterData: true,
|
||||
subtree: true,
|
||||
})
|
||||
}
|
||||
|
||||
restoreLastChatIfHome()
|
||||
bootstrapSelectedModels("")
|
||||
ensureEmbedSignin()
|
||||
patchFetch()
|
||||
watchDocumentTitle()
|
||||
watchChatRoute()
|
||||
|
||||
void (async () => {
|
||||
const config = await fetchEmbedConfig()
|
||||
const model = bootstrapSelectedModels(config?.default_model)
|
||||
if (!localStorage.getItem("token")) ensureEmbedSignin()
|
||||
if (model) scheduleModelSync(model)
|
||||
})()
|
||||
|
||||
window.addEventListener("message", (event) => {
|
||||
if (event.source !== window.parent) return
|
||||
const data = event.data
|
||||
if (!data || typeof data !== "object") return
|
||||
|
||||
if (data.type === "ULTI_SESSION") {
|
||||
try {
|
||||
sessionStorage.setItem(
|
||||
STORAGE_KEY,
|
||||
JSON.stringify({
|
||||
token_secret: data.token_secret,
|
||||
session_id: data.session_id,
|
||||
mcp_url: data.mcp_url,
|
||||
enabled_tools: data.enabled_tools,
|
||||
default_model: data.default_model,
|
||||
updated_at: Date.now(),
|
||||
})
|
||||
)
|
||||
} catch {
|
||||
// sessionStorage unavailable
|
||||
}
|
||||
if (data.default_model && typeof data.default_model === "string") {
|
||||
const model = data.default_model.trim()
|
||||
if (model) {
|
||||
applySelectedModels([model])
|
||||
scheduleModelSync(model)
|
||||
}
|
||||
}
|
||||
if (!localStorage.getItem("token")) {
|
||||
ensureEmbedSignin()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if (data.type === "ULTI_CONTEXT_UPDATE") {
|
||||
try {
|
||||
sessionStorage.setItem(
|
||||
CONTEXT_KEY,
|
||||
JSON.stringify({
|
||||
systemPrompt: data.systemPrompt,
|
||||
context: data.context,
|
||||
updated_at: Date.now(),
|
||||
})
|
||||
)
|
||||
} catch {
|
||||
// sessionStorage unavailable
|
||||
}
|
||||
}
|
||||
})
|
||||
})()
|
||||
726
services/ultimail-mcp/dist/index.js
vendored
Normal file
726
services/ultimail-mcp/dist/index.js
vendored
Normal file
@ -0,0 +1,726 @@
|
||||
import { randomUUID } from "node:crypto";
|
||||
import express from "express";
|
||||
import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js";
|
||||
import { SSEServerTransport } from "@modelcontextprotocol/sdk/server/sse.js";
|
||||
import { StreamableHTTPServerTransport } from "@modelcontextprotocol/sdk/server/streamableHttp.js";
|
||||
import { isInitializeRequest } from "@modelcontextprotocol/sdk/types.js";
|
||||
import { z } from "zod";
|
||||
const PORT = Number(process.env.MCP_PORT ?? 3100);
|
||||
const API_BASE = (process.env.ULTID_API_URL ?? "http://localhost:8080/api/v1").replace(/\/$/, "");
|
||||
const TOOL_GROUPS = {
|
||||
mail: [
|
||||
"mail_search",
|
||||
"mail_read",
|
||||
"mail_list",
|
||||
"mail_send",
|
||||
"mail_update_labels",
|
||||
"mail_update_flags",
|
||||
"mail_delete",
|
||||
],
|
||||
drive: [
|
||||
"drive_list",
|
||||
"drive_file_info",
|
||||
"drive_create_file",
|
||||
"drive_create_folder",
|
||||
"drive_delete",
|
||||
"drive_move",
|
||||
"drive_rename",
|
||||
"drive_share_create",
|
||||
"drive_share_update",
|
||||
"drive_share_delete",
|
||||
],
|
||||
contacts: [
|
||||
"contacts_search",
|
||||
"contacts_list_books",
|
||||
"contacts_list",
|
||||
"contacts_get",
|
||||
"contacts_create",
|
||||
"contacts_update",
|
||||
"contacts_delete",
|
||||
],
|
||||
search: ["suite_search"],
|
||||
docs: ["docs_read", "docs_save"],
|
||||
agenda: [
|
||||
"calendar_list",
|
||||
"calendar_create",
|
||||
"calendar_update",
|
||||
"calendar_delete",
|
||||
"calendar_list_events",
|
||||
"calendar_freebusy",
|
||||
"calendar_create_event",
|
||||
"calendar_update_event",
|
||||
"calendar_delete_event",
|
||||
"calendar_respond_invitation",
|
||||
"calendar_create_meet_link",
|
||||
],
|
||||
web_search: ["web_search"],
|
||||
};
|
||||
const ALL_TOOL_NAMES = [...new Set(Object.values(TOOL_GROUPS).flat())];
|
||||
function parseEnabledTools(raw) {
|
||||
const groups = String(raw ?? "")
|
||||
.split(",")
|
||||
.map((part) => part.trim().toLowerCase())
|
||||
.filter(Boolean);
|
||||
if (groups.length === 0)
|
||||
return null;
|
||||
const enabled = new Set();
|
||||
for (const group of groups) {
|
||||
for (const tool of TOOL_GROUPS[group] ?? []) {
|
||||
enabled.add(tool);
|
||||
}
|
||||
}
|
||||
return enabled.size > 0 ? enabled : null;
|
||||
}
|
||||
function encodePath(path) {
|
||||
return path
|
||||
.replace(/^\/+/, "")
|
||||
.split("/")
|
||||
.filter(Boolean)
|
||||
.map((seg) => encodeURIComponent(seg))
|
||||
.join("/");
|
||||
}
|
||||
function toolText(data) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: typeof data === "string" ? data : JSON.stringify(data, null, 2),
|
||||
},
|
||||
],
|
||||
};
|
||||
}
|
||||
/** Ultid list APIs use page + page_size, not limit + offset. */
|
||||
function listPaginationParams(limit, offset) {
|
||||
const qs = new URLSearchParams();
|
||||
const pageSize = limit != null && Number.isFinite(limit) && limit > 0
|
||||
? Math.min(Math.trunc(limit), 500)
|
||||
: undefined;
|
||||
if (pageSize != null) {
|
||||
qs.set("page_size", String(pageSize));
|
||||
const off = offset != null && Number.isFinite(offset) && offset > 0 ? Math.trunc(offset) : 0;
|
||||
qs.set("page", String(Math.floor(off / pageSize) + 1));
|
||||
}
|
||||
else if (offset != null && Number.isFinite(offset) && offset > 0) {
|
||||
qs.set("page_size", "50");
|
||||
qs.set("page", String(Math.floor(Math.trunc(offset) / 50) + 1));
|
||||
}
|
||||
return qs;
|
||||
}
|
||||
function withListPagination(base, limit, offset) {
|
||||
const qs = new URLSearchParams(base);
|
||||
for (const [key, value] of listPaginationParams(limit, offset)) {
|
||||
qs.set(key, value);
|
||||
}
|
||||
return qs;
|
||||
}
|
||||
function applyHiddenMailboxParams(qs, include_spam, include_trash) {
|
||||
if (include_spam)
|
||||
qs.set("include_spam", "true");
|
||||
if (include_trash)
|
||||
qs.set("include_trash", "true");
|
||||
}
|
||||
async function ultiFetch(token, path, init) {
|
||||
const headers = new Headers(init?.headers);
|
||||
headers.set("Accept", "application/json");
|
||||
if (token)
|
||||
headers.set("Authorization", `Bearer ${token}`);
|
||||
const res = await fetch(`${API_BASE}${path}`, { ...init, headers });
|
||||
const text = await res.text();
|
||||
if (!res.ok) {
|
||||
throw new Error(`ulti ${path} failed (${res.status}): ${text.slice(0, 500)}`);
|
||||
}
|
||||
if (!text)
|
||||
return { ok: true, status: res.status };
|
||||
try {
|
||||
return JSON.parse(text);
|
||||
}
|
||||
catch {
|
||||
return text;
|
||||
}
|
||||
}
|
||||
function resolveToken(req) {
|
||||
return (String(req.headers["x-ulti-token"] ?? req.headers.authorization ?? "").replace(/^Bearer\s+/i, "") || "");
|
||||
}
|
||||
function createServer(getToken, enabledTools) {
|
||||
const server = new McpServer({
|
||||
name: "ultimail-mcp",
|
||||
version: "0.1.0",
|
||||
});
|
||||
const allow = (toolName) => !enabledTools || enabledTools.has(toolName);
|
||||
if (allow("mail_search")) {
|
||||
server.tool("mail_search", "Search mail messages (spam and trash excluded by default)", {
|
||||
query: z.string(),
|
||||
account_id: z.string().optional(),
|
||||
limit: z.number().optional(),
|
||||
offset: z.number().optional(),
|
||||
include_spam: z.boolean().optional(),
|
||||
include_trash: z.boolean().optional(),
|
||||
}, async ({ query, account_id, limit, offset, include_spam, include_trash }) => {
|
||||
const qs = withListPagination(new URLSearchParams({ q: query }), limit, offset);
|
||||
if (account_id)
|
||||
qs.set("account_id", account_id);
|
||||
applyHiddenMailboxParams(qs, include_spam, include_trash);
|
||||
const data = await ultiFetch(getToken(), `/mail/search?${qs}`);
|
||||
return toolText(data);
|
||||
});
|
||||
}
|
||||
if (allow("mail_read")) {
|
||||
server.tool("mail_read", "Read a mail message by id", { message_id: z.string(), account_id: z.string().optional() }, async ({ message_id, account_id }) => {
|
||||
const qs = account_id ? `?account_id=${encodeURIComponent(account_id)}` : "";
|
||||
const data = await ultiFetch(getToken(), `/mail/messages/${message_id}${qs}`);
|
||||
return toolText(data);
|
||||
});
|
||||
}
|
||||
if (allow("mail_list")) {
|
||||
server.tool("mail_list", "List mail messages in a folder (spam and trash excluded by default unless folder=spam/trash)", {
|
||||
folder: z.string().optional(),
|
||||
account_id: z.string().optional(),
|
||||
limit: z.number().optional(),
|
||||
offset: z.number().optional(),
|
||||
include_spam: z.boolean().optional(),
|
||||
include_trash: z.boolean().optional(),
|
||||
}, async ({ folder, account_id, limit, offset, include_spam, include_trash }) => {
|
||||
let qs = new URLSearchParams();
|
||||
if (folder)
|
||||
qs.set("folder", folder);
|
||||
if (account_id)
|
||||
qs.set("account_id", account_id);
|
||||
qs = withListPagination(qs, limit, offset);
|
||||
applyHiddenMailboxParams(qs, include_spam, include_trash);
|
||||
const suffix = qs.size > 0 ? `?${qs}` : "";
|
||||
const data = await ultiFetch(getToken(), `/mail/messages${suffix}`);
|
||||
return toolText(data);
|
||||
});
|
||||
}
|
||||
if (allow("mail_send")) {
|
||||
server.tool("mail_send", "Send a mail message", {
|
||||
account_id: z.string(),
|
||||
to: z.array(z.string()),
|
||||
cc: z.array(z.string()).optional(),
|
||||
bcc: z.array(z.string()).optional(),
|
||||
subject: z.string(),
|
||||
body_text: z.string().optional(),
|
||||
body_html: z.string().optional(),
|
||||
in_reply_to: z.string().optional(),
|
||||
reply_to_message_id: z.string().optional(),
|
||||
schedule_at: z.string().optional(),
|
||||
}, async (body) => {
|
||||
const data = await ultiFetch(getToken(), "/mail/send", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
"Idempotency-Key": randomUUID(),
|
||||
},
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
return toolText(data);
|
||||
});
|
||||
}
|
||||
if (allow("mail_update_labels")) {
|
||||
server.tool("mail_update_labels", "Update labels on a mail message", { message_id: z.string(), labels: z.array(z.string()) }, async ({ message_id, labels }) => {
|
||||
const data = await ultiFetch(getToken(), `/mail/messages/${message_id}/labels`, {
|
||||
method: "PUT",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ labels }),
|
||||
});
|
||||
return toolText(data);
|
||||
});
|
||||
}
|
||||
if (allow("mail_update_flags")) {
|
||||
server.tool("mail_update_flags", "Update flags on a mail message", { message_id: z.string(), flags: z.array(z.string()) }, async ({ message_id, flags }) => {
|
||||
const data = await ultiFetch(getToken(), `/mail/messages/${message_id}/flags`, {
|
||||
method: "PUT",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ flags }),
|
||||
});
|
||||
return toolText(data);
|
||||
});
|
||||
}
|
||||
if (allow("mail_delete")) {
|
||||
server.tool("mail_delete", "Delete a mail message", { message_id: z.string() }, async ({ message_id }) => {
|
||||
const data = await ultiFetch(getToken(), `/mail/messages/${message_id}`, {
|
||||
method: "DELETE",
|
||||
});
|
||||
return toolText(data);
|
||||
});
|
||||
}
|
||||
if (allow("drive_list")) {
|
||||
server.tool("drive_list", "List drive files in a folder", {
|
||||
path: z.string().optional(),
|
||||
limit: z.number().optional(),
|
||||
offset: z.number().optional(),
|
||||
}, async ({ path, limit, offset }) => {
|
||||
const encoded = encodePath(path ?? "");
|
||||
const qs = listPaginationParams(limit, offset);
|
||||
const suffix = qs.size > 0 ? `?${qs}` : "";
|
||||
const data = await ultiFetch(getToken(), `/drive/files/${encoded}${suffix}`);
|
||||
return toolText(data);
|
||||
});
|
||||
}
|
||||
if (allow("drive_file_info")) {
|
||||
server.tool("drive_file_info", "Get drive file or folder metadata", { path: z.string() }, async ({ path }) => {
|
||||
const encoded = encodePath(path);
|
||||
const data = await ultiFetch(getToken(), `/drive/files/info/${encoded}`);
|
||||
return toolText(data);
|
||||
});
|
||||
}
|
||||
if (allow("drive_create_file")) {
|
||||
server.tool("drive_create_file", "Create a new office file (document, spreadsheet, presentation, drawing)", {
|
||||
parent_path: z.string().optional(),
|
||||
name: z.string(),
|
||||
kind: z.enum(["document", "spreadsheet", "presentation", "drawing"]),
|
||||
}, async ({ parent_path, name, kind }) => {
|
||||
const data = await ultiFetch(getToken(), "/drive/files/new", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ parent_path, name, kind }),
|
||||
});
|
||||
return toolText(data);
|
||||
});
|
||||
}
|
||||
if (allow("drive_create_folder")) {
|
||||
server.tool("drive_create_folder", "Create a drive folder", { path: z.string() }, async ({ path }) => {
|
||||
const encoded = encodePath(path);
|
||||
const data = await ultiFetch(getToken(), `/drive/folders/${encoded}`, {
|
||||
method: "POST",
|
||||
});
|
||||
return toolText(data);
|
||||
});
|
||||
}
|
||||
if (allow("drive_delete")) {
|
||||
server.tool("drive_delete", "Delete a drive file or folder", { path: z.string() }, async ({ path }) => {
|
||||
const encoded = encodePath(path);
|
||||
const data = await ultiFetch(getToken(), `/drive/files/${encoded}`, {
|
||||
method: "DELETE",
|
||||
});
|
||||
return toolText(data);
|
||||
});
|
||||
}
|
||||
if (allow("drive_move")) {
|
||||
server.tool("drive_move", "Move a drive file or folder", {
|
||||
source: z.string(),
|
||||
destination: z.string(),
|
||||
source_root: z.string().optional(),
|
||||
source_root_id: z.string().optional(),
|
||||
destination_root: z.string().optional(),
|
||||
destination_root_id: z.string().optional(),
|
||||
}, async (body) => {
|
||||
const data = await ultiFetch(getToken(), "/drive/move", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
return toolText(data);
|
||||
});
|
||||
}
|
||||
if (allow("drive_rename")) {
|
||||
server.tool("drive_rename", "Rename a drive file or folder", {
|
||||
path: z.string(),
|
||||
new_name: z.string(),
|
||||
root: z.string().optional(),
|
||||
root_id: z.string().optional(),
|
||||
}, async (body) => {
|
||||
const data = await ultiFetch(getToken(), "/drive/rename", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
return toolText(data);
|
||||
});
|
||||
}
|
||||
if (allow("drive_share_create")) {
|
||||
server.tool("drive_share_create", "Create a drive share link or internal share", {
|
||||
path: z.string(),
|
||||
share_type: z.number().optional(),
|
||||
permissions: z.number().optional(),
|
||||
role: z.enum(["owner", "editor", "viewer"]).optional(),
|
||||
mode: z.enum(["public", "internal", "contact"]).optional(),
|
||||
share_with: z.string().optional(),
|
||||
note: z.string().optional(),
|
||||
root: z.string().optional(),
|
||||
root_id: z.string().optional(),
|
||||
}, async (body) => {
|
||||
const data = await ultiFetch(getToken(), "/drive/shares", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
return toolText(data);
|
||||
});
|
||||
}
|
||||
if (allow("drive_share_update")) {
|
||||
server.tool("drive_share_update", "Update an existing drive share", {
|
||||
share_id: z.string(),
|
||||
permissions: z.number().optional(),
|
||||
role: z.enum(["owner", "editor", "viewer"]).optional(),
|
||||
expire_date: z.string().optional(),
|
||||
password: z.string().optional(),
|
||||
}, async ({ share_id, ...body }) => {
|
||||
const data = await ultiFetch(getToken(), `/drive/shares/${share_id}`, {
|
||||
method: "PUT",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
return toolText(data);
|
||||
});
|
||||
}
|
||||
if (allow("drive_share_delete")) {
|
||||
server.tool("drive_share_delete", "Delete a drive share", { share_id: z.string() }, async ({ share_id }) => {
|
||||
const data = await ultiFetch(getToken(), `/drive/shares/${share_id}`, {
|
||||
method: "DELETE",
|
||||
});
|
||||
return toolText(data);
|
||||
});
|
||||
}
|
||||
if (allow("contacts_search")) {
|
||||
server.tool("contacts_search", "Search contacts", {
|
||||
query: z.string(),
|
||||
book_id: z.string().optional(),
|
||||
limit: z.number().optional(),
|
||||
offset: z.number().optional(),
|
||||
}, async ({ query, book_id, limit, offset }) => {
|
||||
const qs = withListPagination(new URLSearchParams({ q: query }), limit, offset);
|
||||
if (book_id)
|
||||
qs.set("book_id", book_id);
|
||||
const data = await ultiFetch(getToken(), `/contacts/search?${qs}`);
|
||||
return toolText(data);
|
||||
});
|
||||
}
|
||||
if (allow("contacts_list_books")) {
|
||||
server.tool("contacts_list_books", "List address books", {}, async () => {
|
||||
const data = await ultiFetch(getToken(), "/contacts/books");
|
||||
return toolText(data);
|
||||
});
|
||||
}
|
||||
if (allow("contacts_list")) {
|
||||
server.tool("contacts_list", "List contacts in an address book", {
|
||||
book_id: z.string(),
|
||||
limit: z.number().optional(),
|
||||
offset: z.number().optional(),
|
||||
}, async ({ book_id, limit, offset }) => {
|
||||
const qs = listPaginationParams(limit, offset);
|
||||
const suffix = qs.size > 0 ? `?${qs}` : "";
|
||||
const data = await ultiFetch(getToken(), `/contacts/books/${book_id}${suffix}`);
|
||||
return toolText(data);
|
||||
});
|
||||
}
|
||||
if (allow("contacts_get")) {
|
||||
server.tool("contacts_get", "Get a contact by path (book_id/contact_uid)", { path: z.string() }, async ({ path }) => {
|
||||
const encoded = encodePath(path);
|
||||
const data = await ultiFetch(getToken(), `/contacts/${encoded}`);
|
||||
return toolText(data);
|
||||
});
|
||||
}
|
||||
if (allow("contacts_create")) {
|
||||
server.tool("contacts_create", "Create a contact in an address book", {
|
||||
book_id: z.string(),
|
||||
contact: z.record(z.string(), z.unknown()),
|
||||
}, async ({ book_id, contact }) => {
|
||||
const data = await ultiFetch(getToken(), `/contacts/books/${book_id}`, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify(contact),
|
||||
});
|
||||
return toolText(data);
|
||||
});
|
||||
}
|
||||
if (allow("contacts_update")) {
|
||||
server.tool("contacts_update", "Update a contact (requires etag from get)", {
|
||||
path: z.string(),
|
||||
if_match: z.string(),
|
||||
contact: z.record(z.string(), z.unknown()),
|
||||
}, async ({ path, if_match, contact }) => {
|
||||
const encoded = encodePath(path);
|
||||
const data = await ultiFetch(getToken(), `/contacts/${encoded}`, {
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
"If-Match": if_match,
|
||||
},
|
||||
body: JSON.stringify(contact),
|
||||
});
|
||||
return toolText(data);
|
||||
});
|
||||
}
|
||||
if (allow("contacts_delete")) {
|
||||
server.tool("contacts_delete", "Delete a contact by path", { path: z.string() }, async ({ path }) => {
|
||||
const encoded = encodePath(path);
|
||||
const data = await ultiFetch(getToken(), `/contacts/${encoded}`, {
|
||||
method: "DELETE",
|
||||
});
|
||||
return toolText(data);
|
||||
});
|
||||
}
|
||||
if (allow("suite_search")) {
|
||||
server.tool("suite_search", "Unified search across mail, drive, contacts (mail spam/trash excluded by default)", {
|
||||
query: z.string(),
|
||||
types: z.string().optional(),
|
||||
limit: z.number().optional(),
|
||||
offset: z.number().optional(),
|
||||
include_spam: z.boolean().optional(),
|
||||
include_trash: z.boolean().optional(),
|
||||
}, async ({ query, types, limit, offset, include_spam, include_trash }) => {
|
||||
const qs = withListPagination(new URLSearchParams({ q: query }), limit, offset);
|
||||
if (types)
|
||||
qs.set("types", types);
|
||||
applyHiddenMailboxParams(qs, include_spam, include_trash);
|
||||
const data = await ultiFetch(getToken(), `/search?${qs}`);
|
||||
return toolText(data);
|
||||
});
|
||||
}
|
||||
if (allow("web_search")) {
|
||||
server.tool("web_search", "Search the public web using the configured search provider (Brave, Bing, SearXNG, DuckDuckGo or custom JSON API).", {
|
||||
query: z.string(),
|
||||
count: z.number().optional(),
|
||||
}, async ({ query, count }) => {
|
||||
const qs = new URLSearchParams({ q: query });
|
||||
if (count != null && Number.isFinite(count) && count > 0) {
|
||||
qs.set("count", String(Math.min(Math.trunc(count), 20)));
|
||||
}
|
||||
const data = await ultiFetch(getToken(), `/search/web?${qs}`);
|
||||
return toolText(data);
|
||||
});
|
||||
}
|
||||
if (allow("docs_read")) {
|
||||
server.tool("docs_read", "Read UltiDocs document JSON (.ultidoc sidecar path)", { path: z.string() }, async ({ path }) => {
|
||||
const encoded = encodePath(path);
|
||||
const data = await ultiFetch(getToken(), `/drive/download/${encoded}`);
|
||||
return toolText(data);
|
||||
});
|
||||
}
|
||||
if (allow("docs_save")) {
|
||||
server.tool("docs_save", "Save UltiDocs TipTap content to sidecar path", {
|
||||
path: z.string(),
|
||||
document: z.record(z.string(), z.unknown()),
|
||||
}, async ({ path, document }) => {
|
||||
const data = await ultiFetch(getToken(), "/richtext/save", {
|
||||
method: "PUT",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ path, document }),
|
||||
});
|
||||
return toolText(data);
|
||||
});
|
||||
}
|
||||
if (allow("calendar_list")) {
|
||||
server.tool("calendar_list", "List calendars", {}, async () => {
|
||||
const data = await ultiFetch(getToken(), "/calendar/");
|
||||
return toolText(data);
|
||||
});
|
||||
}
|
||||
if (allow("calendar_create")) {
|
||||
server.tool("calendar_create", "Create a calendar", {
|
||||
id: z.string().optional(),
|
||||
display_name: z.string(),
|
||||
color: z.string().optional(),
|
||||
}, async (body) => {
|
||||
const data = await ultiFetch(getToken(), "/calendar/", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
return toolText(data);
|
||||
});
|
||||
}
|
||||
if (allow("calendar_update")) {
|
||||
server.tool("calendar_update", "Update calendar display name or color", {
|
||||
cal_id: z.string(),
|
||||
display_name: z.string().optional(),
|
||||
color: z.string().optional(),
|
||||
}, async ({ cal_id, ...body }) => {
|
||||
const data = await ultiFetch(getToken(), `/calendar/${cal_id}`, {
|
||||
method: "PATCH",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
return toolText(data);
|
||||
});
|
||||
}
|
||||
if (allow("calendar_delete")) {
|
||||
server.tool("calendar_delete", "Delete a calendar", { cal_id: z.string() }, async ({ cal_id }) => {
|
||||
const data = await ultiFetch(getToken(), `/calendar/${cal_id}`, {
|
||||
method: "DELETE",
|
||||
});
|
||||
return toolText(data);
|
||||
});
|
||||
}
|
||||
if (allow("calendar_list_events")) {
|
||||
server.tool("calendar_list_events", "List events in a calendar", {
|
||||
cal_id: z.string(),
|
||||
limit: z.number().optional(),
|
||||
offset: z.number().optional(),
|
||||
}, async ({ cal_id, limit, offset }) => {
|
||||
const qs = listPaginationParams(limit, offset);
|
||||
const suffix = qs.size > 0 ? `?${qs}` : "";
|
||||
const data = await ultiFetch(getToken(), `/calendar/${cal_id}/events${suffix}`);
|
||||
return toolText(data);
|
||||
});
|
||||
}
|
||||
if (allow("calendar_freebusy")) {
|
||||
server.tool("calendar_freebusy", "Query free/busy for attendees in a time range", {
|
||||
start: z.string(),
|
||||
end: z.string(),
|
||||
attendees: z.array(z.string()),
|
||||
}, async (body) => {
|
||||
const data = await ultiFetch(getToken(), "/calendar/freebusy", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
return toolText(data);
|
||||
});
|
||||
}
|
||||
if (allow("calendar_create_event")) {
|
||||
server.tool("calendar_create_event", "Create a calendar event", {
|
||||
cal_id: z.string(),
|
||||
event: z.record(z.string(), z.unknown()),
|
||||
}, async ({ cal_id, event }) => {
|
||||
const data = await ultiFetch(getToken(), `/calendar/${cal_id}/events`, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify(event),
|
||||
});
|
||||
return toolText(data);
|
||||
});
|
||||
}
|
||||
if (allow("calendar_update_event")) {
|
||||
server.tool("calendar_update_event", "Update a calendar event (requires If-Match etag)", {
|
||||
event_path: z.string(),
|
||||
if_match: z.string(),
|
||||
event: z.record(z.string(), z.unknown()),
|
||||
}, async ({ event_path, if_match, event }) => {
|
||||
const encoded = encodePath(event_path);
|
||||
const data = await ultiFetch(getToken(), `/calendar/events/${encoded}`, {
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
"If-Match": if_match,
|
||||
},
|
||||
body: JSON.stringify(event),
|
||||
});
|
||||
return toolText(data);
|
||||
});
|
||||
}
|
||||
if (allow("calendar_delete_event")) {
|
||||
server.tool("calendar_delete_event", "Delete a calendar event", { event_path: z.string() }, async ({ event_path }) => {
|
||||
const encoded = encodePath(event_path);
|
||||
const data = await ultiFetch(getToken(), `/calendar/events/${encoded}`, {
|
||||
method: "DELETE",
|
||||
});
|
||||
return toolText(data);
|
||||
});
|
||||
}
|
||||
if (allow("calendar_respond_invitation")) {
|
||||
server.tool("calendar_respond_invitation", "Respond to a calendar invitation", {
|
||||
event_path: z.string(),
|
||||
email: z.string().optional(),
|
||||
response: z.enum(["accepted", "declined", "tentative", "needs-action"]),
|
||||
if_match: z.string().optional(),
|
||||
}, async ({ event_path, ...body }) => {
|
||||
const encoded = encodePath(event_path);
|
||||
const data = await ultiFetch(getToken(), `/calendar/events/response/${encoded}`, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
return toolText(data);
|
||||
});
|
||||
}
|
||||
if (allow("calendar_create_meet_link")) {
|
||||
server.tool("calendar_create_meet_link", "Create a video meeting link for an event", {
|
||||
event_path: z.string(),
|
||||
if_match: z.string().optional(),
|
||||
}, async ({ event_path, if_match }) => {
|
||||
const encoded = encodePath(event_path);
|
||||
const data = await ultiFetch(getToken(), `/calendar/events/meet-link/${encoded}`, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ if_match }),
|
||||
});
|
||||
return toolText(data);
|
||||
});
|
||||
}
|
||||
if (enabledTools && enabledTools.size === 0) {
|
||||
console.warn("ultimail-mcp: no tools enabled for groups", [...enabledTools]);
|
||||
}
|
||||
else if (enabledTools) {
|
||||
const disabled = ALL_TOOL_NAMES.filter((name) => !enabledTools.has(name));
|
||||
if (disabled.length > 0) {
|
||||
console.log("ultimail-mcp disabled tools:", disabled.join(", "));
|
||||
}
|
||||
}
|
||||
return server;
|
||||
}
|
||||
const app = express();
|
||||
app.use(express.json({ limit: "4mb" }));
|
||||
const streamableTransports = new Map();
|
||||
const sseTransports = new Map();
|
||||
app.get("/health", (_req, res) => {
|
||||
res.json({ ok: true, tools: ALL_TOOL_NAMES });
|
||||
});
|
||||
app.all("/mcp", async (req, res) => {
|
||||
const token = resolveToken(req);
|
||||
const enabledTools = parseEnabledTools(String(req.headers["x-ulti-enabled-tools"] ?? ""));
|
||||
try {
|
||||
const sessionId = String(req.headers["mcp-session-id"] ?? "");
|
||||
let transport;
|
||||
if (sessionId && streamableTransports.has(sessionId)) {
|
||||
transport = streamableTransports.get(sessionId);
|
||||
}
|
||||
else if (!sessionId && req.method === "POST" && isInitializeRequest(req.body)) {
|
||||
transport = new StreamableHTTPServerTransport({
|
||||
sessionIdGenerator: () => randomUUID(),
|
||||
onsessioninitialized: (sid) => {
|
||||
if (transport)
|
||||
streamableTransports.set(sid, transport);
|
||||
},
|
||||
});
|
||||
transport.onclose = () => {
|
||||
const sid = transport?.sessionId;
|
||||
if (sid)
|
||||
streamableTransports.delete(sid);
|
||||
};
|
||||
const server = createServer(() => token, enabledTools);
|
||||
await server.connect(transport);
|
||||
}
|
||||
else {
|
||||
res.status(400).json({
|
||||
jsonrpc: "2.0",
|
||||
error: { code: -32000, message: "Bad Request: No valid session ID provided" },
|
||||
id: null,
|
||||
});
|
||||
return;
|
||||
}
|
||||
await transport.handleRequest(req, res, req.body);
|
||||
}
|
||||
catch (error) {
|
||||
console.error("streamable mcp error:", error);
|
||||
if (!res.headersSent) {
|
||||
res.status(500).json({
|
||||
jsonrpc: "2.0",
|
||||
error: { code: -32603, message: "Internal server error" },
|
||||
id: null,
|
||||
});
|
||||
}
|
||||
}
|
||||
});
|
||||
app.get("/sse", async (req, res) => {
|
||||
const token = resolveToken(req);
|
||||
const enabledTools = parseEnabledTools(String(req.headers["x-ulti-enabled-tools"] ?? ""));
|
||||
const transport = new SSEServerTransport("/messages", res);
|
||||
sseTransports.set(transport.sessionId, transport);
|
||||
res.on("close", () => sseTransports.delete(transport.sessionId));
|
||||
const server = createServer(() => token, enabledTools);
|
||||
await server.connect(transport);
|
||||
});
|
||||
app.post("/messages", async (req, res) => {
|
||||
const sessionId = String(req.query.sessionId ?? "");
|
||||
const transport = sseTransports.get(sessionId);
|
||||
if (!transport || !(transport instanceof SSEServerTransport)) {
|
||||
res.status(404).json({ error: "session not found" });
|
||||
return;
|
||||
}
|
||||
await transport.handlePostMessage(req, res, req.body);
|
||||
});
|
||||
app.listen(PORT, () => {
|
||||
console.log(`ultimail-mcp listening on :${PORT} (streamable /mcp, legacy /sse)`);
|
||||
});
|
||||
1800
services/ultimail-mcp/package-lock.json
generated
Normal file
1800
services/ultimail-mcp/package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load Diff
@ -9,7 +9,7 @@
|
||||
"dev": "tsx src/index.ts"
|
||||
},
|
||||
"dependencies": {
|
||||
"@modelcontextprotocol/sdk": "^1.12.1",
|
||||
"@modelcontextprotocol/sdk": "^1.29.0",
|
||||
"express": "^5.1.0",
|
||||
"zod": "^3.24.4"
|
||||
},
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user