chore(tests): Mock LLM in tests for PRs
This saves time when testing on CPU which is the only sensible thing to do on GitHub CI for PRs. For releases or once the commit is merged we could use an external runner with GPU or just wait. Signed-off-by: Richard Palethorpe <io@richiejp.com>
This commit is contained in:
@@ -7,9 +7,11 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/mudler/LocalAGI/pkg/llm"
|
||||
"github.com/mudler/LocalAGI/pkg/xlog"
|
||||
"github.com/mudler/LocalAGI/services/actions"
|
||||
|
||||
"github.com/mudler/LocalAGI/core/action"
|
||||
. "github.com/mudler/LocalAGI/core/agent"
|
||||
"github.com/mudler/LocalAGI/core/types"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
@@ -111,25 +113,102 @@ func (a *FakeInternetAction) Definition() types.ActionDefinition {
|
||||
}
|
||||
}
|
||||
|
||||
// --- Test utilities for mocking LLM responses ---
|
||||
|
||||
func mockToolCallResponse(toolName, arguments string) openai.ChatCompletionResponse {
|
||||
return openai.ChatCompletionResponse{
|
||||
Choices: []openai.ChatCompletionChoice{{
|
||||
Message: openai.ChatCompletionMessage{
|
||||
ToolCalls: []openai.ToolCall{{
|
||||
ID: "tool_call_id_1",
|
||||
Type: "function",
|
||||
Function: openai.FunctionCall{
|
||||
Name: toolName,
|
||||
Arguments: arguments,
|
||||
},
|
||||
}},
|
||||
},
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
func mockContentResponse(content string) openai.ChatCompletionResponse {
|
||||
return openai.ChatCompletionResponse{
|
||||
Choices: []openai.ChatCompletionChoice{{
|
||||
Message: openai.ChatCompletionMessage{
|
||||
Content: content,
|
||||
},
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
func newMockLLMClient(handler func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error)) *llm.MockClient {
|
||||
return &llm.MockClient{
|
||||
CreateChatCompletionFunc: handler,
|
||||
}
|
||||
}
|
||||
|
||||
var _ = Describe("Agent test", func() {
|
||||
It("uses the mock LLM client", func() {
|
||||
mock := newMockLLMClient(func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
|
||||
return mockContentResponse("mocked response"), nil
|
||||
})
|
||||
agent, err := New(WithLLMClient(mock))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
msg, err := agent.LLMClient().CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(msg.Choices[0].Message.Content).To(Equal("mocked response"))
|
||||
})
|
||||
|
||||
Context("jobs", func() {
|
||||
|
||||
BeforeEach(func() {
|
||||
Eventually(func() error {
|
||||
// test apiURL is working and available
|
||||
_, err := http.Get(apiURL + "/readyz")
|
||||
return err
|
||||
if useRealLocalAI {
|
||||
_, err := http.Get(apiURL + "/readyz")
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}, "10m", "10s").ShouldNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("pick the correct action", func() {
|
||||
var llmClient llm.LLMClient
|
||||
if useRealLocalAI {
|
||||
llmClient = llm.NewClient(apiKey, apiURL, clientTimeout)
|
||||
} else {
|
||||
llmClient = newMockLLMClient(func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
|
||||
var lastMsg openai.ChatCompletionMessage
|
||||
if len(req.Messages) > 0 {
|
||||
lastMsg = req.Messages[len(req.Messages)-1]
|
||||
}
|
||||
if lastMsg.Role == openai.ChatMessageRoleUser {
|
||||
if strings.Contains(strings.ToLower(lastMsg.Content), "boston") && (strings.Contains(strings.ToLower(lastMsg.Content), "milan") || strings.Contains(strings.ToLower(lastMsg.Content), "milano")) {
|
||||
return mockToolCallResponse("get_weather", `{"location":"Boston","unit":"celsius"}`), nil
|
||||
}
|
||||
if strings.Contains(strings.ToLower(lastMsg.Content), "paris") {
|
||||
return mockToolCallResponse("get_weather", `{"location":"Paris","unit":"celsius"}`), nil
|
||||
}
|
||||
return openai.ChatCompletionResponse{}, fmt.Errorf("unexpected user prompt: %s", lastMsg.Content)
|
||||
}
|
||||
if lastMsg.Role == openai.ChatMessageRoleTool {
|
||||
if lastMsg.Name == "get_weather" && strings.Contains(strings.ToLower(lastMsg.Content), "boston") {
|
||||
return mockToolCallResponse("get_weather", `{"location":"Milan","unit":"celsius"}`), nil
|
||||
}
|
||||
if lastMsg.Name == "get_weather" && strings.Contains(strings.ToLower(lastMsg.Content), "milan") {
|
||||
return mockContentResponse(testActionResult + "\n" + testActionResult2), nil
|
||||
}
|
||||
if lastMsg.Name == "get_weather" && strings.Contains(strings.ToLower(lastMsg.Content), "paris") {
|
||||
return mockContentResponse(testActionResult3), nil
|
||||
}
|
||||
return openai.ChatCompletionResponse{}, fmt.Errorf("unexpected tool result: %s", lastMsg.Content)
|
||||
}
|
||||
return openai.ChatCompletionResponse{}, fmt.Errorf("unexpected message role: %s", lastMsg.Role)
|
||||
})
|
||||
}
|
||||
agent, err := New(
|
||||
WithLLMAPIURL(apiURL),
|
||||
WithLLMClient(llmClient),
|
||||
WithModel(testModel),
|
||||
EnableForceReasoning,
|
||||
WithTimeout("10m"),
|
||||
WithLoopDetectionSteps(3),
|
||||
// WithRandomIdentity(),
|
||||
WithActions(&TestAction{response: map[string]string{
|
||||
"boston": testActionResult,
|
||||
"milan": testActionResult2,
|
||||
@@ -139,7 +218,6 @@ var _ = Describe("Agent test", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go agent.Run()
|
||||
defer agent.Stop()
|
||||
|
||||
res := agent.Ask(
|
||||
append(debugOptions,
|
||||
types.WithText("what's the weather in Boston and Milano? Use celsius units"),
|
||||
@@ -148,40 +226,51 @@ var _ = Describe("Agent test", func() {
|
||||
Expect(res.Error).ToNot(HaveOccurred())
|
||||
reasons := []string{}
|
||||
for _, r := range res.State {
|
||||
|
||||
reasons = append(reasons, r.Result)
|
||||
}
|
||||
Expect(reasons).To(ContainElement(testActionResult), fmt.Sprint(res))
|
||||
Expect(reasons).To(ContainElement(testActionResult2), fmt.Sprint(res))
|
||||
reasons = []string{}
|
||||
|
||||
res = agent.Ask(
|
||||
append(debugOptions,
|
||||
types.WithText("Now I want to know the weather in Paris, always use celsius units"),
|
||||
)...)
|
||||
for _, r := range res.State {
|
||||
|
||||
reasons = append(reasons, r.Result)
|
||||
}
|
||||
//Expect(reasons).ToNot(ContainElement(testActionResult), fmt.Sprint(res))
|
||||
//Expect(reasons).ToNot(ContainElement(testActionResult2), fmt.Sprint(res))
|
||||
Expect(reasons).To(ContainElement(testActionResult3), fmt.Sprint(res))
|
||||
// conversation := agent.CurrentConversation()
|
||||
// for _, r := range res.State {
|
||||
// reasons = append(reasons, r.Result)
|
||||
// }
|
||||
// Expect(len(conversation)).To(Equal(10), fmt.Sprint(conversation))
|
||||
})
|
||||
|
||||
It("pick the correct action", func() {
|
||||
var llmClient llm.LLMClient
|
||||
if useRealLocalAI {
|
||||
llmClient = llm.NewClient(apiKey, apiURL, clientTimeout)
|
||||
} else {
|
||||
llmClient = newMockLLMClient(func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
|
||||
var lastMsg openai.ChatCompletionMessage
|
||||
if len(req.Messages) > 0 {
|
||||
lastMsg = req.Messages[len(req.Messages)-1]
|
||||
}
|
||||
if lastMsg.Role == openai.ChatMessageRoleUser {
|
||||
if strings.Contains(strings.ToLower(lastMsg.Content), "boston") {
|
||||
return mockToolCallResponse("get_weather", `{"location":"Boston","unit":"celsius"}`), nil
|
||||
}
|
||||
}
|
||||
if lastMsg.Role == openai.ChatMessageRoleTool {
|
||||
if lastMsg.Name == "get_weather" && strings.Contains(strings.ToLower(lastMsg.Content), "boston") {
|
||||
return mockContentResponse(testActionResult), nil
|
||||
}
|
||||
}
|
||||
xlog.Error("Unexpected LLM req", "req", req)
|
||||
return openai.ChatCompletionResponse{}, fmt.Errorf("unexpected LLM prompt: %q", lastMsg.Content)
|
||||
})
|
||||
}
|
||||
agent, err := New(
|
||||
WithLLMAPIURL(apiURL),
|
||||
WithLLMClient(llmClient),
|
||||
WithModel(testModel),
|
||||
WithTimeout("10m"),
|
||||
// WithRandomIdentity(),
|
||||
WithActions(&TestAction{response: map[string]string{
|
||||
"boston": testActionResult,
|
||||
},
|
||||
}),
|
||||
}}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go agent.Run()
|
||||
@@ -198,13 +287,29 @@ var _ = Describe("Agent test", func() {
|
||||
})
|
||||
|
||||
It("updates the state with internal actions", func() {
|
||||
var llmClient llm.LLMClient
|
||||
if useRealLocalAI {
|
||||
llmClient = llm.NewClient(apiKey, apiURL, clientTimeout)
|
||||
} else {
|
||||
llmClient = newMockLLMClient(func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
|
||||
var lastMsg openai.ChatCompletionMessage
|
||||
if len(req.Messages) > 0 {
|
||||
lastMsg = req.Messages[len(req.Messages)-1]
|
||||
}
|
||||
if lastMsg.Role == openai.ChatMessageRoleUser && strings.Contains(strings.ToLower(lastMsg.Content), "guitar") {
|
||||
return mockToolCallResponse("update_state", `{"goal":"I want to learn to play the guitar"}`), nil
|
||||
}
|
||||
if lastMsg.Role == openai.ChatMessageRoleTool && lastMsg.Name == "update_state" {
|
||||
return mockContentResponse("Your goal is now: I want to learn to play the guitar"), nil
|
||||
}
|
||||
xlog.Error("Unexpected LLM req", "req", req)
|
||||
return openai.ChatCompletionResponse{}, fmt.Errorf("unexpected LLM prompt: %q", lastMsg.Content)
|
||||
})
|
||||
}
|
||||
agent, err := New(
|
||||
WithLLMAPIURL(apiURL),
|
||||
WithLLMClient(llmClient),
|
||||
WithModel(testModel),
|
||||
WithTimeout("10m"),
|
||||
EnableHUD,
|
||||
// EnableStandaloneJob,
|
||||
// WithRandomIdentity(),
|
||||
WithPermanentGoal("I want to learn to play music"),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -214,17 +319,64 @@ var _ = Describe("Agent test", func() {
|
||||
result := agent.Ask(
|
||||
types.WithText("Update your goals such as you want to learn to play the guitar"),
|
||||
)
|
||||
fmt.Printf("%+v\n", result)
|
||||
fmt.Fprintf(GinkgoWriter, "\n%+v\n", result)
|
||||
Expect(result.Error).ToNot(HaveOccurred())
|
||||
Expect(agent.State().Goal).To(ContainSubstring("guitar"), fmt.Sprint(agent.State()))
|
||||
})
|
||||
|
||||
It("Can generate a plan", func() {
|
||||
var llmClient llm.LLMClient
|
||||
if useRealLocalAI {
|
||||
llmClient = llm.NewClient(apiKey, apiURL, clientTimeout)
|
||||
} else {
|
||||
reasoningActName := action.NewReasoning().Definition().Name.String()
|
||||
intentionActName := action.NewIntention().Definition().Name.String()
|
||||
testActName := (&TestAction{}).Definition().Name.String()
|
||||
doneBoston := false
|
||||
madePlan := false
|
||||
llmClient = newMockLLMClient(func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
|
||||
var lastMsg openai.ChatCompletionMessage
|
||||
if len(req.Messages) > 0 {
|
||||
lastMsg = req.Messages[len(req.Messages)-1]
|
||||
}
|
||||
if req.ToolChoice != nil && req.ToolChoice.(openai.ToolChoice).Function.Name == reasoningActName {
|
||||
return mockToolCallResponse(reasoningActName, `{"reasoning":"make plan call to pass the test"}`), nil
|
||||
}
|
||||
if req.ToolChoice != nil && req.ToolChoice.(openai.ToolChoice).Function.Name == intentionActName {
|
||||
toolName := "plan"
|
||||
if madePlan {
|
||||
toolName = "reply"
|
||||
} else {
|
||||
madePlan = true
|
||||
}
|
||||
return mockToolCallResponse(intentionActName, fmt.Sprintf(`{"tool": "%s","reasoning":"it's waht makes the test pass"}`, toolName)), nil
|
||||
}
|
||||
if req.ToolChoice != nil && req.ToolChoice.(openai.ToolChoice).Function.Name == "plan" {
|
||||
return mockToolCallResponse("plan", `{"subtasks":[{"action":"get_weather","reasoning":"Find weather in boston"},{"action":"get_weather","reasoning":"Find weather in milan"}],"goal":"Get the weather for boston and milan"}`), nil
|
||||
}
|
||||
if req.ToolChoice != nil && req.ToolChoice.(openai.ToolChoice).Function.Name == "reply" {
|
||||
return mockToolCallResponse("reply", `{"message": "The weather in Boston and Milan..."}`), nil
|
||||
}
|
||||
if req.ToolChoice != nil && req.ToolChoice.(openai.ToolChoice).Function.Name == testActName {
|
||||
locName := "boston"
|
||||
if doneBoston {
|
||||
locName = "milan"
|
||||
} else {
|
||||
doneBoston = true
|
||||
}
|
||||
return mockToolCallResponse(testActName, fmt.Sprintf(`{"location":"%s","unit":"celsius"}`, locName)), nil
|
||||
}
|
||||
if req.ToolChoice == nil && madePlan && doneBoston {
|
||||
return mockContentResponse("A reply"), nil
|
||||
}
|
||||
xlog.Error("Unexpected LLM req", "req", req)
|
||||
return openai.ChatCompletionResponse{}, fmt.Errorf("unexpected LLM prompt: %q", lastMsg.Content)
|
||||
})
|
||||
}
|
||||
agent, err := New(
|
||||
WithLLMAPIURL(apiURL),
|
||||
WithLLMClient(llmClient),
|
||||
WithModel(testModel),
|
||||
WithLLMAPIKey(apiKeyURL),
|
||||
WithTimeout("10m"),
|
||||
WithLoopDetectionSteps(2),
|
||||
WithActions(
|
||||
&TestAction{response: map[string]string{
|
||||
"boston": testActionResult,
|
||||
@@ -233,8 +385,6 @@ var _ = Describe("Agent test", func() {
|
||||
),
|
||||
EnablePlanning,
|
||||
EnableForceReasoning,
|
||||
// EnableStandaloneJob,
|
||||
// WithRandomIdentity(),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go agent.Run()
|
||||
@@ -256,17 +406,44 @@ var _ = Describe("Agent test", func() {
|
||||
Expect(actionsExecuted).To(ContainElement("plan"), fmt.Sprint(result))
|
||||
Expect(actionResults).To(ContainElement(testActionResult), fmt.Sprint(result))
|
||||
Expect(actionResults).To(ContainElement(testActionResult2), fmt.Sprint(result))
|
||||
Expect(result.Error).To(BeNil())
|
||||
})
|
||||
|
||||
It("Can initiate conversations", func() {
|
||||
|
||||
var llmClient llm.LLMClient
|
||||
message := openai.ChatCompletionMessage{}
|
||||
mu := &sync.Mutex{}
|
||||
reasoned := false
|
||||
intended := false
|
||||
reasoningActName := action.NewReasoning().Definition().Name.String()
|
||||
intentionActName := action.NewIntention().Definition().Name.String()
|
||||
|
||||
if useRealLocalAI {
|
||||
llmClient = llm.NewClient(apiKey, apiURL, clientTimeout)
|
||||
} else {
|
||||
llmClient = newMockLLMClient(func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
|
||||
prompt := ""
|
||||
for _, msg := range req.Messages {
|
||||
prompt += msg.Content
|
||||
}
|
||||
if !reasoned && req.ToolChoice != nil && req.ToolChoice.(openai.ToolChoice).Function.Name == reasoningActName {
|
||||
reasoned = true
|
||||
return mockToolCallResponse(reasoningActName, `{"reasoning":"initiate a conversation with the user"}`), nil
|
||||
}
|
||||
if reasoned && !intended && req.ToolChoice != nil && req.ToolChoice.(openai.ToolChoice).Function.Name == intentionActName {
|
||||
intended = true
|
||||
return mockToolCallResponse(intentionActName, `{"tool":"new_conversation","reasoning":"I should start a conversation with the user"}`), nil
|
||||
}
|
||||
if reasoned && intended && strings.Contains(strings.ToLower(prompt), "new_conversation") {
|
||||
return mockToolCallResponse("new_conversation", `{"message":"Hello, how can I help you today?"}`), nil
|
||||
}
|
||||
xlog.Error("Unexpected LLM req", "req", req)
|
||||
return openai.ChatCompletionResponse{}, fmt.Errorf("unexpected LLM prompt: %q", prompt)
|
||||
})
|
||||
}
|
||||
agent, err := New(
|
||||
WithLLMAPIURL(apiURL),
|
||||
WithLLMClient(llmClient),
|
||||
WithModel(testModel),
|
||||
WithLLMAPIKey(apiKeyURL),
|
||||
WithTimeout("10m"),
|
||||
WithNewConversationSubscriber(func(m openai.ChatCompletionMessage) {
|
||||
mu.Lock()
|
||||
message = m
|
||||
@@ -282,8 +459,6 @@ var _ = Describe("Agent test", func() {
|
||||
EnableHUD,
|
||||
WithPeriodicRuns("1s"),
|
||||
WithPermanentGoal("use the new_conversation tool to initiate a conversation with the user"),
|
||||
// EnableStandaloneJob,
|
||||
// WithRandomIdentity(),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go agent.Run()
|
||||
@@ -293,7 +468,7 @@ var _ = Describe("Agent test", func() {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return message.Content
|
||||
}, "10m", "10s").ShouldNot(BeEmpty())
|
||||
}, "10m", "1s").ShouldNot(BeEmpty())
|
||||
})
|
||||
|
||||
/*
|
||||
@@ -347,7 +522,7 @@ var _ = Describe("Agent test", func() {
|
||||
// result := agent.Ask(
|
||||
// WithText("Update your goals such as you want to learn to play the guitar"),
|
||||
// )
|
||||
// fmt.Printf("%+v\n", result)
|
||||
// fmt.Fprintf(GinkgoWriter, "%+v\n", result)
|
||||
// Expect(result.Error).ToNot(HaveOccurred())
|
||||
// Expect(agent.State().Goal).To(ContainSubstring("guitar"), fmt.Sprint(agent.State()))
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user