Add groups to mcpbox
Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Signed-off-by: mudler <mudler@localai.io>
This commit is contained in:
committed by
mudler
parent
33b8aaddfe
commit
eec88d74fe
@@ -9,72 +9,45 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// JSONRPCRequest represents a JSON-RPC request
|
||||
type JSONRPCRequest struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID int64 `json:"id"`
|
||||
Method string `json:"method"`
|
||||
Params interface{} `json:"params"`
|
||||
}
|
||||
|
||||
// JSONRPCResponse represents a JSON-RPC response
|
||||
type JSONRPCResponse struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID int64 `json:"id"`
|
||||
Result json.RawMessage `json:"result,omitempty"`
|
||||
Error *JSONRPCError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// JSONRPCError represents a JSON-RPC error
|
||||
type JSONRPCError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// JSONRPCNotification represents a JSON-RPC notification
|
||||
type JSONRPCNotification struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
Notification struct {
|
||||
Method string `json:"method"`
|
||||
Params interface{} `json:"params,omitempty"`
|
||||
} `json:"notification"`
|
||||
}
|
||||
|
||||
// Client implements the transport.Interface for stdio processes
|
||||
type Client struct {
|
||||
baseURL string
|
||||
processID string
|
||||
conn *websocket.Conn
|
||||
mu sync.Mutex
|
||||
notifyChan chan JSONRPCNotification
|
||||
baseURL string
|
||||
processes map[string]*Process
|
||||
groups map[string][]string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewClient creates a new stdio transport client
|
||||
func NewClient(baseURL string) *Client {
|
||||
return &Client{
|
||||
baseURL: baseURL,
|
||||
notifyChan: make(chan JSONRPCNotification, 100),
|
||||
baseURL: baseURL,
|
||||
processes: make(map[string]*Process),
|
||||
groups: make(map[string][]string),
|
||||
}
|
||||
}
|
||||
|
||||
// Start initiates the connection to the server
|
||||
func (c *Client) Start(ctx context.Context) error {
|
||||
// Start a new process
|
||||
// CreateProcess starts a new process in a group
|
||||
func (c *Client) CreateProcess(ctx context.Context, command string, args []string, env []string, groupID string) (*Process, error) {
|
||||
req := struct {
|
||||
Command string `json:"command"`
|
||||
Args []string `json:"args"`
|
||||
Env []string `json:"env"`
|
||||
GroupID string `json:"group_id"`
|
||||
}{
|
||||
Command: "./mcp_server",
|
||||
Args: []string{},
|
||||
Command: command,
|
||||
Args: args,
|
||||
Env: env,
|
||||
GroupID: groupID,
|
||||
}
|
||||
|
||||
reqBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal request: %w", err)
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := http.Post(
|
||||
@@ -83,7 +56,7 @@ func (c *Client) Start(ctx context.Context) error {
|
||||
bytes.NewReader(reqBody),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start process: %w", err)
|
||||
return nil, fmt.Errorf("failed to start process: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
@@ -91,124 +64,197 @@ func (c *Client) Start(ctx context.Context) error {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return fmt.Errorf("failed to decode response: %w", err)
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
c.processID = result.ID
|
||||
process := &Process{
|
||||
ID: result.ID,
|
||||
GroupID: groupID,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
c.processes[process.ID] = process
|
||||
if groupID != "" {
|
||||
c.groups[groupID] = append(c.groups[groupID], process.ID)
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
return process, nil
|
||||
}
|
||||
|
||||
// GetProcess returns a process by ID
|
||||
func (c *Client) GetProcess(id string) (*Process, error) {
|
||||
c.mu.RLock()
|
||||
process, exists := c.processes[id]
|
||||
c.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("process not found: %s", id)
|
||||
}
|
||||
|
||||
return process, nil
|
||||
}
|
||||
|
||||
// GetGroupProcesses returns all processes in a group
|
||||
func (c *Client) GetGroupProcesses(groupID string) ([]*Process, error) {
|
||||
c.mu.RLock()
|
||||
processIDs, exists := c.groups[groupID]
|
||||
if !exists {
|
||||
c.mu.RUnlock()
|
||||
return nil, fmt.Errorf("group not found: %s", groupID)
|
||||
}
|
||||
|
||||
processes := make([]*Process, 0, len(processIDs))
|
||||
for _, pid := range processIDs {
|
||||
if process, exists := c.processes[pid]; exists {
|
||||
processes = append(processes, process)
|
||||
}
|
||||
}
|
||||
c.mu.RUnlock()
|
||||
|
||||
return processes, nil
|
||||
}
|
||||
|
||||
// StopProcess stops a single process
|
||||
func (c *Client) StopProcess(id string) error {
|
||||
c.mu.Lock()
|
||||
process, exists := c.processes[id]
|
||||
if !exists {
|
||||
c.mu.Unlock()
|
||||
return fmt.Errorf("process not found: %s", id)
|
||||
}
|
||||
|
||||
// Remove from group if it exists
|
||||
if process.GroupID != "" {
|
||||
groupProcesses := c.groups[process.GroupID]
|
||||
for i, pid := range groupProcesses {
|
||||
if pid == id {
|
||||
c.groups[process.GroupID] = append(groupProcesses[:i], groupProcesses[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(c.groups[process.GroupID]) == 0 {
|
||||
delete(c.groups, process.GroupID)
|
||||
}
|
||||
}
|
||||
|
||||
delete(c.processes, id)
|
||||
c.mu.Unlock()
|
||||
|
||||
req, err := http.NewRequest(
|
||||
"DELETE",
|
||||
fmt.Sprintf("%s/processes/%s", c.baseURL, id),
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to stop process: %w", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// StopGroup stops all processes in a group
|
||||
func (c *Client) StopGroup(groupID string) error {
|
||||
c.mu.Lock()
|
||||
processIDs, exists := c.groups[groupID]
|
||||
if !exists {
|
||||
c.mu.Unlock()
|
||||
return fmt.Errorf("group not found: %s", groupID)
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
for _, pid := range processIDs {
|
||||
if err := c.StopProcess(pid); err != nil {
|
||||
return fmt.Errorf("failed to stop process %s in group %s: %w", pid, groupID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListGroups returns all group IDs
|
||||
func (c *Client) ListGroups() []string {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
groups := make([]string, 0, len(c.groups))
|
||||
for groupID := range c.groups {
|
||||
groups = append(groups, groupID)
|
||||
}
|
||||
return groups
|
||||
}
|
||||
|
||||
// GetProcessIO returns io.Reader and io.Writer for a process
|
||||
func (c *Client) GetProcessIO(id string) (io.Reader, io.Writer, error) {
|
||||
process, err := c.GetProcess(id)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Connect to WebSocket
|
||||
u := url.URL{
|
||||
Scheme: "ws",
|
||||
Host: c.baseURL,
|
||||
Path: fmt.Sprintf("/ws/%s", c.processID),
|
||||
Path: fmt.Sprintf("/ws/%s", process.ID),
|
||||
}
|
||||
|
||||
conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to WebSocket: %w", err)
|
||||
return nil, nil, fmt.Errorf("failed to connect to WebSocket: %w", err)
|
||||
}
|
||||
|
||||
c.conn = conn
|
||||
// Create reader and writer
|
||||
reader := &websocketReader{conn: conn}
|
||||
writer := &websocketWriter{conn: conn}
|
||||
|
||||
// Start notification handler
|
||||
go c.handleNotifications()
|
||||
|
||||
return nil
|
||||
return reader, writer, nil
|
||||
}
|
||||
|
||||
// Close shuts down the client and closes the transport
|
||||
// websocketReader implements io.Reader for WebSocket
|
||||
type websocketReader struct {
|
||||
conn *websocket.Conn
|
||||
}
|
||||
|
||||
func (r *websocketReader) Read(p []byte) (n int, err error) {
|
||||
_, message, err := r.conn.ReadMessage()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
n = copy(p, message)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// websocketWriter implements io.Writer for WebSocket
|
||||
type websocketWriter struct {
|
||||
conn *websocket.Conn
|
||||
}
|
||||
|
||||
func (w *websocketWriter) Write(p []byte) (n int, err error) {
|
||||
err = w.conn.WriteMessage(websocket.TextMessage, p)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
// Close closes all connections and stops all processes
|
||||
func (c *Client) Close() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.conn != nil {
|
||||
c.conn.Close()
|
||||
}
|
||||
|
||||
if c.processID != "" {
|
||||
req, err := http.NewRequest(
|
||||
"DELETE",
|
||||
fmt.Sprintf("%s/processes/%s", c.baseURL, c.processID),
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
// Stop all processes
|
||||
for id := range c.processes {
|
||||
if err := c.StopProcess(id); err != nil {
|
||||
return fmt.Errorf("failed to stop process %s: %w", id, err)
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to stop process: %w", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendRequest sends a JSON-RPC request to the server
|
||||
func (c *Client) SendRequest(
|
||||
ctx context.Context,
|
||||
request JSONRPCRequest,
|
||||
) (*JSONRPCResponse, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.conn == nil {
|
||||
return nil, fmt.Errorf("not connected")
|
||||
}
|
||||
|
||||
if err := c.conn.WriteJSON(request); err != nil {
|
||||
return nil, fmt.Errorf("failed to write request: %w", err)
|
||||
}
|
||||
|
||||
var response JSONRPCResponse
|
||||
if err := c.conn.ReadJSON(&response); err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
// SendNotification sends a JSON-RPC notification to the server
|
||||
func (c *Client) SendNotification(
|
||||
ctx context.Context,
|
||||
notification JSONRPCNotification,
|
||||
) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.conn == nil {
|
||||
return fmt.Errorf("not connected")
|
||||
}
|
||||
|
||||
return c.conn.WriteJSON(notification)
|
||||
}
|
||||
|
||||
// SetNotificationHandler sets the handler for notifications
|
||||
func (c *Client) SetNotificationHandler(
|
||||
handler func(notification JSONRPCNotification),
|
||||
) {
|
||||
go func() {
|
||||
for notification := range c.notifyChan {
|
||||
handler(notification)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (c *Client) handleNotifications() {
|
||||
for {
|
||||
var notification JSONRPCNotification
|
||||
if err := c.conn.ReadJSON(¬ification); err != nil {
|
||||
if err == io.EOF {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
select {
|
||||
case c.notifyChan <- notification:
|
||||
default:
|
||||
// Drop notification if channel is full
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user