diff --git a/pkg/stdio/client.go b/pkg/stdio/client.go index a82491d..c461565 100644 --- a/pkg/stdio/client.go +++ b/pkg/stdio/client.go @@ -9,72 +9,45 @@ import ( "net/http" "net/url" "sync" + "time" "github.com/gorilla/websocket" ) -// JSONRPCRequest represents a JSON-RPC request -type JSONRPCRequest struct { - JSONRPC string `json:"jsonrpc"` - ID int64 `json:"id"` - Method string `json:"method"` - Params interface{} `json:"params"` -} - -// JSONRPCResponse represents a JSON-RPC response -type JSONRPCResponse struct { - JSONRPC string `json:"jsonrpc"` - ID int64 `json:"id"` - Result json.RawMessage `json:"result,omitempty"` - Error *JSONRPCError `json:"error,omitempty"` -} - -// JSONRPCError represents a JSON-RPC error -type JSONRPCError struct { - Code int `json:"code"` - Message string `json:"message"` -} - -// JSONRPCNotification represents a JSON-RPC notification -type JSONRPCNotification struct { - JSONRPC string `json:"jsonrpc"` - Notification struct { - Method string `json:"method"` - Params interface{} `json:"params,omitempty"` - } `json:"notification"` -} - // Client implements the transport.Interface for stdio processes type Client struct { - baseURL string - processID string - conn *websocket.Conn - mu sync.Mutex - notifyChan chan JSONRPCNotification + baseURL string + processes map[string]*Process + groups map[string][]string + mu sync.RWMutex } // NewClient creates a new stdio transport client func NewClient(baseURL string) *Client { return &Client{ - baseURL: baseURL, - notifyChan: make(chan JSONRPCNotification, 100), + baseURL: baseURL, + processes: make(map[string]*Process), + groups: make(map[string][]string), } } -// Start initiates the connection to the server -func (c *Client) Start(ctx context.Context) error { - // Start a new process +// CreateProcess starts a new process in a group +func (c *Client) CreateProcess(ctx context.Context, command string, args []string, env []string, groupID string) (*Process, error) { req := struct { Command string `json:"command"` Args []string `json:"args"` + Env []string `json:"env"` + GroupID string `json:"group_id"` }{ - Command: "./mcp_server", - Args: []string{}, + Command: command, + Args: args, + Env: env, + GroupID: groupID, } reqBody, err := json.Marshal(req) if err != nil { - return fmt.Errorf("failed to marshal request: %w", err) + return nil, fmt.Errorf("failed to marshal request: %w", err) } resp, err := http.Post( @@ -83,7 +56,7 @@ func (c *Client) Start(ctx context.Context) error { bytes.NewReader(reqBody), ) if err != nil { - return fmt.Errorf("failed to start process: %w", err) + return nil, fmt.Errorf("failed to start process: %w", err) } defer resp.Body.Close() @@ -91,124 +64,197 @@ func (c *Client) Start(ctx context.Context) error { ID string `json:"id"` } if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return fmt.Errorf("failed to decode response: %w", err) + return nil, fmt.Errorf("failed to decode response: %w", err) } - c.processID = result.ID + process := &Process{ + ID: result.ID, + GroupID: groupID, + CreatedAt: time.Now(), + } + + c.mu.Lock() + c.processes[process.ID] = process + if groupID != "" { + c.groups[groupID] = append(c.groups[groupID], process.ID) + } + c.mu.Unlock() + + return process, nil +} + +// GetProcess returns a process by ID +func (c *Client) GetProcess(id string) (*Process, error) { + c.mu.RLock() + process, exists := c.processes[id] + c.mu.RUnlock() + + if !exists { + return nil, fmt.Errorf("process not found: %s", id) + } + + return process, nil +} + +// GetGroupProcesses returns all processes in a group +func (c *Client) GetGroupProcesses(groupID string) ([]*Process, error) { + c.mu.RLock() + processIDs, exists := c.groups[groupID] + if !exists { + c.mu.RUnlock() + return nil, fmt.Errorf("group not found: %s", groupID) + } + + processes := make([]*Process, 0, len(processIDs)) + for _, pid := range processIDs { + if process, exists := c.processes[pid]; exists { + processes = append(processes, process) + } + } + c.mu.RUnlock() + + return processes, nil +} + +// StopProcess stops a single process +func (c *Client) StopProcess(id string) error { + c.mu.Lock() + process, exists := c.processes[id] + if !exists { + c.mu.Unlock() + return fmt.Errorf("process not found: %s", id) + } + + // Remove from group if it exists + if process.GroupID != "" { + groupProcesses := c.groups[process.GroupID] + for i, pid := range groupProcesses { + if pid == id { + c.groups[process.GroupID] = append(groupProcesses[:i], groupProcesses[i+1:]...) + break + } + } + if len(c.groups[process.GroupID]) == 0 { + delete(c.groups, process.GroupID) + } + } + + delete(c.processes, id) + c.mu.Unlock() + + req, err := http.NewRequest( + "DELETE", + fmt.Sprintf("%s/processes/%s", c.baseURL, id), + nil, + ) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return fmt.Errorf("failed to stop process: %w", err) + } + resp.Body.Close() + + return nil +} + +// StopGroup stops all processes in a group +func (c *Client) StopGroup(groupID string) error { + c.mu.Lock() + processIDs, exists := c.groups[groupID] + if !exists { + c.mu.Unlock() + return fmt.Errorf("group not found: %s", groupID) + } + c.mu.Unlock() + + for _, pid := range processIDs { + if err := c.StopProcess(pid); err != nil { + return fmt.Errorf("failed to stop process %s in group %s: %w", pid, groupID, err) + } + } + + return nil +} + +// ListGroups returns all group IDs +func (c *Client) ListGroups() []string { + c.mu.RLock() + defer c.mu.RUnlock() + + groups := make([]string, 0, len(c.groups)) + for groupID := range c.groups { + groups = append(groups, groupID) + } + return groups +} + +// GetProcessIO returns io.Reader and io.Writer for a process +func (c *Client) GetProcessIO(id string) (io.Reader, io.Writer, error) { + process, err := c.GetProcess(id) + if err != nil { + return nil, nil, err + } // Connect to WebSocket u := url.URL{ Scheme: "ws", Host: c.baseURL, - Path: fmt.Sprintf("/ws/%s", c.processID), + Path: fmt.Sprintf("/ws/%s", process.ID), } conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil) if err != nil { - return fmt.Errorf("failed to connect to WebSocket: %w", err) + return nil, nil, fmt.Errorf("failed to connect to WebSocket: %w", err) } - c.conn = conn + // Create reader and writer + reader := &websocketReader{conn: conn} + writer := &websocketWriter{conn: conn} - // Start notification handler - go c.handleNotifications() - - return nil + return reader, writer, nil } -// Close shuts down the client and closes the transport +// websocketReader implements io.Reader for WebSocket +type websocketReader struct { + conn *websocket.Conn +} + +func (r *websocketReader) Read(p []byte) (n int, err error) { + _, message, err := r.conn.ReadMessage() + if err != nil { + return 0, err + } + n = copy(p, message) + return n, nil +} + +// websocketWriter implements io.Writer for WebSocket +type websocketWriter struct { + conn *websocket.Conn +} + +func (w *websocketWriter) Write(p []byte) (n int, err error) { + err = w.conn.WriteMessage(websocket.TextMessage, p) + if err != nil { + return 0, err + } + return len(p), nil +} + +// Close closes all connections and stops all processes func (c *Client) Close() error { c.mu.Lock() defer c.mu.Unlock() - if c.conn != nil { - c.conn.Close() - } - - if c.processID != "" { - req, err := http.NewRequest( - "DELETE", - fmt.Sprintf("%s/processes/%s", c.baseURL, c.processID), - nil, - ) - if err != nil { - return fmt.Errorf("failed to create request: %w", err) + // Stop all processes + for id := range c.processes { + if err := c.StopProcess(id); err != nil { + return fmt.Errorf("failed to stop process %s: %w", id, err) } - - resp, err := http.DefaultClient.Do(req) - if err != nil { - return fmt.Errorf("failed to stop process: %w", err) - } - resp.Body.Close() } return nil } - -// SendRequest sends a JSON-RPC request to the server -func (c *Client) SendRequest( - ctx context.Context, - request JSONRPCRequest, -) (*JSONRPCResponse, error) { - c.mu.Lock() - defer c.mu.Unlock() - - if c.conn == nil { - return nil, fmt.Errorf("not connected") - } - - if err := c.conn.WriteJSON(request); err != nil { - return nil, fmt.Errorf("failed to write request: %w", err) - } - - var response JSONRPCResponse - if err := c.conn.ReadJSON(&response); err != nil { - return nil, fmt.Errorf("failed to read response: %w", err) - } - - return &response, nil -} - -// SendNotification sends a JSON-RPC notification to the server -func (c *Client) SendNotification( - ctx context.Context, - notification JSONRPCNotification, -) error { - c.mu.Lock() - defer c.mu.Unlock() - - if c.conn == nil { - return fmt.Errorf("not connected") - } - - return c.conn.WriteJSON(notification) -} - -// SetNotificationHandler sets the handler for notifications -func (c *Client) SetNotificationHandler( - handler func(notification JSONRPCNotification), -) { - go func() { - for notification := range c.notifyChan { - handler(notification) - } - }() -} - -func (c *Client) handleNotifications() { - for { - var notification JSONRPCNotification - if err := c.conn.ReadJSON(¬ification); err != nil { - if err == io.EOF { - return - } - continue - } - - select { - case c.notifyChan <- notification: - default: - // Drop notification if channel is full - } - } -} diff --git a/pkg/stdio/example/main.go b/pkg/stdio/example/main.go index d5e348c..becdbe8 100644 --- a/pkg/stdio/example/main.go +++ b/pkg/stdio/example/main.go @@ -3,6 +3,7 @@ package main import ( "context" "fmt" + "io" "log" "time" @@ -24,48 +25,60 @@ func main() { // Create a client client := stdio.NewClient("localhost:8080") - // Start the client - if err := client.Start(context.Background()); err != nil { - log.Fatalf("Failed to start client: %v", err) - } - defer client.Close() + // Create a process group + groupID := "test-group" - // Set up notification handler - client.SetNotificationHandler(func(notification stdio.JSONRPCNotification) { - fmt.Printf("Received notification: %+v\n", notification) - }) - - // Send a request - request := stdio.JSONRPCRequest{ - JSONRPC: "2.0", - ID: 1, - Method: "test", - Params: map[string]string{"hello": "world"}, - } - - response, err := client.SendRequest(context.Background(), request) + // Start a process in the group + process, err := client.CreateProcess( + context.Background(), + "echo", + []string{"Hello, World!"}, + []string{"TEST=value"}, + groupID, + ) if err != nil { - log.Fatalf("Failed to send request: %v", err) + log.Fatalf("Failed to create process: %v", err) } - fmt.Printf("Received response: %+v\n", response) - - // Send a notification - notification := stdio.JSONRPCNotification{ - JSONRPC: "2.0", - Notification: struct { - Method string `json:"method"` - Params interface{} `json:"params,omitempty"` - }{ - Method: "test", - Params: map[string]string{"hello": "world"}, - }, + // Get IO streams for the process + reader, writer, err := client.GetProcessIO(process.ID) + if err != nil { + log.Fatalf("Failed to get process IO: %v", err) } - if err := client.SendNotification(context.Background(), notification); err != nil { - log.Fatalf("Failed to send notification: %v", err) + // Write to the process + _, err = writer.Write([]byte("Hello from client\n")) + if err != nil { + log.Fatalf("Failed to write to process: %v", err) } - // Keep the program running - select {} + // Read from the process + buf := make([]byte, 1024) + n, err := reader.Read(buf) + if err != nil && err != io.EOF { + log.Fatalf("Failed to read from process: %v", err) + } + fmt.Printf("Process output: %s", buf[:n]) + + // Get all processes in the group + processes, err := client.GetGroupProcesses(groupID) + if err != nil { + log.Printf("Failed to get group processes: %v", err) + } else { + fmt.Printf("Processes in group %s: %+v\n", groupID, processes) + } + + // List all groups + groups := client.ListGroups() + fmt.Printf("All groups: %v\n", groups) + + // Stop the process + if err := client.StopProcess(process.ID); err != nil { + log.Fatalf("Failed to stop process: %v", err) + } + + // Close the client + if err := client.Close(); err != nil { + log.Fatalf("Failed to close client: %v", err) + } } diff --git a/pkg/stdio/server.go b/pkg/stdio/server.go index 22450bf..1c07578 100644 --- a/pkg/stdio/server.go +++ b/pkg/stdio/server.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net/http" + "os" "os/exec" "sync" "time" @@ -16,6 +17,7 @@ import ( // Process represents a running process with its stdio streams type Process struct { ID string + GroupID string Cmd *exec.Cmd Stdin io.WriteCloser Stdout io.ReadCloser @@ -26,6 +28,7 @@ type Process struct { // Server handles process management and stdio streaming type Server struct { processes map[string]*Process + groups map[string][]string // maps group ID to process IDs mu sync.RWMutex upgrader websocket.Upgrader } @@ -34,6 +37,7 @@ type Server struct { func NewServer() *Server { return &Server{ processes: make(map[string]*Process), + groups: make(map[string][]string), upgrader: websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true }, }, @@ -41,9 +45,13 @@ func NewServer() *Server { } // StartProcess starts a new process and returns its ID -func (s *Server) StartProcess(ctx context.Context, command string, args []string) (string, error) { +func (s *Server) StartProcess(ctx context.Context, command string, args []string, env []string, groupID string) (string, error) { cmd := exec.CommandContext(ctx, command, args...) + if len(env) > 0 { + cmd.Env = append(os.Environ(), env...) + } + stdin, err := cmd.StdinPipe() if err != nil { return "", fmt.Errorf("failed to create stdin pipe: %w", err) @@ -65,6 +73,7 @@ func (s *Server) StartProcess(ctx context.Context, command string, args []string process := &Process{ ID: fmt.Sprintf("%d", time.Now().UnixNano()), + GroupID: groupID, Cmd: cmd, Stdin: stdin, Stdout: stdout, @@ -74,6 +83,9 @@ func (s *Server) StartProcess(ctx context.Context, command string, args []string s.mu.Lock() s.processes[process.ID] = process + if groupID != "" { + s.groups[groupID] = append(s.groups[groupID], process.ID) + } s.mu.Unlock() return process.ID, nil @@ -87,6 +99,21 @@ func (s *Server) StopProcess(id string) error { s.mu.Unlock() return fmt.Errorf("process not found: %s", id) } + + // Remove from group if it exists + if process.GroupID != "" { + groupProcesses := s.groups[process.GroupID] + for i, pid := range groupProcesses { + if pid == id { + s.groups[process.GroupID] = append(groupProcesses[:i], groupProcesses[i+1:]...) + break + } + } + if len(s.groups[process.GroupID]) == 0 { + delete(s.groups, process.GroupID) + } + } + delete(s.processes, id) s.mu.Unlock() @@ -97,6 +124,57 @@ func (s *Server) StopProcess(id string) error { return nil } +// StopGroup stops all processes in a group +func (s *Server) StopGroup(groupID string) error { + s.mu.Lock() + processIDs, exists := s.groups[groupID] + if !exists { + s.mu.Unlock() + return fmt.Errorf("group not found: %s", groupID) + } + s.mu.Unlock() + + for _, pid := range processIDs { + if err := s.StopProcess(pid); err != nil { + return fmt.Errorf("failed to stop process %s in group %s: %w", pid, groupID, err) + } + } + + return nil +} + +// GetGroupProcesses returns all processes in a group +func (s *Server) GetGroupProcesses(groupID string) ([]*Process, error) { + s.mu.RLock() + processIDs, exists := s.groups[groupID] + if !exists { + s.mu.RUnlock() + return nil, fmt.Errorf("group not found: %s", groupID) + } + + processes := make([]*Process, 0, len(processIDs)) + for _, pid := range processIDs { + if process, exists := s.processes[pid]; exists { + processes = append(processes, process) + } + } + s.mu.RUnlock() + + return processes, nil +} + +// ListGroups returns all group IDs +func (s *Server) ListGroups() []string { + s.mu.RLock() + defer s.mu.RUnlock() + + groups := make([]string, 0, len(s.groups)) + for groupID := range s.groups { + groups = append(groups, groupID) + } + return groups +} + // GetProcess returns a process by ID func (s *Server) GetProcess(id string) (*Process, error) { s.mu.RLock() @@ -128,6 +206,8 @@ func (s *Server) Start(addr string) error { http.HandleFunc("/processes", s.handleProcesses) http.HandleFunc("/processes/", s.handleProcess) http.HandleFunc("/ws/", s.handleWebSocket) + http.HandleFunc("/groups", s.handleGroups) + http.HandleFunc("/groups/", s.handleGroup) return http.ListenAndServe(addr, nil) } @@ -141,13 +221,15 @@ func (s *Server) handleProcesses(w http.ResponseWriter, r *http.Request) { var req struct { Command string `json:"command"` Args []string `json:"args"` + Env []string `json:"env"` + GroupID string `json:"group_id"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } - id, err := s.StartProcess(r.Context(), req.Command, req.Args) + id, err := s.StartProcess(r.Context(), req.Command, req.Args, req.Env, req.GroupID) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -212,6 +294,7 @@ func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) { go func() { buf := make([]byte, 1024) for { + n, err := process.Stdout.Read(buf) if err != nil { return @@ -235,3 +318,36 @@ func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) { // Wait for process to exit process.Cmd.Wait() } + +// Add new handlers for group management +func (s *Server) handleGroups(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + groups := s.ListGroups() + json.NewEncoder(w).Encode(groups) + default: + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } +} + +func (s *Server) handleGroup(w http.ResponseWriter, r *http.Request) { + groupID := r.URL.Path[len("/groups/"):] + + switch r.Method { + case http.MethodGet: + processes, err := s.GetGroupProcesses(groupID) + if err != nil { + http.Error(w, err.Error(), http.StatusNotFound) + return + } + json.NewEncoder(w).Encode(processes) + case http.MethodDelete: + if err := s.StopGroup(groupID); err != nil { + http.Error(w, err.Error(), http.StatusNotFound) + return + } + w.WriteHeader(http.StatusNoContent) + default: + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } +}