Add groups to mcpbox

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
Signed-off-by: mudler <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2025-04-22 22:43:56 +02:00
committed by mudler
parent 33b8aaddfe
commit eec88d74fe
3 changed files with 354 additions and 179 deletions

View File

@@ -9,72 +9,45 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"sync" "sync"
"time"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
) )
// JSONRPCRequest represents a JSON-RPC request
type JSONRPCRequest struct {
JSONRPC string `json:"jsonrpc"`
ID int64 `json:"id"`
Method string `json:"method"`
Params interface{} `json:"params"`
}
// JSONRPCResponse represents a JSON-RPC response
type JSONRPCResponse struct {
JSONRPC string `json:"jsonrpc"`
ID int64 `json:"id"`
Result json.RawMessage `json:"result,omitempty"`
Error *JSONRPCError `json:"error,omitempty"`
}
// JSONRPCError represents a JSON-RPC error
type JSONRPCError struct {
Code int `json:"code"`
Message string `json:"message"`
}
// JSONRPCNotification represents a JSON-RPC notification
type JSONRPCNotification struct {
JSONRPC string `json:"jsonrpc"`
Notification struct {
Method string `json:"method"`
Params interface{} `json:"params,omitempty"`
} `json:"notification"`
}
// Client implements the transport.Interface for stdio processes // Client implements the transport.Interface for stdio processes
type Client struct { type Client struct {
baseURL string baseURL string
processID string processes map[string]*Process
conn *websocket.Conn groups map[string][]string
mu sync.Mutex mu sync.RWMutex
notifyChan chan JSONRPCNotification
} }
// NewClient creates a new stdio transport client // NewClient creates a new stdio transport client
func NewClient(baseURL string) *Client { func NewClient(baseURL string) *Client {
return &Client{ return &Client{
baseURL: baseURL, baseURL: baseURL,
notifyChan: make(chan JSONRPCNotification, 100), processes: make(map[string]*Process),
groups: make(map[string][]string),
} }
} }
// Start initiates the connection to the server // CreateProcess starts a new process in a group
func (c *Client) Start(ctx context.Context) error { func (c *Client) CreateProcess(ctx context.Context, command string, args []string, env []string, groupID string) (*Process, error) {
// Start a new process
req := struct { req := struct {
Command string `json:"command"` Command string `json:"command"`
Args []string `json:"args"` Args []string `json:"args"`
Env []string `json:"env"`
GroupID string `json:"group_id"`
}{ }{
Command: "./mcp_server", Command: command,
Args: []string{}, Args: args,
Env: env,
GroupID: groupID,
} }
reqBody, err := json.Marshal(req) reqBody, err := json.Marshal(req)
if err != nil { if err != nil {
return fmt.Errorf("failed to marshal request: %w", err) return nil, fmt.Errorf("failed to marshal request: %w", err)
} }
resp, err := http.Post( resp, err := http.Post(
@@ -83,7 +56,7 @@ func (c *Client) Start(ctx context.Context) error {
bytes.NewReader(reqBody), bytes.NewReader(reqBody),
) )
if err != nil { if err != nil {
return fmt.Errorf("failed to start process: %w", err) return nil, fmt.Errorf("failed to start process: %w", err)
} }
defer resp.Body.Close() defer resp.Body.Close()
@@ -91,124 +64,197 @@ func (c *Client) Start(ctx context.Context) error {
ID string `json:"id"` ID string `json:"id"`
} }
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return fmt.Errorf("failed to decode response: %w", err) return nil, fmt.Errorf("failed to decode response: %w", err)
} }
c.processID = result.ID process := &Process{
ID: result.ID,
GroupID: groupID,
CreatedAt: time.Now(),
}
c.mu.Lock()
c.processes[process.ID] = process
if groupID != "" {
c.groups[groupID] = append(c.groups[groupID], process.ID)
}
c.mu.Unlock()
return process, nil
}
// GetProcess returns a process by ID
func (c *Client) GetProcess(id string) (*Process, error) {
c.mu.RLock()
process, exists := c.processes[id]
c.mu.RUnlock()
if !exists {
return nil, fmt.Errorf("process not found: %s", id)
}
return process, nil
}
// GetGroupProcesses returns all processes in a group
func (c *Client) GetGroupProcesses(groupID string) ([]*Process, error) {
c.mu.RLock()
processIDs, exists := c.groups[groupID]
if !exists {
c.mu.RUnlock()
return nil, fmt.Errorf("group not found: %s", groupID)
}
processes := make([]*Process, 0, len(processIDs))
for _, pid := range processIDs {
if process, exists := c.processes[pid]; exists {
processes = append(processes, process)
}
}
c.mu.RUnlock()
return processes, nil
}
// StopProcess stops a single process
func (c *Client) StopProcess(id string) error {
c.mu.Lock()
process, exists := c.processes[id]
if !exists {
c.mu.Unlock()
return fmt.Errorf("process not found: %s", id)
}
// Remove from group if it exists
if process.GroupID != "" {
groupProcesses := c.groups[process.GroupID]
for i, pid := range groupProcesses {
if pid == id {
c.groups[process.GroupID] = append(groupProcesses[:i], groupProcesses[i+1:]...)
break
}
}
if len(c.groups[process.GroupID]) == 0 {
delete(c.groups, process.GroupID)
}
}
delete(c.processes, id)
c.mu.Unlock()
req, err := http.NewRequest(
"DELETE",
fmt.Sprintf("%s/processes/%s", c.baseURL, id),
nil,
)
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("failed to stop process: %w", err)
}
resp.Body.Close()
return nil
}
// StopGroup stops all processes in a group
func (c *Client) StopGroup(groupID string) error {
c.mu.Lock()
processIDs, exists := c.groups[groupID]
if !exists {
c.mu.Unlock()
return fmt.Errorf("group not found: %s", groupID)
}
c.mu.Unlock()
for _, pid := range processIDs {
if err := c.StopProcess(pid); err != nil {
return fmt.Errorf("failed to stop process %s in group %s: %w", pid, groupID, err)
}
}
return nil
}
// ListGroups returns all group IDs
func (c *Client) ListGroups() []string {
c.mu.RLock()
defer c.mu.RUnlock()
groups := make([]string, 0, len(c.groups))
for groupID := range c.groups {
groups = append(groups, groupID)
}
return groups
}
// GetProcessIO returns io.Reader and io.Writer for a process
func (c *Client) GetProcessIO(id string) (io.Reader, io.Writer, error) {
process, err := c.GetProcess(id)
if err != nil {
return nil, nil, err
}
// Connect to WebSocket // Connect to WebSocket
u := url.URL{ u := url.URL{
Scheme: "ws", Scheme: "ws",
Host: c.baseURL, Host: c.baseURL,
Path: fmt.Sprintf("/ws/%s", c.processID), Path: fmt.Sprintf("/ws/%s", process.ID),
} }
conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil) conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
if err != nil { if err != nil {
return fmt.Errorf("failed to connect to WebSocket: %w", err) return nil, nil, fmt.Errorf("failed to connect to WebSocket: %w", err)
} }
c.conn = conn // Create reader and writer
reader := &websocketReader{conn: conn}
writer := &websocketWriter{conn: conn}
// Start notification handler return reader, writer, nil
go c.handleNotifications()
return nil
} }
// Close shuts down the client and closes the transport // websocketReader implements io.Reader for WebSocket
type websocketReader struct {
conn *websocket.Conn
}
func (r *websocketReader) Read(p []byte) (n int, err error) {
_, message, err := r.conn.ReadMessage()
if err != nil {
return 0, err
}
n = copy(p, message)
return n, nil
}
// websocketWriter implements io.Writer for WebSocket
type websocketWriter struct {
conn *websocket.Conn
}
func (w *websocketWriter) Write(p []byte) (n int, err error) {
err = w.conn.WriteMessage(websocket.TextMessage, p)
if err != nil {
return 0, err
}
return len(p), nil
}
// Close closes all connections and stops all processes
func (c *Client) Close() error { func (c *Client) Close() error {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
if c.conn != nil { // Stop all processes
c.conn.Close() for id := range c.processes {
} if err := c.StopProcess(id); err != nil {
return fmt.Errorf("failed to stop process %s: %w", id, err)
if c.processID != "" {
req, err := http.NewRequest(
"DELETE",
fmt.Sprintf("%s/processes/%s", c.baseURL, c.processID),
nil,
)
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
} }
resp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("failed to stop process: %w", err)
}
resp.Body.Close()
} }
return nil return nil
} }
// SendRequest sends a JSON-RPC request to the server
func (c *Client) SendRequest(
ctx context.Context,
request JSONRPCRequest,
) (*JSONRPCResponse, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.conn == nil {
return nil, fmt.Errorf("not connected")
}
if err := c.conn.WriteJSON(request); err != nil {
return nil, fmt.Errorf("failed to write request: %w", err)
}
var response JSONRPCResponse
if err := c.conn.ReadJSON(&response); err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
return &response, nil
}
// SendNotification sends a JSON-RPC notification to the server
func (c *Client) SendNotification(
ctx context.Context,
notification JSONRPCNotification,
) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.conn == nil {
return fmt.Errorf("not connected")
}
return c.conn.WriteJSON(notification)
}
// SetNotificationHandler sets the handler for notifications
func (c *Client) SetNotificationHandler(
handler func(notification JSONRPCNotification),
) {
go func() {
for notification := range c.notifyChan {
handler(notification)
}
}()
}
func (c *Client) handleNotifications() {
for {
var notification JSONRPCNotification
if err := c.conn.ReadJSON(&notification); err != nil {
if err == io.EOF {
return
}
continue
}
select {
case c.notifyChan <- notification:
default:
// Drop notification if channel is full
}
}
}

