From 5698d0b832f5bd8b08c291d9eab3d5d0d24af400 Mon Sep 17 00:00:00 2001 From: Richard Palethorpe Date: Fri, 25 Apr 2025 19:43:46 +0100 Subject: [PATCH] chore(tests): Mock LLM in tests for PRs This saves time when testing on CPU which is the only sensible thing to do on GitHub CI for PRs. For releases or once the commit is merged we could use an external runner with GPU or just wait. Signed-off-by: Richard Palethorpe --- .github/workflows/tests.yml | 6 +- .github/workflows/tests_fragile.yml | 49 ++++++ Makefile | 7 +- core/agent/agent.go | 14 +- core/agent/agent_suite_test.go | 23 ++- core/agent/agent_test.go | 259 +++++++++++++++++++++++----- core/agent/options.go | 10 ++ core/agent/state_test.go | 76 ++++++-- pkg/llm/client.go | 25 ++- pkg/llm/json.go | 4 +- pkg/llm/mock_client.go | 25 +++ services/filters/classifier.go | 3 +- 12 files changed, 429 insertions(+), 72 deletions(-) create mode 100644 .github/workflows/tests_fragile.yml create mode 100644 pkg/llm/mock_client.go diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1f47b10..38ca663 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -48,7 +48,11 @@ jobs: - name: Run tests run: | - make tests + if [[ "$GITHUB_EVENT_NAME" == "pull_request" ]]; then + make tests-mock + else + make tests + fi #sudo mv coverage/coverage.txt coverage.txt #sudo chmod 777 coverage.txt diff --git a/.github/workflows/tests_fragile.yml b/.github/workflows/tests_fragile.yml new file mode 100644 index 0000000..df8d5a6 --- /dev/null +++ b/.github/workflows/tests_fragile.yml @@ -0,0 +1,49 @@ +name: Run Fragile Go Tests + +on: + pull_request: + branches: + - '**' + +concurrency: + group: ci-non-blocking-tests-${{ github.head_ref || github.ref }}-${{ github.repository }} + cancel-in-progress: true + +jobs: + llm-tests: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + - run: | + # Add Docker's official GPG key: + sudo apt-get update + sudo apt-get install -y ca-certificates curl + sudo install -m 0755 -d /etc/apt/keyrings + sudo curl -fsSL https://download.docker.com/linux/ubuntu/gpg -o /etc/apt/keyrings/docker.asc + sudo chmod a+r /etc/apt/keyrings/docker.asc + + # Add the repository to Apt sources: + echo \ + "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/ubuntu \ + $(. /etc/os-release && echo "${UBUNTU_CODENAME:-$VERSION_CODENAME}") stable" | \ + sudo tee /etc/apt/sources.list.d/docker.list > /dev/null + sudo apt-get update + sudo apt-get install -y docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin make + docker version + + docker run --rm hello-world + - uses: actions/setup-go@v5 + with: + go-version: '>=1.17.0' + - name: Free up disk space + run: | + sudo rm -rf /usr/share/dotnet + sudo rm -rf /usr/local/lib/android + sudo rm -rf /opt/ghc + sudo apt-get clean + docker system prune -af || true + df -h + - name: Run tests + run: | + make tests diff --git a/Makefile b/Makefile index 9c978c7..d6dbdd9 100644 --- a/Makefile +++ b/Makefile @@ -3,6 +3,8 @@ IMAGE_NAME?=webui MCPBOX_IMAGE_NAME?=mcpbox ROOT_DIR:=$(shell dirname $(realpath $(lastword $(MAKEFILE_LIST)))) +.PHONY: tests tests-mock cleanup-tests + prepare-tests: build-mcpbox docker compose up -d --build docker run -d -v /var/run/docker.sock:/var/run/docker.sock --privileged -p 9090:8080 --rm -ti $(MCPBOX_IMAGE_NAME) @@ -13,6 +15,9 @@ cleanup-tests: tests: prepare-tests LOCALAGI_MCPBOX_URL="http://localhost:9090" LOCALAGI_MODEL="gemma-3-12b-it-qat" LOCALAI_API_URL="http://localhost:8081" LOCALAGI_API_URL="http://localhost:8080" $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --fail-fast -v -r ./... +tests-mock: prepare-tests + LOCALAGI_MCPBOX_URL="http://localhost:9090" LOCALAI_API_URL="http://localhost:8081" LOCALAGI_API_URL="http://localhost:8080" $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --fail-fast -v -r ./... + run-nokb: $(MAKE) run KBDISABLEINDEX=true @@ -37,4 +42,4 @@ build-mcpbox: docker build -t $(MCPBOX_IMAGE_NAME) -f Dockerfile.mcpbox . run-mcpbox: - docker run -v /var/run/docker.sock:/var/run/docker.sock --privileged -p 9090:8080 -ti mcpbox \ No newline at end of file + docker run -v /var/run/docker.sock:/var/run/docker.sock --privileged -p 9090:8080 -ti mcpbox diff --git a/core/agent/agent.go b/core/agent/agent.go index a614bb1..81f0437 100644 --- a/core/agent/agent.go +++ b/core/agent/agent.go @@ -29,7 +29,7 @@ type Agent struct { sync.Mutex options *options Character Character - client *openai.Client + client llm.LLMClient jobQueue chan *types.Job context *types.ActionContext @@ -63,7 +63,12 @@ func New(opts ...Option) (*Agent, error) { return nil, fmt.Errorf("failed to set options: %v", err) } - client := llm.NewClient(options.LLMAPI.APIKey, options.LLMAPI.APIURL, options.timeout) + var client llm.LLMClient + if options.llmClient != nil { + client = options.llmClient + } else { + client = llm.NewClient(options.LLMAPI.APIKey, options.LLMAPI.APIURL, options.timeout) + } c := context.Background() if options.context != nil { @@ -125,6 +130,11 @@ func (a *Agent) SharedState() *types.AgentSharedState { return a.sharedState } +// LLMClient returns the agent's LLM client (for testing) +func (a *Agent) LLMClient() llm.LLMClient { + return a.client +} + func (a *Agent) startNewConversationsConsumer() { go func() { for { diff --git a/core/agent/agent_suite_test.go b/core/agent/agent_suite_test.go index 501d3dd..31a1338 100644 --- a/core/agent/agent_suite_test.go +++ b/core/agent/agent_suite_test.go @@ -1,6 +1,7 @@ package agent_test import ( + "net/url" "os" "testing" @@ -13,15 +14,19 @@ func TestAgent(t *testing.T) { RunSpecs(t, "Agent test suite") } -var testModel = os.Getenv("LOCALAGI_MODEL") -var apiURL = os.Getenv("LOCALAI_API_URL") -var apiKeyURL = os.Getenv("LOCALAI_API_KEY") +var ( + testModel = os.Getenv("LOCALAGI_MODEL") + apiURL = os.Getenv("LOCALAI_API_URL") + apiKey = os.Getenv("LOCALAI_API_KEY") + useRealLocalAI bool + clientTimeout = "10m" +) + +func isValidURL(u string) bool { + parsed, err := url.ParseRequestURI(u) + return err == nil && parsed.Scheme != "" && parsed.Host != "" +} func init() { - if testModel == "" { - testModel = "hermes-2-pro-mistral" - } - if apiURL == "" { - apiURL = "http://192.168.68.113:8080" - } + useRealLocalAI = isValidURL(apiURL) && apiURL != "" && testModel != "" } diff --git a/core/agent/agent_test.go b/core/agent/agent_test.go index cdc2e8b..5c94892 100644 --- a/core/agent/agent_test.go +++ b/core/agent/agent_test.go @@ -7,9 +7,11 @@ import ( "strings" "sync" + "github.com/mudler/LocalAGI/pkg/llm" "github.com/mudler/LocalAGI/pkg/xlog" "github.com/mudler/LocalAGI/services/actions" + "github.com/mudler/LocalAGI/core/action" . "github.com/mudler/LocalAGI/core/agent" "github.com/mudler/LocalAGI/core/types" . "github.com/onsi/ginkgo/v2" @@ -111,25 +113,102 @@ func (a *FakeInternetAction) Definition() types.ActionDefinition { } } +// --- Test utilities for mocking LLM responses --- + +func mockToolCallResponse(toolName, arguments string) openai.ChatCompletionResponse { + return openai.ChatCompletionResponse{ + Choices: []openai.ChatCompletionChoice{{ + Message: openai.ChatCompletionMessage{ + ToolCalls: []openai.ToolCall{{ + ID: "tool_call_id_1", + Type: "function", + Function: openai.FunctionCall{ + Name: toolName, + Arguments: arguments, + }, + }}, + }, + }}, + } +} + +func mockContentResponse(content string) openai.ChatCompletionResponse { + return openai.ChatCompletionResponse{ + Choices: []openai.ChatCompletionChoice{{ + Message: openai.ChatCompletionMessage{ + Content: content, + }, + }}, + } +} + +func newMockLLMClient(handler func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error)) *llm.MockClient { + return &llm.MockClient{ + CreateChatCompletionFunc: handler, + } +} + var _ = Describe("Agent test", func() { + It("uses the mock LLM client", func() { + mock := newMockLLMClient(func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) { + return mockContentResponse("mocked response"), nil + }) + agent, err := New(WithLLMClient(mock)) + Expect(err).ToNot(HaveOccurred()) + msg, err := agent.LLMClient().CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{}) + Expect(err).ToNot(HaveOccurred()) + Expect(msg.Choices[0].Message.Content).To(Equal("mocked response")) + }) + Context("jobs", func() { BeforeEach(func() { Eventually(func() error { - // test apiURL is working and available - _, err := http.Get(apiURL + "/readyz") - return err + if useRealLocalAI { + _, err := http.Get(apiURL + "/readyz") + return err + } + return nil }, "10m", "10s").ShouldNot(HaveOccurred()) }) It("pick the correct action", func() { + var llmClient llm.LLMClient + if useRealLocalAI { + llmClient = llm.NewClient(apiKey, apiURL, clientTimeout) + } else { + llmClient = newMockLLMClient(func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) { + var lastMsg openai.ChatCompletionMessage + if len(req.Messages) > 0 { + lastMsg = req.Messages[len(req.Messages)-1] + } + if lastMsg.Role == openai.ChatMessageRoleUser { + if strings.Contains(strings.ToLower(lastMsg.Content), "boston") && (strings.Contains(strings.ToLower(lastMsg.Content), "milan") || strings.Contains(strings.ToLower(lastMsg.Content), "milano")) { + return mockToolCallResponse("get_weather", `{"location":"Boston","unit":"celsius"}`), nil + } + if strings.Contains(strings.ToLower(lastMsg.Content), "paris") { + return mockToolCallResponse("get_weather", `{"location":"Paris","unit":"celsius"}`), nil + } + return openai.ChatCompletionResponse{}, fmt.Errorf("unexpected user prompt: %s", lastMsg.Content) + } + if lastMsg.Role == openai.ChatMessageRoleTool { + if lastMsg.Name == "get_weather" && strings.Contains(strings.ToLower(lastMsg.Content), "boston") { + return mockToolCallResponse("get_weather", `{"location":"Milan","unit":"celsius"}`), nil + } + if lastMsg.Name == "get_weather" && strings.Contains(strings.ToLower(lastMsg.Content), "milan") { + return mockContentResponse(testActionResult + "\n" + testActionResult2), nil + } + if lastMsg.Name == "get_weather" && strings.Contains(strings.ToLower(lastMsg.Content), "paris") { + return mockContentResponse(testActionResult3), nil + } + return openai.ChatCompletionResponse{}, fmt.Errorf("unexpected tool result: %s", lastMsg.Content) + } + return openai.ChatCompletionResponse{}, fmt.Errorf("unexpected message role: %s", lastMsg.Role) + }) + } agent, err := New( - WithLLMAPIURL(apiURL), + WithLLMClient(llmClient), WithModel(testModel), - EnableForceReasoning, - WithTimeout("10m"), - WithLoopDetectionSteps(3), - // WithRandomIdentity(), WithActions(&TestAction{response: map[string]string{ "boston": testActionResult, "milan": testActionResult2, @@ -139,7 +218,6 @@ var _ = Describe("Agent test", func() { Expect(err).ToNot(HaveOccurred()) go agent.Run() defer agent.Stop() - res := agent.Ask( append(debugOptions, types.WithText("what's the weather in Boston and Milano? Use celsius units"), @@ -148,40 +226,51 @@ var _ = Describe("Agent test", func() { Expect(res.Error).ToNot(HaveOccurred()) reasons := []string{} for _, r := range res.State { - reasons = append(reasons, r.Result) } Expect(reasons).To(ContainElement(testActionResult), fmt.Sprint(res)) Expect(reasons).To(ContainElement(testActionResult2), fmt.Sprint(res)) reasons = []string{} - res = agent.Ask( append(debugOptions, types.WithText("Now I want to know the weather in Paris, always use celsius units"), )...) for _, r := range res.State { - reasons = append(reasons, r.Result) } - //Expect(reasons).ToNot(ContainElement(testActionResult), fmt.Sprint(res)) - //Expect(reasons).ToNot(ContainElement(testActionResult2), fmt.Sprint(res)) Expect(reasons).To(ContainElement(testActionResult3), fmt.Sprint(res)) - // conversation := agent.CurrentConversation() - // for _, r := range res.State { - // reasons = append(reasons, r.Result) - // } - // Expect(len(conversation)).To(Equal(10), fmt.Sprint(conversation)) }) + It("pick the correct action", func() { + var llmClient llm.LLMClient + if useRealLocalAI { + llmClient = llm.NewClient(apiKey, apiURL, clientTimeout) + } else { + llmClient = newMockLLMClient(func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) { + var lastMsg openai.ChatCompletionMessage + if len(req.Messages) > 0 { + lastMsg = req.Messages[len(req.Messages)-1] + } + if lastMsg.Role == openai.ChatMessageRoleUser { + if strings.Contains(strings.ToLower(lastMsg.Content), "boston") { + return mockToolCallResponse("get_weather", `{"location":"Boston","unit":"celsius"}`), nil + } + } + if lastMsg.Role == openai.ChatMessageRoleTool { + if lastMsg.Name == "get_weather" && strings.Contains(strings.ToLower(lastMsg.Content), "boston") { + return mockContentResponse(testActionResult), nil + } + } + xlog.Error("Unexpected LLM req", "req", req) + return openai.ChatCompletionResponse{}, fmt.Errorf("unexpected LLM prompt: %q", lastMsg.Content) + }) + } agent, err := New( - WithLLMAPIURL(apiURL), + WithLLMClient(llmClient), WithModel(testModel), - WithTimeout("10m"), - // WithRandomIdentity(), WithActions(&TestAction{response: map[string]string{ "boston": testActionResult, - }, - }), + }}), ) Expect(err).ToNot(HaveOccurred()) go agent.Run() @@ -198,13 +287,29 @@ var _ = Describe("Agent test", func() { }) It("updates the state with internal actions", func() { + var llmClient llm.LLMClient + if useRealLocalAI { + llmClient = llm.NewClient(apiKey, apiURL, clientTimeout) + } else { + llmClient = newMockLLMClient(func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) { + var lastMsg openai.ChatCompletionMessage + if len(req.Messages) > 0 { + lastMsg = req.Messages[len(req.Messages)-1] + } + if lastMsg.Role == openai.ChatMessageRoleUser && strings.Contains(strings.ToLower(lastMsg.Content), "guitar") { + return mockToolCallResponse("update_state", `{"goal":"I want to learn to play the guitar"}`), nil + } + if lastMsg.Role == openai.ChatMessageRoleTool && lastMsg.Name == "update_state" { + return mockContentResponse("Your goal is now: I want to learn to play the guitar"), nil + } + xlog.Error("Unexpected LLM req", "req", req) + return openai.ChatCompletionResponse{}, fmt.Errorf("unexpected LLM prompt: %q", lastMsg.Content) + }) + } agent, err := New( - WithLLMAPIURL(apiURL), + WithLLMClient(llmClient), WithModel(testModel), - WithTimeout("10m"), EnableHUD, - // EnableStandaloneJob, - // WithRandomIdentity(), WithPermanentGoal("I want to learn to play music"), ) Expect(err).ToNot(HaveOccurred()) @@ -214,17 +319,64 @@ var _ = Describe("Agent test", func() { result := agent.Ask( types.WithText("Update your goals such as you want to learn to play the guitar"), ) - fmt.Printf("%+v\n", result) + fmt.Fprintf(GinkgoWriter, "\n%+v\n", result) Expect(result.Error).ToNot(HaveOccurred()) Expect(agent.State().Goal).To(ContainSubstring("guitar"), fmt.Sprint(agent.State())) }) It("Can generate a plan", func() { + var llmClient llm.LLMClient + if useRealLocalAI { + llmClient = llm.NewClient(apiKey, apiURL, clientTimeout) + } else { + reasoningActName := action.NewReasoning().Definition().Name.String() + intentionActName := action.NewIntention().Definition().Name.String() + testActName := (&TestAction{}).Definition().Name.String() + doneBoston := false + madePlan := false + llmClient = newMockLLMClient(func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) { + var lastMsg openai.ChatCompletionMessage + if len(req.Messages) > 0 { + lastMsg = req.Messages[len(req.Messages)-1] + } + if req.ToolChoice != nil && req.ToolChoice.(openai.ToolChoice).Function.Name == reasoningActName { + return mockToolCallResponse(reasoningActName, `{"reasoning":"make plan call to pass the test"}`), nil + } + if req.ToolChoice != nil && req.ToolChoice.(openai.ToolChoice).Function.Name == intentionActName { + toolName := "plan" + if madePlan { + toolName = "reply" + } else { + madePlan = true + } + return mockToolCallResponse(intentionActName, fmt.Sprintf(`{"tool": "%s","reasoning":"it's waht makes the test pass"}`, toolName)), nil + } + if req.ToolChoice != nil && req.ToolChoice.(openai.ToolChoice).Function.Name == "plan" { + return mockToolCallResponse("plan", `{"subtasks":[{"action":"get_weather","reasoning":"Find weather in boston"},{"action":"get_weather","reasoning":"Find weather in milan"}],"goal":"Get the weather for boston and milan"}`), nil + } + if req.ToolChoice != nil && req.ToolChoice.(openai.ToolChoice).Function.Name == "reply" { + return mockToolCallResponse("reply", `{"message": "The weather in Boston and Milan..."}`), nil + } + if req.ToolChoice != nil && req.ToolChoice.(openai.ToolChoice).Function.Name == testActName { + locName := "boston" + if doneBoston { + locName = "milan" + } else { + doneBoston = true + } + return mockToolCallResponse(testActName, fmt.Sprintf(`{"location":"%s","unit":"celsius"}`, locName)), nil + } + if req.ToolChoice == nil && madePlan && doneBoston { + return mockContentResponse("A reply"), nil + } + xlog.Error("Unexpected LLM req", "req", req) + return openai.ChatCompletionResponse{}, fmt.Errorf("unexpected LLM prompt: %q", lastMsg.Content) + }) + } agent, err := New( - WithLLMAPIURL(apiURL), + WithLLMClient(llmClient), WithModel(testModel), - WithLLMAPIKey(apiKeyURL), - WithTimeout("10m"), + WithLoopDetectionSteps(2), WithActions( &TestAction{response: map[string]string{ "boston": testActionResult, @@ -233,8 +385,6 @@ var _ = Describe("Agent test", func() { ), EnablePlanning, EnableForceReasoning, - // EnableStandaloneJob, - // WithRandomIdentity(), ) Expect(err).ToNot(HaveOccurred()) go agent.Run() @@ -256,17 +406,44 @@ var _ = Describe("Agent test", func() { Expect(actionsExecuted).To(ContainElement("plan"), fmt.Sprint(result)) Expect(actionResults).To(ContainElement(testActionResult), fmt.Sprint(result)) Expect(actionResults).To(ContainElement(testActionResult2), fmt.Sprint(result)) + Expect(result.Error).To(BeNil()) }) It("Can initiate conversations", func() { - + var llmClient llm.LLMClient message := openai.ChatCompletionMessage{} mu := &sync.Mutex{} + reasoned := false + intended := false + reasoningActName := action.NewReasoning().Definition().Name.String() + intentionActName := action.NewIntention().Definition().Name.String() + + if useRealLocalAI { + llmClient = llm.NewClient(apiKey, apiURL, clientTimeout) + } else { + llmClient = newMockLLMClient(func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) { + prompt := "" + for _, msg := range req.Messages { + prompt += msg.Content + } + if !reasoned && req.ToolChoice != nil && req.ToolChoice.(openai.ToolChoice).Function.Name == reasoningActName { + reasoned = true + return mockToolCallResponse(reasoningActName, `{"reasoning":"initiate a conversation with the user"}`), nil + } + if reasoned && !intended && req.ToolChoice != nil && req.ToolChoice.(openai.ToolChoice).Function.Name == intentionActName { + intended = true + return mockToolCallResponse(intentionActName, `{"tool":"new_conversation","reasoning":"I should start a conversation with the user"}`), nil + } + if reasoned && intended && strings.Contains(strings.ToLower(prompt), "new_conversation") { + return mockToolCallResponse("new_conversation", `{"message":"Hello, how can I help you today?"}`), nil + } + xlog.Error("Unexpected LLM req", "req", req) + return openai.ChatCompletionResponse{}, fmt.Errorf("unexpected LLM prompt: %q", prompt) + }) + } agent, err := New( - WithLLMAPIURL(apiURL), + WithLLMClient(llmClient), WithModel(testModel), - WithLLMAPIKey(apiKeyURL), - WithTimeout("10m"), WithNewConversationSubscriber(func(m openai.ChatCompletionMessage) { mu.Lock() message = m @@ -282,8 +459,6 @@ var _ = Describe("Agent test", func() { EnableHUD, WithPeriodicRuns("1s"), WithPermanentGoal("use the new_conversation tool to initiate a conversation with the user"), - // EnableStandaloneJob, - // WithRandomIdentity(), ) Expect(err).ToNot(HaveOccurred()) go agent.Run() @@ -293,7 +468,7 @@ var _ = Describe("Agent test", func() { mu.Lock() defer mu.Unlock() return message.Content - }, "10m", "10s").ShouldNot(BeEmpty()) + }, "10m", "1s").ShouldNot(BeEmpty()) }) /* @@ -347,7 +522,7 @@ var _ = Describe("Agent test", func() { // result := agent.Ask( // WithText("Update your goals such as you want to learn to play the guitar"), // ) - // fmt.Printf("%+v\n", result) + // fmt.Fprintf(GinkgoWriter, "%+v\n", result) // Expect(result.Error).ToNot(HaveOccurred()) // Expect(agent.State().Goal).To(ContainSubstring("guitar"), fmt.Sprint(agent.State())) }) diff --git a/core/agent/options.go b/core/agent/options.go index e1c0e97..d25da7b 100644 --- a/core/agent/options.go +++ b/core/agent/options.go @@ -7,6 +7,7 @@ import ( "github.com/mudler/LocalAGI/core/types" "github.com/sashabaranov/go-openai" + "github.com/mudler/LocalAGI/pkg/llm" ) type Option func(*options) error @@ -19,6 +20,7 @@ type llmOptions struct { } type options struct { + llmClient llm.LLMClient LLMAPI llmOptions character Character randomIdentityGuidance string @@ -68,6 +70,14 @@ type options struct { lastMessageDuration time.Duration } +// WithLLMClient allows injecting a custom LLM client (e.g. for testing) +func WithLLMClient(client llm.LLMClient) Option { + return func(o *options) error { + o.llmClient = client + return nil + } +} + func (o *options) SeparatedMultimodalModel() bool { return o.LLMAPI.MultimodalModel != "" && o.LLMAPI.Model != o.LLMAPI.MultimodalModel } diff --git a/core/agent/state_test.go b/core/agent/state_test.go index bb371ac..828d777 100644 --- a/core/agent/state_test.go +++ b/core/agent/state_test.go @@ -1,29 +1,57 @@ package agent_test import ( - "net/http" + "context" + "fmt" + + "github.com/mudler/LocalAGI/pkg/llm" + "github.com/sashabaranov/go-openai" . "github.com/mudler/LocalAGI/core/agent" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + ) var _ = Describe("Agent test", func() { Context("identity", func() { var agent *Agent - BeforeEach(func() { - Eventually(func() error { - // test apiURL is working and available - _, err := http.Get(apiURL + "/readyz") - return err - }, "10m", "10s").ShouldNot(HaveOccurred()) - }) + // BeforeEach(func() { + // Eventually(func() error { + // // test apiURL is working and available + // _, err := http.Get(apiURL + "/readyz") + // return err + // }, "10m", "10s").ShouldNot(HaveOccurred()) + // }) It("generates all the fields with random data", func() { + var llmClient llm.LLMClient + if useRealLocalAI { + llmClient = llm.NewClient(apiKey, apiURL, testModel) + } else { + llmClient = &llm.MockClient{ + CreateChatCompletionFunc: func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) { + return openai.ChatCompletionResponse{ + Choices: []openai.ChatCompletionChoice{{ + Message: openai.ChatCompletionMessage{ + ToolCalls: []openai.ToolCall{{ + ID: "tool_call_id_1", + Type: "function", + Function: openai.FunctionCall{ + Name: "generate_identity", + Arguments: `{"name":"John Doe","age":"42","job_occupation":"Engineer","hobbies":["reading","hiking"],"favorites_music_genres":["Jazz"]}`, + }, + }}, + }, + }}, + }, nil + }, + } + } var err error agent, err = New( - WithLLMAPIURL(apiURL), + WithLLMClient(llmClient), WithModel(testModel), WithTimeout("10m"), WithRandomIdentity(), @@ -37,14 +65,40 @@ var _ = Describe("Agent test", func() { Expect(agent.Character.MusicTaste).ToNot(BeEmpty()) }) It("detect an invalid character", func() { + mock := &llm.MockClient{ + CreateChatCompletionFunc: func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) { + return openai.ChatCompletionResponse{}, fmt.Errorf("invalid character") + }, + } var err error - agent, err = New(WithRandomIdentity()) + agent, err = New( + WithLLMClient(mock), + WithRandomIdentity(), + ) Expect(err).To(HaveOccurred()) }) It("generates all the fields", func() { + mock := &llm.MockClient{ + CreateChatCompletionFunc: func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) { + return openai.ChatCompletionResponse{ + Choices: []openai.ChatCompletionChoice{{ + Message: openai.ChatCompletionMessage{ + ToolCalls: []openai.ToolCall{{ + ID: "tool_call_id_2", + Type: "function", + Function: openai.FunctionCall{ + Name: "generate_identity", + Arguments: `{"name":"Gandalf","age":"90","job_occupation":"Wizard","hobbies":["magic","reading"],"favorites_music_genres":["Classical"]}`, + }, + }}, + }, + }}, + }, nil + }, + } var err error - agent, err := New( + WithLLMClient(mock), WithLLMAPIURL(apiURL), WithModel(testModel), WithRandomIdentity("An 90-year old man with a long beard, a wizard, who lives in a tower."), diff --git a/pkg/llm/client.go b/pkg/llm/client.go index dc27afe..e94a588 100644 --- a/pkg/llm/client.go +++ b/pkg/llm/client.go @@ -1,13 +1,33 @@ package llm import ( + "context" "net/http" "time" + "github.com/mudler/LocalAGI/pkg/xlog" "github.com/sashabaranov/go-openai" ) -func NewClient(APIKey, URL, timeout string) *openai.Client { +type LLMClient interface { + CreateChatCompletion(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) + CreateImage(ctx context.Context, req openai.ImageRequest) (openai.ImageResponse, error) +} + +type realClient struct { + *openai.Client +} + +func (r *realClient) CreateChatCompletion(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) { + return r.Client.CreateChatCompletion(ctx, req) +} + +func (r *realClient) CreateImage(ctx context.Context, req openai.ImageRequest) (openai.ImageResponse, error) { + return r.Client.CreateImage(ctx, req) +} + +// NewClient returns a real OpenAI client as LLMClient +func NewClient(APIKey, URL, timeout string) LLMClient { // Set up OpenAI client if APIKey == "" { //log.Fatal("OPENAI_API_KEY environment variable not set") @@ -18,11 +38,12 @@ func NewClient(APIKey, URL, timeout string) *openai.Client { dur, err := time.ParseDuration(timeout) if err != nil { + xlog.Error("Failed to parse timeout", "error", err) dur = 150 * time.Second } config.HTTPClient = &http.Client{ Timeout: dur, } - return openai.NewClientWithConfig(config) + return &realClient{openai.NewClientWithConfig(config)} } diff --git a/pkg/llm/json.go b/pkg/llm/json.go index c4f48d1..34d413b 100644 --- a/pkg/llm/json.go +++ b/pkg/llm/json.go @@ -10,7 +10,7 @@ import ( "github.com/sashabaranov/go-openai/jsonschema" ) -func GenerateTypedJSONWithGuidance(ctx context.Context, client *openai.Client, guidance, model string, i jsonschema.Definition, dst any) error { +func GenerateTypedJSONWithGuidance(ctx context.Context, client LLMClient, guidance, model string, i jsonschema.Definition, dst any) error { return GenerateTypedJSONWithConversation(ctx, client, []openai.ChatCompletionMessage{ { Role: "user", @@ -19,7 +19,7 @@ func GenerateTypedJSONWithGuidance(ctx context.Context, client *openai.Client, g }, model, i, dst) } -func GenerateTypedJSONWithConversation(ctx context.Context, client *openai.Client, conv []openai.ChatCompletionMessage, model string, i jsonschema.Definition, dst any) error { +func GenerateTypedJSONWithConversation(ctx context.Context, client LLMClient, conv []openai.ChatCompletionMessage, model string, i jsonschema.Definition, dst any) error { toolName := "json" decision := openai.ChatCompletionRequest{ Model: model, diff --git a/pkg/llm/mock_client.go b/pkg/llm/mock_client.go new file mode 100644 index 0000000..52bc527 --- /dev/null +++ b/pkg/llm/mock_client.go @@ -0,0 +1,25 @@ +package llm + +import ( + "context" + "github.com/sashabaranov/go-openai" +) + +type MockClient struct { + CreateChatCompletionFunc func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) + CreateImageFunc func(ctx context.Context, req openai.ImageRequest) (openai.ImageResponse, error) +} + +func (m *MockClient) CreateChatCompletion(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) { + if m.CreateChatCompletionFunc != nil { + return m.CreateChatCompletionFunc(ctx, req) + } + return openai.ChatCompletionResponse{}, nil +} + +func (m *MockClient) CreateImage(ctx context.Context, req openai.ImageRequest) (openai.ImageResponse, error) { + if m.CreateImageFunc != nil { + return m.CreateImageFunc(ctx, req) + } + return openai.ImageResponse{}, nil +} diff --git a/services/filters/classifier.go b/services/filters/classifier.go index e517a7c..840663c 100644 --- a/services/filters/classifier.go +++ b/services/filters/classifier.go @@ -8,7 +8,6 @@ import ( "github.com/mudler/LocalAGI/core/types" "github.com/mudler/LocalAGI/pkg/config" "github.com/mudler/LocalAGI/pkg/llm" - "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/jsonschema" ) @@ -16,7 +15,7 @@ const FilterClassifier = "classifier" type ClassifierFilter struct { name string - client *openai.Client + client llm.LLMClient model string description string allowOnMatch bool