368 lines
9.5 KiB
Go
368 lines
9.5 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"strconv"
|
|
"sync"
|
|
"time"
|
|
|
|
"agentmon/internal/httpx"
|
|
"agentmon/internal/store/postgres"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/go-chi/chi/v5/middleware"
|
|
"github.com/gorilla/websocket"
|
|
"github.com/nats-io/nats.go"
|
|
)
|
|
|
|
type wsClient struct {
|
|
conn *websocket.Conn
|
|
send chan []byte
|
|
}
|
|
|
|
var (
|
|
wsUpgrader = websocket.Upgrader{
|
|
CheckOrigin: func(r *http.Request) bool { return true },
|
|
}
|
|
wsClients = make(map[*wsClient]bool)
|
|
wsMu sync.RWMutex
|
|
natsConn *nats.Conn
|
|
)
|
|
|
|
func subscribeToNATS(nc *nats.Conn) {
|
|
topic := envDefault("NATS_TOPIC", "agentmon.events.v1")
|
|
sub, err := nc.Subscribe(topic, func(msg *nats.Msg) {
|
|
wsMu.RLock()
|
|
for client := range wsClients {
|
|
select {
|
|
case client.send <- msg.Data:
|
|
default:
|
|
// Slow client; close and remove in background.
|
|
go removeClient(client)
|
|
}
|
|
}
|
|
wsMu.RUnlock()
|
|
})
|
|
if err != nil {
|
|
log.Printf("failed to subscribe to NATS: %v", err)
|
|
return
|
|
}
|
|
log.Printf("subscribed to NATS topic: %s", topic)
|
|
_ = sub
|
|
}
|
|
|
|
func removeClient(c *wsClient) {
|
|
wsMu.Lock()
|
|
if wsClients[c] {
|
|
delete(wsClients, c)
|
|
close(c.send)
|
|
c.conn.Close()
|
|
}
|
|
wsMu.Unlock()
|
|
}
|
|
|
|
func wsHandler(w http.ResponseWriter, r *http.Request) {
|
|
conn, err := wsUpgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
client := &wsClient{
|
|
conn: conn,
|
|
send: make(chan []byte, 256),
|
|
}
|
|
|
|
wsMu.Lock()
|
|
wsClients[client] = true
|
|
wsMu.Unlock()
|
|
|
|
log.Printf("WebSocket client connected")
|
|
|
|
// Writer goroutine: sole owner of conn writes.
|
|
go func() {
|
|
defer conn.Close()
|
|
for msg := range client.send {
|
|
if err := conn.WriteMessage(websocket.TextMessage, msg); err != nil {
|
|
break
|
|
}
|
|
}
|
|
}()
|
|
|
|
// Read loop blocks until the client disconnects.
|
|
for {
|
|
_, _, err := conn.ReadMessage()
|
|
if err != nil {
|
|
break
|
|
}
|
|
}
|
|
|
|
wsMu.Lock()
|
|
if wsClients[client] {
|
|
delete(wsClients, client)
|
|
close(client.send)
|
|
}
|
|
wsMu.Unlock()
|
|
|
|
log.Printf("WebSocket client disconnected")
|
|
}
|
|
|
|
func main() {
|
|
addr := envDefault("AGENTMON_QUERY_ADDR", ":8081")
|
|
dsn := os.Getenv("DATABASE_URL")
|
|
natsURL := envDefault("NATS_URL", "nats://localhost:4222")
|
|
|
|
if dsn == "" {
|
|
log.Fatalf("DATABASE_URL is required")
|
|
}
|
|
|
|
db, err := postgres.Open(dsn)
|
|
if err != nil {
|
|
log.Fatalf("failed to open DB: %v", err)
|
|
}
|
|
defer func() { _ = db.Close() }()
|
|
|
|
nc, err := nats.Connect(natsURL)
|
|
if err != nil {
|
|
log.Printf("warning: failed to connect to NATS: %v", err)
|
|
} else {
|
|
natsConn = nc
|
|
go subscribeToNATS(nc)
|
|
}
|
|
|
|
r := chi.NewRouter()
|
|
r.Use(middleware.RequestID)
|
|
r.Use(middleware.RealIP)
|
|
r.Use(middleware.Logger)
|
|
r.Use(middleware.Recoverer)
|
|
|
|
r.Get("/healthz", func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write([]byte("ok"))
|
|
})
|
|
|
|
r.Get("/v1/ws", wsHandler)
|
|
|
|
r.Get("/v1/events", func(w http.ResponseWriter, r *http.Request) {
|
|
limit, _ := strconv.Atoi(r.URL.Query().Get("limit"))
|
|
f := postgres.EventsFilter{
|
|
Limit: limit,
|
|
EventType: r.URL.Query().Get("event_type"),
|
|
Framework: r.URL.Query().Get("framework"),
|
|
ClientID: r.URL.Query().Get("client_id"),
|
|
}
|
|
events, err := db.ListRecentEvents(r.Context(), f)
|
|
if err != nil {
|
|
httpx.WriteJSON(w, http.StatusInternalServerError, map[string]any{"error": "db_error"})
|
|
return
|
|
}
|
|
httpx.WriteJSON(w, http.StatusOK, map[string]any{"events": events})
|
|
})
|
|
|
|
r.Get("/v1/sessions", func(w http.ResponseWriter, r *http.Request) {
|
|
q := r.URL.Query()
|
|
f := postgres.SessionsFilter{
|
|
Framework: q.Get("framework"),
|
|
Host: q.Get("host"),
|
|
}
|
|
|
|
if v := q.Get("limit"); v != "" {
|
|
f.Limit, _ = strconv.Atoi(v)
|
|
}
|
|
|
|
if v := q.Get("from"); v != "" {
|
|
if t, err := time.Parse(time.RFC3339, v); err == nil {
|
|
f.From = &t
|
|
} else if t, err := time.Parse("2006-01-02", v); err == nil {
|
|
f.From = &t
|
|
}
|
|
}
|
|
|
|
if v := q.Get("to"); v != "" {
|
|
if t, err := time.Parse(time.RFC3339, v); err == nil {
|
|
f.To = &t
|
|
} else if t, err := time.Parse("2006-01-02", v); err == nil {
|
|
end := t.Add(24*time.Hour - time.Nanosecond)
|
|
f.To = &end
|
|
}
|
|
}
|
|
|
|
if v := q.Get("cursor"); v != "" {
|
|
if t, err := time.Parse(time.RFC3339Nano, v); err == nil {
|
|
f.Cursor = &t
|
|
}
|
|
}
|
|
|
|
sessions, nextCursor, err := db.ListSessions(r.Context(), f)
|
|
if err != nil {
|
|
httpx.WriteJSON(w, http.StatusInternalServerError, map[string]any{"error": "db_error"})
|
|
return
|
|
}
|
|
|
|
resp := map[string]any{"sessions": sessions}
|
|
if nextCursor != nil {
|
|
resp["next_cursor"] = nextCursor.Format(time.RFC3339Nano)
|
|
}
|
|
|
|
// Include total count on the first page (no cursor) so the UI can show "X of Y"
|
|
if f.Cursor == nil {
|
|
total, err := db.CountSessions(r.Context(), f)
|
|
if err == nil {
|
|
resp["total"] = total
|
|
}
|
|
}
|
|
|
|
httpx.WriteJSON(w, http.StatusOK, resp)
|
|
})
|
|
|
|
r.Get("/v1/sessions/{sessionID}", func(w http.ResponseWriter, r *http.Request) {
|
|
sessionID := chi.URLParam(r, "sessionID")
|
|
session, runs, err := db.GetSessionWithRuns(r.Context(), sessionID)
|
|
if err == sql.ErrNoRows {
|
|
httpx.WriteJSON(w, http.StatusNotFound, map[string]any{"error": "not_found"})
|
|
return
|
|
}
|
|
if err != nil {
|
|
httpx.WriteJSON(w, http.StatusInternalServerError, map[string]any{"error": "db_error"})
|
|
return
|
|
}
|
|
httpx.WriteJSON(w, http.StatusOK, map[string]any{"session": session, "runs": runs})
|
|
})
|
|
|
|
r.Get("/v1/agents/live", func(w http.ResponseWriter, r *http.Request) {
|
|
clientID := r.URL.Query().Get("client_id")
|
|
framework := r.URL.Query().Get("framework")
|
|
if clientID == "" || framework == "" {
|
|
httpx.WriteJSON(w, http.StatusBadRequest, map[string]any{"error": "missing_agent_selector"})
|
|
return
|
|
}
|
|
|
|
limit, _ := strconv.Atoi(r.URL.Query().Get("limit"))
|
|
events, err := db.ListAgentLiveEvents(r.Context(), framework, clientID, limit)
|
|
if err != nil {
|
|
httpx.WriteJSON(w, http.StatusInternalServerError, map[string]any{"error": "db_error"})
|
|
return
|
|
}
|
|
httpx.WriteJSON(w, http.StatusOK, map[string]any{"events": events})
|
|
})
|
|
|
|
r.Get("/v1/runs/{runID}", func(w http.ResponseWriter, r *http.Request) {
|
|
runID := chi.URLParam(r, "runID")
|
|
run, spans, err := db.GetRunWithSpans(r.Context(), runID)
|
|
if err == sql.ErrNoRows {
|
|
httpx.WriteJSON(w, http.StatusNotFound, map[string]any{"error": "not_found"})
|
|
return
|
|
}
|
|
if err != nil {
|
|
httpx.WriteJSON(w, http.StatusInternalServerError, map[string]any{"error": "db_error"})
|
|
return
|
|
}
|
|
httpx.WriteJSON(w, http.StatusOK, map[string]any{"run": run, "spans": spans})
|
|
})
|
|
|
|
r.Get("/v1/stats/summary", func(w http.ResponseWriter, r *http.Request) {
|
|
summary, err := db.GetSummary(r.Context())
|
|
if err != nil {
|
|
httpx.WriteJSON(w, http.StatusInternalServerError, map[string]any{"error": "db_error"})
|
|
return
|
|
}
|
|
httpx.WriteJSON(w, http.StatusOK, summary)
|
|
})
|
|
|
|
r.Get("/v1/stats/top-tools", func(w http.ResponseWriter, r *http.Request) {
|
|
limit, _ := strconv.Atoi(r.URL.Query().Get("limit"))
|
|
tools, err := db.GetTopTools(r.Context(), limit)
|
|
if err != nil {
|
|
httpx.WriteJSON(w, http.StatusInternalServerError, map[string]any{"error": "db_error"})
|
|
return
|
|
}
|
|
if tools == nil {
|
|
tools = []postgres.TopTool{}
|
|
}
|
|
httpx.WriteJSON(w, http.StatusOK, map[string]any{"tools": tools})
|
|
})
|
|
|
|
r.Get("/v1/stats/top-models", func(w http.ResponseWriter, r *http.Request) {
|
|
limit, _ := strconv.Atoi(r.URL.Query().Get("limit"))
|
|
models, err := db.GetTopModels(r.Context(), limit)
|
|
if err != nil {
|
|
httpx.WriteJSON(w, http.StatusInternalServerError, map[string]any{"error": "db_error"})
|
|
return
|
|
}
|
|
if models == nil {
|
|
models = []postgres.TopModel{}
|
|
}
|
|
httpx.WriteJSON(w, http.StatusOK, map[string]any{"models": models})
|
|
})
|
|
|
|
r.Post("/v1/admin/retention", func(w http.ResponseWriter, r *http.Request) {
|
|
var req struct {
|
|
Days int `json:"days"`
|
|
}
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.Days <= 0 {
|
|
httpx.WriteJSON(w, http.StatusBadRequest, map[string]any{"error": "invalid_request", "message": "days must be a positive integer"})
|
|
return
|
|
}
|
|
cutoff := time.Now().AddDate(0, 0, -req.Days)
|
|
deleted, err := db.DeleteOlderThan(r.Context(), cutoff)
|
|
if err != nil {
|
|
httpx.WriteJSON(w, http.StatusInternalServerError, map[string]any{"error": "db_error"})
|
|
return
|
|
}
|
|
httpx.WriteJSON(w, http.StatusOK, map[string]any{"deleted": deleted, "cutoff": cutoff.Format(time.RFC3339)})
|
|
})
|
|
|
|
r.Get("/v1/stats/timeseries", func(w http.ResponseWriter, r *http.Request) {
|
|
window := r.URL.Query().Get("window")
|
|
switch window {
|
|
case "1h", "6h", "24h", "7d":
|
|
case "":
|
|
window = "1h"
|
|
default:
|
|
httpx.WriteJSON(w, http.StatusBadRequest, map[string]any{"error": "invalid_window"})
|
|
return
|
|
}
|
|
timeseries, err := db.GetTimeseries(r.Context(), window)
|
|
if err != nil {
|
|
httpx.WriteJSON(w, http.StatusInternalServerError, map[string]any{"error": "db_error"})
|
|
return
|
|
}
|
|
httpx.WriteJSON(w, http.StatusOK, timeseries)
|
|
})
|
|
|
|
// Background retention cleanup
|
|
retentionDays := 30
|
|
if v := os.Getenv("RETENTION_DAYS"); v != "" {
|
|
if n, err := strconv.Atoi(v); err == nil && n > 0 {
|
|
retentionDays = n
|
|
}
|
|
}
|
|
go func() {
|
|
ticker := time.NewTicker(24 * time.Hour)
|
|
defer ticker.Stop()
|
|
for range ticker.C {
|
|
cutoff := time.Now().AddDate(0, 0, -retentionDays)
|
|
deleted, err := db.DeleteOlderThan(context.Background(), cutoff)
|
|
if err != nil {
|
|
log.Printf("retention cleanup error: %v", err)
|
|
} else if deleted > 0 {
|
|
log.Printf("retention cleanup: deleted %d events older than %s", deleted, cutoff.Format(time.RFC3339))
|
|
}
|
|
}
|
|
}()
|
|
|
|
log.Printf("query-api listening on %s", addr)
|
|
log.Fatal(http.ListenAndServe(addr, r))
|
|
}
|
|
|
|
func envDefault(key, def string) string {
|
|
if v := os.Getenv(key); v != "" {
|
|
return v
|
|
}
|
|
return def
|
|
}
|