View File

@@ -3,6 +3,7 @@ package main
import ( import (
"context" "context"
"fmt" "fmt"
"io"
"log" "log"
"time" "time"
@@ -24,48 +25,60 @@ func main() {
// Create a client // Create a client
client := stdio.NewClient("localhost:8080") client := stdio.NewClient("localhost:8080")
// Start the client // Create a process group
if err := client.Start(context.Background()); err != nil { groupID := "test-group"
log.Fatalf("Failed to start client: %v", err)
}
defer client.Close()
// Set up notification handler // Start a process in the group
client.SetNotificationHandler(func(notification stdio.JSONRPCNotification) { process, err := client.CreateProcess(
fmt.Printf("Received notification: %+v\n", notification) context.Background(),
}) "echo",
[]string{"Hello, World!"},
// Send a request []string{"TEST=value"},
request := stdio.JSONRPCRequest{ groupID,
JSONRPC: "2.0", )
ID: 1,
Method: "test",
Params: map[string]string{"hello": "world"},
}
response, err := client.SendRequest(context.Background(), request)
if err != nil { if err != nil {
log.Fatalf("Failed to send request: %v", err) log.Fatalf("Failed to create process: %v", err)
} }
fmt.Printf("Received response: %+v\n", response) // Get IO streams for the process
reader, writer, err := client.GetProcessIO(process.ID)
// Send a notification if err != nil {
notification := stdio.JSONRPCNotification{ log.Fatalf("Failed to get process IO: %v", err)
JSONRPC: "2.0",
Notification: struct {
Method string `json:"method"`
Params interface{} `json:"params,omitempty"`
}{
Method: "test",
Params: map[string]string{"hello": "world"},
},
} }
if err := client.SendNotification(context.Background(), notification); err != nil { // Write to the process
log.Fatalf("Failed to send notification: %v", err) _, err = writer.Write([]byte("Hello from client\n"))
if err != nil {
log.Fatalf("Failed to write to process: %v", err)
} }
// Keep the program running // Read from the process
select {} buf := make([]byte, 1024)
n, err := reader.Read(buf)
if err != nil && err != io.EOF {
log.Fatalf("Failed to read from process: %v", err)
}
fmt.Printf("Process output: %s", buf[:n])
// Get all processes in the group
processes, err := client.GetGroupProcesses(groupID)
if err != nil {
log.Printf("Failed to get group processes: %v", err)
} else {
fmt.Printf("Processes in group %s: %+v\n", groupID, processes)
}
// List all groups
groups := client.ListGroups()
fmt.Printf("All groups: %v\n", groups)
// Stop the process
if err := client.StopProcess(process.ID); err != nil {
log.Fatalf("Failed to stop process: %v", err)
}
// Close the client
if err := client.Close(); err != nil {
log.Fatalf("Failed to close client: %v", err)
}
} }

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"os"
"os/exec" "os/exec"
"sync" "sync"
"time" "time"
@@ -16,6 +17,7 @@ import (
// Process represents a running process with its stdio streams // Process represents a running process with its stdio streams
type Process struct { type Process struct {
ID string ID string
GroupID string
Cmd *exec.Cmd Cmd *exec.Cmd
Stdin io.WriteCloser Stdin io.WriteCloser
Stdout io.ReadCloser Stdout io.ReadCloser
@@ -26,6 +28,7 @@ type Process struct {
// Server handles process management and stdio streaming // Server handles process management and stdio streaming
type Server struct { type Server struct {
processes map[string]*Process processes map[string]*Process
groups map[string][]string // maps group ID to process IDs
mu sync.RWMutex mu sync.RWMutex
upgrader websocket.Upgrader upgrader websocket.Upgrader
} }
@@ -34,6 +37,7 @@ type Server struct {
func NewServer() *Server { func NewServer() *Server {
return &Server{ return &Server{
processes: make(map[string]*Process), processes: make(map[string]*Process),
groups: make(map[string][]string),
upgrader: websocket.Upgrader{ upgrader: websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true }, CheckOrigin: func(r *http.Request) bool { return true },
}, },
@@ -41,9 +45,13 @@ func NewServer() *Server {
} }
// StartProcess starts a new process and returns its ID // StartProcess starts a new process and returns its ID
func (s *Server) StartProcess(ctx context.Context, command string, args []string) (string, error) { func (s *Server) StartProcess(ctx context.Context, command string, args []string, env []string, groupID string) (string, error) {
cmd := exec.CommandContext(ctx, command, args...) cmd := exec.CommandContext(ctx, command, args...)
if len(env) > 0 {
cmd.Env = append(os.Environ(), env...)
}
stdin, err := cmd.StdinPipe() stdin, err := cmd.StdinPipe()
if err != nil { if err != nil {
return "", fmt.Errorf("failed to create stdin pipe: %w", err) return "", fmt.Errorf("failed to create stdin pipe: %w", err)
@@ -65,6 +73,7 @@ func (s *Server) StartProcess(ctx context.Context, command string, args []string
process := &Process{ process := &Process{
ID: fmt.Sprintf("%d", time.Now().UnixNano()), ID: fmt.Sprintf("%d", time.Now().UnixNano()),
GroupID: groupID,
Cmd: cmd, Cmd: cmd,
Stdin: stdin, Stdin: stdin,
Stdout: stdout, Stdout: stdout,
@@ -74,6 +83,9 @@ func (s *Server) StartProcess(ctx context.Context, command string, args []string
s.mu.Lock() s.mu.Lock()
s.processes[process.ID] = process s.processes[process.ID] = process
if groupID != "" {
s.groups[groupID] = append(s.groups[groupID], process.ID)
}
s.mu.Unlock() s.mu.Unlock()
return process.ID, nil return process.ID, nil
@@ -87,6 +99,21 @@ func (s *Server) StopProcess(id string) error {
s.mu.Unlock() s.mu.Unlock()
return fmt.Errorf("process not found: %s", id) return fmt.Errorf("process not found: %s", id)
} }
// Remove from group if it exists
if process.GroupID != "" {
groupProcesses := s.groups[process.GroupID]
for i, pid := range groupProcesses {
if pid == id {
s.groups[process.GroupID] = append(groupProcesses[:i], groupProcesses[i+1:]...)
break
}
}
if len(s.groups[process.GroupID]) == 0 {
delete(s.groups, process.GroupID)
}
}
delete(s.processes, id) delete(s.processes, id)
s.mu.Unlock() s.mu.Unlock()
@@ -97,6 +124,57 @@ func (s *Server) StopProcess(id string) error {
return nil return nil
} }
// StopGroup stops all processes in a group
func (s *Server) StopGroup(groupID string) error {
s.mu.Lock()
processIDs, exists := s.groups[groupID]
if !exists {
s.mu.Unlock()
return fmt.Errorf("group not found: %s", groupID)
}
s.mu.Unlock()
for _, pid := range processIDs {
if err := s.StopProcess(pid); err != nil {
return fmt.Errorf("failed to stop process %s in group %s: %w", pid, groupID, err)
}
}
return nil
}
// GetGroupProcesses returns all processes in a group
func (s *Server) GetGroupProcesses(groupID string) ([]*Process, error) {
s.mu.RLock()
processIDs, exists := s.groups[groupID]
if !exists {
s.mu.RUnlock()
return nil, fmt.Errorf("group not found: %s", groupID)
}
processes := make([]*Process, 0, len(processIDs))
for _, pid := range processIDs {
if process, exists := s.processes[pid]; exists {
processes = append(processes, process)
}
}
s.mu.RUnlock()
return processes, nil
}
// ListGroups returns all group IDs
func (s *Server) ListGroups() []string {
s.mu.RLock()
defer s.mu.RUnlock()
groups := make([]string, 0, len(s.groups))
for groupID := range s.groups {
groups = append(groups, groupID)
}
return groups
}
// GetProcess returns a process by ID // GetProcess returns a process by ID
func (s *Server) GetProcess(id string) (*Process, error) { func (s *Server) GetProcess(id string) (*Process, error) {
s.mu.RLock() s.mu.RLock()
@@ -128,6 +206,8 @@ func (s *Server) Start(addr string) error {
http.HandleFunc("/processes", s.handleProcesses) http.HandleFunc("/processes", s.handleProcesses)
http.HandleFunc("/processes/", s.handleProcess) http.HandleFunc("/processes/", s.handleProcess)
http.HandleFunc("/ws/", s.handleWebSocket) http.HandleFunc("/ws/", s.handleWebSocket)
http.HandleFunc("/groups", s.handleGroups)
http.HandleFunc("/groups/", s.handleGroup)
return http.ListenAndServe(addr, nil) return http.ListenAndServe(addr, nil)
} }
@@ -141,13 +221,15 @@ func (s *Server) handleProcesses(w http.ResponseWriter, r *http.Request) {
var req struct { var req struct {
Command string `json:"command"` Command string `json:"command"`
Args []string `json:"args"` Args []string `json:"args"`
Env []string `json:"env"`
GroupID string `json:"group_id"`
} }
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest) http.Error(w, err.Error(), http.StatusBadRequest)
return return
} }
id, err := s.StartProcess(r.Context(), req.Command, req.Args) id, err := s.StartProcess(r.Context(), req.Command, req.Args, req.Env, req.GroupID)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
@@ -212,6 +294,7 @@ func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) {
go func() { go func() {
buf := make([]byte, 1024) buf := make([]byte, 1024)
for { for {
n, err := process.Stdout.Read(buf) n, err := process.Stdout.Read(buf)
if err != nil { if err != nil {
return return
@@ -235,3 +318,36 @@ func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) {
// Wait for process to exit // Wait for process to exit
process.Cmd.Wait() process.Cmd.Wait()
} }
// Add new handlers for group management
func (s *Server) handleGroups(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
groups := s.ListGroups()
json.NewEncoder(w).Encode(groups)
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
}
func (s *Server) handleGroup(w http.ResponseWriter, r *http.Request) {
groupID := r.URL.Path[len("/groups/"):]
switch r.Method {
case http.MethodGet:
processes, err := s.GetGroupProcesses(groupID)
if err != nil {
http.Error(w, err.Error(), http.StatusNotFound)
return
}
json.NewEncoder(w).Encode(processes)
case http.MethodDelete:
if err := s.StopGroup(groupID); err != nil {
http.Error(w, err.Error(), http.StatusNotFound)
return
}
w.WriteHeader(http.StatusNoContent)
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
}