From 9b7344fbdf09db3e4151dfb6cf997160b1f8575d Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Wed, 23 Apr 2025 22:53:49 +0200 Subject: [PATCH] fixes Signed-off-by: mudler --- pkg/stdio/client.go | 8 +- pkg/stdio/client_suite_test.go | 13 +++ pkg/stdio/client_test.go | 187 +++++++++++++++++++++++++++++++++ pkg/stdio/server.go | 128 ++++++++++++---------- 4 files changed, 275 insertions(+), 61 deletions(-) create mode 100644 pkg/stdio/client_suite_test.go create mode 100644 pkg/stdio/client_test.go diff --git a/pkg/stdio/client.go b/pkg/stdio/client.go index 01ff256..1de5fdd 100644 --- a/pkg/stdio/client.go +++ b/pkg/stdio/client.go @@ -58,7 +58,8 @@ func (c *Client) CreateProcess(ctx context.Context, command string, args []strin resp, err := http.Post(url, "application/json", bytes.NewReader(reqBody)) if err != nil { - return nil, fmt.Errorf("failed to start process: %w", err) + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("failed to start process: %w. body: %s", err, string(body)) } defer resp.Body.Close() @@ -68,8 +69,9 @@ func (c *Client) CreateProcess(ctx context.Context, command string, args []strin ID string `json:"id"` } - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - body, _ := io.ReadAll(resp.Body) + body, _ := io.ReadAll(resp.Body) + + if err := json.Unmarshal(body, &result); err != nil { return nil, fmt.Errorf("failed to decode response: %w. body: %s", err, string(body)) } diff --git a/pkg/stdio/client_suite_test.go b/pkg/stdio/client_suite_test.go new file mode 100644 index 0000000..11e52c0 --- /dev/null +++ b/pkg/stdio/client_suite_test.go @@ -0,0 +1,13 @@ +package stdio + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestSTDIOTransport(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "STDIOTransport test suite") +} diff --git a/pkg/stdio/client_test.go b/pkg/stdio/client_test.go new file mode 100644 index 0000000..9d6e722 --- /dev/null +++ b/pkg/stdio/client_test.go @@ -0,0 +1,187 @@ +package stdio + +import ( + "context" + "os" + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Client", func() { + var ( + client *Client + baseURL string + ) + + BeforeEach(func() { + baseURL = os.Getenv("STDIO_SERVER_URL") + if baseURL == "" { + baseURL = "http://localhost:8080" + } + client = NewClient(baseURL) + }) + + AfterEach(func() { + if client != nil { + Expect(client.Close()).To(Succeed()) + } + }) + + Context("Process Management", func() { + It("should create and stop a process", func() { + ctx := context.Background() + // Use a command that doesn't exit immediately + process, err := client.CreateProcess(ctx, "sh", []string{"-c", "echo 'Hello, World!'; sleep 10"}, []string{}, "test-group") + Expect(err).NotTo(HaveOccurred()) + Expect(process).NotTo(BeNil()) + Expect(process.ID).NotTo(BeEmpty()) + + // Get process IO + reader, writer, err := client.GetProcessIO(process.ID) + Expect(err).NotTo(HaveOccurred()) + Expect(reader).NotTo(BeNil()) + Expect(writer).NotTo(BeNil()) + + // Write to process + _, err = writer.Write([]byte("test input\n")) + Expect(err).NotTo(HaveOccurred()) + + // Read from process with timeout + buf := make([]byte, 1024) + readDone := make(chan struct{}) + var readErr error + var readN int + + go func() { + readN, readErr = reader.Read(buf) + close(readDone) + }() + + // Wait for read with timeout + select { + case <-readDone: + Expect(readErr).NotTo(HaveOccurred()) + Expect(readN).To(BeNumerically(">", 0)) + Expect(string(buf[:readN])).To(ContainSubstring("Hello, World!")) + case <-time.After(5 * time.Second): + Fail("Timeout waiting for process output") + } + + // Stop the process + err = client.StopProcess(process.ID) + Expect(err).NotTo(HaveOccurred()) + }) + + It("should manage process groups", func() { + ctx := context.Background() + groupID := "test-group" + + // Create multiple processes in the same group + process1, err := client.CreateProcess(ctx, "sh", []string{"-c", "echo 'Process 1'; sleep 1"}, []string{}, groupID) + Expect(err).NotTo(HaveOccurred()) + Expect(process1).NotTo(BeNil()) + + process2, err := client.CreateProcess(ctx, "sh", []string{"-c", "echo 'Process 2'; sleep 1"}, []string{}, groupID) + Expect(err).NotTo(HaveOccurred()) + Expect(process2).NotTo(BeNil()) + + // Get group processes + processes, err := client.GetGroupProcesses(groupID) + Expect(err).NotTo(HaveOccurred()) + Expect(processes).To(HaveLen(2)) + + // List groups + groups := client.ListGroups() + Expect(groups).To(ContainElement(groupID)) + + // Stop the group + err = client.StopGroup(groupID) + Expect(err).NotTo(HaveOccurred()) + }) + + It("should run a one-time process", func() { + ctx := context.Background() + output, err := client.RunProcess(ctx, "echo", []string{"One-time process"}, []string{}) + Expect(err).NotTo(HaveOccurred()) + Expect(output).To(ContainSubstring("One-time process")) + }) + + It("should handle process with environment variables", func() { + ctx := context.Background() + env := []string{"TEST_VAR=test_value"} + process, err := client.CreateProcess(ctx, "sh", []string{"-c", "env | grep TEST_VAR; sleep 1"}, env, "test-group") + Expect(err).NotTo(HaveOccurred()) + Expect(process).NotTo(BeNil()) + + // Get process IO + reader, _, err := client.GetProcessIO(process.ID) + Expect(err).NotTo(HaveOccurred()) + + // Read environment variables with timeout + buf := make([]byte, 1024) + readDone := make(chan struct{}) + var readErr error + var readN int + + go func() { + readN, readErr = reader.Read(buf) + close(readDone) + }() + + // Wait for read with timeout + select { + case <-readDone: + Expect(readErr).NotTo(HaveOccurred()) + Expect(readN).To(BeNumerically(">", 0)) + Expect(string(buf[:readN])).To(ContainSubstring("TEST_VAR=test_value")) + case <-time.After(5 * time.Second): + Fail("Timeout waiting for process output") + } + + // Stop the process + err = client.StopProcess(process.ID) + Expect(err).NotTo(HaveOccurred()) + }) + + It("should handle long-running processes", func() { + ctx := context.Background() + process, err := client.CreateProcess(ctx, "sh", []string{"-c", "echo 'Starting long process'; sleep 5"}, []string{}, "test-group") + Expect(err).NotTo(HaveOccurred()) + Expect(process).NotTo(BeNil()) + + // Get process IO + reader, _, err := client.GetProcessIO(process.ID) + Expect(err).NotTo(HaveOccurred()) + + // Read initial output + buf := make([]byte, 1024) + readDone := make(chan struct{}) + var readErr error + var readN int + + go func() { + readN, readErr = reader.Read(buf) + close(readDone) + }() + + // Wait for read with timeout + select { + case <-readDone: + Expect(readErr).NotTo(HaveOccurred()) + Expect(readN).To(BeNumerically(">", 0)) + Expect(string(buf[:readN])).To(ContainSubstring("Starting long process")) + case <-time.After(5 * time.Second): + Fail("Timeout waiting for process output") + } + + // Wait a bit to ensure process is running + time.Sleep(time.Second) + + // Stop the process + err = client.StopProcess(process.ID) + Expect(err).NotTo(HaveOccurred()) + }) + }) +}) diff --git a/pkg/stdio/server.go b/pkg/stdio/server.go index 7b3bf54..ec9c6cf 100644 --- a/pkg/stdio/server.go +++ b/pkg/stdio/server.go @@ -1,7 +1,7 @@ package stdio import ( - "bytes" + "bufio" "context" "encoding/json" "fmt" @@ -14,6 +14,7 @@ import ( "time" "github.com/gorilla/websocket" + "github.com/mudler/LocalAGI/pkg/xlog" ) // Process represents a running process with its stdio streams @@ -48,12 +49,13 @@ func NewServer() *Server { // StartProcess starts a new process and returns its ID func (s *Server) StartProcess(ctx context.Context, command string, args []string, env []string, groupID string) (string, error) { - log.Printf("Starting process: command=%s, args=%v, groupID=%s", command, args, groupID) + xlog.Debug("Starting process", "command", command, "args", args, "groupID", groupID) cmd := exec.CommandContext(ctx, command, args...) if len(env) > 0 { cmd.Env = append(os.Environ(), env...) + xlog.Debug("Process environment", "env", cmd.Env) } stdin, err := cmd.StdinPipe() @@ -92,7 +94,7 @@ func (s *Server) StartProcess(ctx context.Context, command string, args []string } s.mu.Unlock() - log.Printf("Successfully started process with ID: %s", process.ID) + xlog.Debug("Successfully started process", "id", process.ID, "pid", cmd.Process.Pid) return process.ID, nil } @@ -105,6 +107,8 @@ func (s *Server) StopProcess(id string) error { return fmt.Errorf("process not found: %s", id) } + xlog.Debug("Stopping process", "processID", id, "pid", process.Cmd.Process.Pid) + // Remove from group if it exists if process.GroupID != "" { groupProcesses := s.groups[process.GroupID] @@ -123,9 +127,11 @@ func (s *Server) StopProcess(id string) error { s.mu.Unlock() if err := process.Cmd.Process.Kill(); err != nil { + xlog.Debug("Failed to kill process", "processID", id, "pid", process.Cmd.Process.Pid, "error", err) return fmt.Errorf("failed to kill process: %w", err) } + xlog.Debug("Successfully killed process", "processID", id, "pid", process.Cmd.Process.Pid) return nil } @@ -253,7 +259,7 @@ func (s *Server) handleProcesses(w http.ResponseWriter, r *http.Request) { return } - id, err := s.StartProcess(r.Context(), req.Command, req.Args, req.Env, req.GroupID) + id, err := s.StartProcess(context.Background(), req.Command, req.Args, req.Env, req.GroupID) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -290,7 +296,7 @@ func (s *Server) handleProcess(w http.ResponseWriter, r *http.Request) { func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) { id := r.URL.Path[len("/ws/"):] - log.Printf("Handling WebSocket connection for process: %s", id) + xlog.Debug("Handling WebSocket connection", "processID", id) process, err := s.GetProcess(id) if err != nil { @@ -298,6 +304,14 @@ func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) { return } + if process.Cmd.ProcessState != nil && process.Cmd.ProcessState.Exited() { + xlog.Debug("Process already exited", "processID", id) + http.Error(w, "Process already exited", http.StatusGone) + return + } + + xlog.Debug("Process is running", "processID", id, "pid", process.Cmd.Process.Pid) + conn, err := s.upgrader.Upgrade(w, r, nil) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) @@ -305,24 +319,19 @@ func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) { } defer conn.Close() - log.Printf("WebSocket connection established for process: %s", id) + xlog.Debug("WebSocket connection established", "processID", id) // Create a done channel to signal process completion done := make(chan struct{}) - // Create buffers to capture output - var stdoutBuf, stderrBuf bytes.Buffer - stdoutTee := io.TeeReader(process.Stdout, &stdoutBuf) - stderrTee := io.TeeReader(process.Stderr, &stderrBuf) - // Handle stdin go func() { defer func() { select { case <-done: - // Process already done, this is expected + xlog.Debug("Process stdin handler done", "processID", id) default: - log.Printf("WebSocket stdin connection closed for process %s", id) + xlog.Debug("WebSocket stdin connection closed", "processID", id) } }() @@ -330,86 +339,89 @@ func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) { _, message, err := conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) { - log.Printf("WebSocket stdin unexpected error for process %s: %v", id, err) + xlog.Debug("WebSocket stdin unexpected error", "processID", id, "error", err) } return } + xlog.Debug("Received message", "processID", id, "message", string(message)) if _, err := process.Stdin.Write(message); err != nil { if err != io.EOF { - log.Printf("WebSocket stdin write error for process %s: %v", id, err) + xlog.Debug("WebSocket stdin write error", "processID", id, "error", err) } return } + xlog.Debug("Message sent to process", "processID", id, "message", string(message)) } }() - // Handle stdout + // Handle stdout and stderr using bufio.Scanner go func() { defer func() { select { case <-done: - // Process already done, this is expected + xlog.Debug("Process output handler done", "processID", id) default: - log.Printf("WebSocket stdout connection closed for process %s", id) + xlog.Debug("WebSocket output connection closed", "processID", id) } }() - buf := make([]byte, 1024) - for { - n, err := stdoutTee.Read(buf) - if err != nil { - if err != io.EOF { - log.Printf("WebSocket stdout read error for process %s: %v", id, err) - } - return + // Create a scanner that reads from both stdout and stderr + scanner := bufio.NewScanner(io.MultiReader(process.Stdout, process.Stderr)) + // Set a larger buffer size for JSON-RPC messages (10MB) + scanner.Buffer(make([]byte, 10*1024*1024), 10*1024*1024) + // Use a custom split function to handle JSON-RPC messages + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil } - if err := conn.WriteMessage(websocket.TextMessage, buf[:n]); err != nil { + + // Look for the end of a JSON-RPC message + for i := 0; i < len(data); i++ { + if data[i] == '\n' { + return i + 1, data[:i], nil + } + } + + // If we're at EOF, return the remaining data + if atEOF { + return len(data), data, nil + } + + // Request more data + return 0, nil, nil + }) + + for scanner.Scan() { + line := scanner.Text() + xlog.Debug("Sending message", "processID", id, "message", line) + if err := conn.WriteMessage(websocket.TextMessage, []byte(line)); err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) { - log.Printf("WebSocket stdout write error for process %s: %v", id, err) + xlog.Debug("WebSocket output write error", "processID", id, "error", err) } return } + xlog.Debug("Message sent to client", "processID", id, "message", line) } - }() - // Handle stderr - go func() { - defer func() { - select { - case <-done: - // Process already done, this is expected - default: - log.Printf("WebSocket stderr connection closed for process %s", id) - } - }() - - buf := make([]byte, 1024) - for { - n, err := stderrTee.Read(buf) - if err != nil { - if err != io.EOF { - log.Printf("WebSocket stderr read error for process %s: %v", id, err) - } - return - } - if err := conn.WriteMessage(websocket.TextMessage, buf[:n]); err != nil { - if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) { - log.Printf("WebSocket stderr write error for process %s: %v", id, err) - } - return - } + if err := scanner.Err(); err != nil { + xlog.Debug("Scanner error", "processID", id, "error", err) } }() // Wait for process to exit + xlog.Debug("Waiting for process to exit", "processID", id) err = process.Cmd.Wait() close(done) // Signal that the process is done if err != nil { - log.Printf("Process %s exited with error: %v\nstdout: %s\nstderr: %s", - id, err, stdoutBuf.String(), stderrBuf.String()) + xlog.Debug("Process exited with error", + "processID", id, + "pid", process.Cmd.Process.Pid, + "error", err) } else { - log.Printf("Process %s exited successfully", id) + xlog.Debug("Process exited successfully", + "processID", id, + "pid", process.Cmd.Process.Pid) } }