wip
Signed-off-by: Ettore Di Giacinto <mudler@localai.io> Signed-off-by: mudler <mudler@localai.io>
This commit is contained in:
committed by
mudler
parent
ce997d2425
commit
33b8aaddfe
214
pkg/stdio/client.go
Normal file
214
pkg/stdio/client.go
Normal file
@@ -0,0 +1,214 @@
|
||||
package stdio
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// JSONRPCRequest represents a JSON-RPC request
|
||||
type JSONRPCRequest struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID int64 `json:"id"`
|
||||
Method string `json:"method"`
|
||||
Params interface{} `json:"params"`
|
||||
}
|
||||
|
||||
// JSONRPCResponse represents a JSON-RPC response
|
||||
type JSONRPCResponse struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID int64 `json:"id"`
|
||||
Result json.RawMessage `json:"result,omitempty"`
|
||||
Error *JSONRPCError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// JSONRPCError represents a JSON-RPC error
|
||||
type JSONRPCError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// JSONRPCNotification represents a JSON-RPC notification
|
||||
type JSONRPCNotification struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
Notification struct {
|
||||
Method string `json:"method"`
|
||||
Params interface{} `json:"params,omitempty"`
|
||||
} `json:"notification"`
|
||||
}
|
||||
|
||||
// Client implements the transport.Interface for stdio processes
|
||||
type Client struct {
|
||||
baseURL string
|
||||
processID string
|
||||
conn *websocket.Conn
|
||||
mu sync.Mutex
|
||||
notifyChan chan JSONRPCNotification
|
||||
}
|
||||
|
||||
// NewClient creates a new stdio transport client
|
||||
func NewClient(baseURL string) *Client {
|
||||
return &Client{
|
||||
baseURL: baseURL,
|
||||
notifyChan: make(chan JSONRPCNotification, 100),
|
||||
}
|
||||
}
|
||||
|
||||
// Start initiates the connection to the server
|
||||
func (c *Client) Start(ctx context.Context) error {
|
||||
// Start a new process
|
||||
req := struct {
|
||||
Command string `json:"command"`
|
||||
Args []string `json:"args"`
|
||||
}{
|
||||
Command: "./mcp_server",
|
||||
Args: []string{},
|
||||
}
|
||||
|
||||
reqBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := http.Post(
|
||||
fmt.Sprintf("%s/processes", c.baseURL),
|
||||
"application/json",
|
||||
bytes.NewReader(reqBody),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start process: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var result struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
c.processID = result.ID
|
||||
|
||||
// Connect to WebSocket
|
||||
u := url.URL{
|
||||
Scheme: "ws",
|
||||
Host: c.baseURL,
|
||||
Path: fmt.Sprintf("/ws/%s", c.processID),
|
||||
}
|
||||
|
||||
conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to WebSocket: %w", err)
|
||||
}
|
||||
|
||||
c.conn = conn
|
||||
|
||||
// Start notification handler
|
||||
go c.handleNotifications()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close shuts down the client and closes the transport
|
||||
func (c *Client) Close() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.conn != nil {
|
||||
c.conn.Close()
|
||||
}
|
||||
|
||||
if c.processID != "" {
|
||||
req, err := http.NewRequest(
|
||||
"DELETE",
|
||||
fmt.Sprintf("%s/processes/%s", c.baseURL, c.processID),
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to stop process: %w", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendRequest sends a JSON-RPC request to the server
|
||||
func (c *Client) SendRequest(
|
||||
ctx context.Context,
|
||||
request JSONRPCRequest,
|
||||
) (*JSONRPCResponse, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.conn == nil {
|
||||
return nil, fmt.Errorf("not connected")
|
||||
}
|
||||
|
||||
if err := c.conn.WriteJSON(request); err != nil {
|
||||
return nil, fmt.Errorf("failed to write request: %w", err)
|
||||
}
|
||||
|
||||
var response JSONRPCResponse
|
||||
if err := c.conn.ReadJSON(&response); err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
// SendNotification sends a JSON-RPC notification to the server
|
||||
func (c *Client) SendNotification(
|
||||
ctx context.Context,
|
||||
notification JSONRPCNotification,
|
||||
) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.conn == nil {
|
||||
return fmt.Errorf("not connected")
|
||||
}
|
||||
|
||||
return c.conn.WriteJSON(notification)
|
||||
}
|
||||
|
||||
// SetNotificationHandler sets the handler for notifications
|
||||
func (c *Client) SetNotificationHandler(
|
||||
handler func(notification JSONRPCNotification),
|
||||
) {
|
||||
go func() {
|
||||
for notification := range c.notifyChan {
|
||||
handler(notification)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (c *Client) handleNotifications() {
|
||||
for {
|
||||
var notification JSONRPCNotification
|
||||
if err := c.conn.ReadJSON(¬ification); err != nil {
|
||||
if err == io.EOF {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
select {
|
||||
case c.notifyChan <- notification:
|
||||
default:
|
||||
// Drop notification if channel is full
|
||||
}
|
||||
}
|
||||
}
|
||||
71
pkg/stdio/example/main.go
Normal file
71
pkg/stdio/example/main.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAGI/pkg/stdio"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Start the server
|
||||
server := stdio.NewServer()
|
||||
go func() {
|
||||
if err := server.Start(":8080"); err != nil {
|
||||
log.Fatalf("Failed to start server: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Give the server time to start
|
||||
time.Sleep(time.Second)
|
||||
|
||||
// Create a client
|
||||
client := stdio.NewClient("localhost:8080")
|
||||
|
||||
// Start the client
|
||||
if err := client.Start(context.Background()); err != nil {
|
||||
log.Fatalf("Failed to start client: %v", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// Set up notification handler
|
||||
client.SetNotificationHandler(func(notification stdio.JSONRPCNotification) {
|
||||
fmt.Printf("Received notification: %+v\n", notification)
|
||||
})
|
||||
|
||||
// Send a request
|
||||
request := stdio.JSONRPCRequest{
|
||||
JSONRPC: "2.0",
|
||||
ID: 1,
|
||||
Method: "test",
|
||||
Params: map[string]string{"hello": "world"},
|
||||
}
|
||||
|
||||
response, err := client.SendRequest(context.Background(), request)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to send request: %v", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Received response: %+v\n", response)
|
||||
|
||||
// Send a notification
|
||||
notification := stdio.JSONRPCNotification{
|
||||
JSONRPC: "2.0",
|
||||
Notification: struct {
|
||||
Method string `json:"method"`
|
||||
Params interface{} `json:"params,omitempty"`
|
||||
}{
|
||||
Method: "test",
|
||||
Params: map[string]string{"hello": "world"},
|
||||
},
|
||||
}
|
||||
|
||||
if err := client.SendNotification(context.Background(), notification); err != nil {
|
||||
log.Fatalf("Failed to send notification: %v", err)
|
||||
}
|
||||
|
||||
// Keep the program running
|
||||
select {}
|
||||
}
|
||||
237
pkg/stdio/server.go
Normal file
237
pkg/stdio/server.go
Normal file
@@ -0,0 +1,237 @@
|
||||
package stdio
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os/exec"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// Process represents a running process with its stdio streams
|
||||
type Process struct {
|
||||
ID string
|
||||
Cmd *exec.Cmd
|
||||
Stdin io.WriteCloser
|
||||
Stdout io.ReadCloser
|
||||
Stderr io.ReadCloser
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// Server handles process management and stdio streaming
|
||||
type Server struct {
|
||||
processes map[string]*Process
|
||||
mu sync.RWMutex
|
||||
upgrader websocket.Upgrader
|
||||
}
|
||||
|
||||
// NewServer creates a new stdio server
|
||||
func NewServer() *Server {
|
||||
return &Server{
|
||||
processes: make(map[string]*Process),
|
||||
upgrader: websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// StartProcess starts a new process and returns its ID
|
||||
func (s *Server) StartProcess(ctx context.Context, command string, args []string) (string, error) {
|
||||
cmd := exec.CommandContext(ctx, command, args...)
|
||||
|
||||
stdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create stdin pipe: %w", err)
|
||||
}
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create stderr pipe: %w", err)
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return "", fmt.Errorf("failed to start process: %w", err)
|
||||
}
|
||||
|
||||
process := &Process{
|
||||
ID: fmt.Sprintf("%d", time.Now().UnixNano()),
|
||||
Cmd: cmd,
|
||||
Stdin: stdin,
|
||||
Stdout: stdout,
|
||||
Stderr: stderr,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.processes[process.ID] = process
|
||||
s.mu.Unlock()
|
||||
|
||||
return process.ID, nil
|
||||
}
|
||||
|
||||
// StopProcess stops a running process
|
||||
func (s *Server) StopProcess(id string) error {
|
||||
s.mu.Lock()
|
||||
process, exists := s.processes[id]
|
||||
if !exists {
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("process not found: %s", id)
|
||||
}
|
||||
delete(s.processes, id)
|
||||
s.mu.Unlock()
|
||||
|
||||
if err := process.Cmd.Process.Kill(); err != nil {
|
||||
return fmt.Errorf("failed to kill process: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetProcess returns a process by ID
|
||||
func (s *Server) GetProcess(id string) (*Process, error) {
|
||||
s.mu.RLock()
|
||||
process, exists := s.processes[id]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("process not found: %s", id)
|
||||
}
|
||||
|
||||
return process, nil
|
||||
}
|
||||
|
||||
// ListProcesses returns all running processes
|
||||
func (s *Server) ListProcesses() []*Process {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
processes := make([]*Process, 0, len(s.processes))
|
||||
for _, p := range s.processes {
|
||||
processes = append(processes, p)
|
||||
}
|
||||
|
||||
return processes
|
||||
}
|
||||
|
||||
// Start starts the HTTP server
|
||||
func (s *Server) Start(addr string) error {
|
||||
http.HandleFunc("/processes", s.handleProcesses)
|
||||
http.HandleFunc("/processes/", s.handleProcess)
|
||||
http.HandleFunc("/ws/", s.handleWebSocket)
|
||||
|
||||
return http.ListenAndServe(addr, nil)
|
||||
}
|
||||
|
||||
func (s *Server) handleProcesses(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
processes := s.ListProcesses()
|
||||
json.NewEncoder(w).Encode(processes)
|
||||
case http.MethodPost:
|
||||
var req struct {
|
||||
Command string `json:"command"`
|
||||
Args []string `json:"args"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
id, err := s.StartProcess(r.Context(), req.Command, req.Args)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]string{"id": id})
|
||||
default:
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleProcess(w http.ResponseWriter, r *http.Request) {
|
||||
id := r.URL.Path[len("/processes/"):]
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
process, err := s.GetProcess(id)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
json.NewEncoder(w).Encode(process)
|
||||
case http.MethodDelete:
|
||||
if err := s.StopProcess(id); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
default:
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) {
|
||||
id := r.URL.Path[len("/ws/"):]
|
||||
process, err := s.GetProcess(id)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := s.upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Handle stdin
|
||||
go func() {
|
||||
for {
|
||||
_, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
process.Stdin.Write(message)
|
||||
}
|
||||
}()
|
||||
|
||||
// Handle stdout
|
||||
go func() {
|
||||
buf := make([]byte, 1024)
|
||||
for {
|
||||
n, err := process.Stdout.Read(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
conn.WriteMessage(websocket.TextMessage, buf[:n])
|
||||
}
|
||||
}()
|
||||
|
||||
// Handle stderr
|
||||
go func() {
|
||||
buf := make([]byte, 1024)
|
||||
for {
|
||||
n, err := process.Stderr.Read(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
conn.WriteMessage(websocket.TextMessage, buf[:n])
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for process to exit
|
||||
process.Cmd.Wait()
|
||||
}
|
||||
Reference in New Issue
Block a user