chore(tests): Mock LLM in tests for PRs
This saves time when testing on CPU which is the only sensible thing to do on GitHub CI for PRs. For releases or once the commit is merged we could use an external runner with GPU or just wait. Signed-off-by: Richard Palethorpe <io@richiejp.com>
This commit is contained in:
@@ -29,7 +29,7 @@ type Agent struct {
|
||||
sync.Mutex
|
||||
options *options
|
||||
Character Character
|
||||
client *openai.Client
|
||||
client llm.LLMClient
|
||||
jobQueue chan *types.Job
|
||||
context *types.ActionContext
|
||||
|
||||
@@ -63,7 +63,12 @@ func New(opts ...Option) (*Agent, error) {
|
||||
return nil, fmt.Errorf("failed to set options: %v", err)
|
||||
}
|
||||
|
||||
client := llm.NewClient(options.LLMAPI.APIKey, options.LLMAPI.APIURL, options.timeout)
|
||||
var client llm.LLMClient
|
||||
if options.llmClient != nil {
|
||||
client = options.llmClient
|
||||
} else {
|
||||
client = llm.NewClient(options.LLMAPI.APIKey, options.LLMAPI.APIURL, options.timeout)
|
||||
}
|
||||
|
||||
c := context.Background()
|
||||
if options.context != nil {
|
||||
@@ -125,6 +130,11 @@ func (a *Agent) SharedState() *types.AgentSharedState {
|
||||
return a.sharedState
|
||||
}
|
||||
|
||||
// LLMClient returns the agent's LLM client (for testing)
|
||||
func (a *Agent) LLMClient() llm.LLMClient {
|
||||
return a.client
|
||||
}
|
||||
|
||||
func (a *Agent) startNewConversationsConsumer() {
|
||||
go func() {
|
||||
for {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package agent_test
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
@@ -13,15 +14,19 @@ func TestAgent(t *testing.T) {
|
||||
RunSpecs(t, "Agent test suite")
|
||||
}
|
||||
|
||||
var testModel = os.Getenv("LOCALAGI_MODEL")
|
||||
var apiURL = os.Getenv("LOCALAI_API_URL")
|
||||
var apiKeyURL = os.Getenv("LOCALAI_API_KEY")
|
||||
var (
|
||||
testModel = os.Getenv("LOCALAGI_MODEL")
|
||||
apiURL = os.Getenv("LOCALAI_API_URL")
|
||||
apiKey = os.Getenv("LOCALAI_API_KEY")
|
||||
useRealLocalAI bool
|
||||
clientTimeout = "10m"
|
||||
)
|
||||
|
||||
func isValidURL(u string) bool {
|
||||
parsed, err := url.ParseRequestURI(u)
|
||||
return err == nil && parsed.Scheme != "" && parsed.Host != ""
|
||||
}
|
||||
|
||||
func init() {
|
||||
if testModel == "" {
|
||||
testModel = "hermes-2-pro-mistral"
|
||||
}
|
||||
if apiURL == "" {
|
||||
apiURL = "http://192.168.68.113:8080"
|
||||
}
|
||||
useRealLocalAI = isValidURL(apiURL) && apiURL != "" && testModel != ""
|
||||
}
|
||||
|
||||
@@ -7,9 +7,11 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/mudler/LocalAGI/pkg/llm"
|
||||
"github.com/mudler/LocalAGI/pkg/xlog"
|
||||
"github.com/mudler/LocalAGI/services/actions"
|
||||
|
||||
"github.com/mudler/LocalAGI/core/action"
|
||||
. "github.com/mudler/LocalAGI/core/agent"
|
||||
"github.com/mudler/LocalAGI/core/types"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
@@ -111,25 +113,102 @@ func (a *FakeInternetAction) Definition() types.ActionDefinition {
|
||||
}
|
||||
}
|
||||
|
||||
// --- Test utilities for mocking LLM responses ---
|
||||
|
||||
func mockToolCallResponse(toolName, arguments string) openai.ChatCompletionResponse {
|
||||
return openai.ChatCompletionResponse{
|
||||
Choices: []openai.ChatCompletionChoice{{
|
||||
Message: openai.ChatCompletionMessage{
|
||||
ToolCalls: []openai.ToolCall{{
|
||||
ID: "tool_call_id_1",
|
||||
Type: "function",
|
||||
Function: openai.FunctionCall{
|
||||
Name: toolName,
|
||||
Arguments: arguments,
|
||||
},
|
||||
}},
|
||||
},
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
func mockContentResponse(content string) openai.ChatCompletionResponse {
|
||||
return openai.ChatCompletionResponse{
|
||||
Choices: []openai.ChatCompletionChoice{{
|
||||
Message: openai.ChatCompletionMessage{
|
||||
Content: content,
|
||||
},
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
func newMockLLMClient(handler func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error)) *llm.MockClient {
|
||||
return &llm.MockClient{
|
||||
CreateChatCompletionFunc: handler,
|
||||
}
|
||||
}
|
||||
|
||||
var _ = Describe("Agent test", func() {
|
||||
It("uses the mock LLM client", func() {
|
||||
mock := newMockLLMClient(func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
|
||||
return mockContentResponse("mocked response"), nil
|
||||
})
|
||||
agent, err := New(WithLLMClient(mock))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
msg, err := agent.LLMClient().CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(msg.Choices[0].Message.Content).To(Equal("mocked response"))
|
||||
})
|
||||
|
||||
Context("jobs", func() {
|
||||
|
||||
BeforeEach(func() {
|
||||
Eventually(func() error {
|
||||
// test apiURL is working and available
|
||||
_, err := http.Get(apiURL + "/readyz")
|
||||
return err
|
||||
if useRealLocalAI {
|
||||
_, err := http.Get(apiURL + "/readyz")
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}, "10m", "10s").ShouldNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("pick the correct action", func() {
|
||||
var llmClient llm.LLMClient
|
||||
if useRealLocalAI {
|
||||
llmClient = llm.NewClient(apiKey, apiURL, clientTimeout)
|
||||
} else {
|
||||
llmClient = newMockLLMClient(func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
|
||||
var lastMsg openai.ChatCompletionMessage
|
||||
if len(req.Messages) > 0 {
|
||||
lastMsg = req.Messages[len(req.Messages)-1]
|
||||
}
|
||||
if lastMsg.Role == openai.ChatMessageRoleUser {
|
||||
if strings.Contains(strings.ToLower(lastMsg.Content), "boston") && (strings.Contains(strings.ToLower(lastMsg.Content), "milan") || strings.Contains(strings.ToLower(lastMsg.Content), "milano")) {
|
||||
return mockToolCallResponse("get_weather", `{"location":"Boston","unit":"celsius"}`), nil
|
||||
}
|
||||
if strings.Contains(strings.ToLower(lastMsg.Content), "paris") {
|
||||
return mockToolCallResponse("get_weather", `{"location":"Paris","unit":"celsius"}`), nil
|
||||
}
|
||||
return openai.ChatCompletionResponse{}, fmt.Errorf("unexpected user prompt: %s", lastMsg.Content)
|
||||
}
|
||||
if lastMsg.Role == openai.ChatMessageRoleTool {
|
||||
if lastMsg.Name == "get_weather" && strings.Contains(strings.ToLower(lastMsg.Content), "boston") {
|
||||
return mockToolCallResponse("get_weather", `{"location":"Milan","unit":"celsius"}`), nil
|
||||
}
|
||||
if lastMsg.Name == "get_weather" && strings.Contains(strings.ToLower(lastMsg.Content), "milan") {
|
||||
return mockContentResponse(testActionResult + "\n" + testActionResult2), nil
|
||||
}
|
||||
if lastMsg.Name == "get_weather" && strings.Contains(strings.ToLower(lastMsg.Content), "paris") {
|
||||
return mockContentResponse(testActionResult3), nil
|
||||
}
|
||||
return openai.ChatCompletionResponse{}, fmt.Errorf("unexpected tool result: %s", lastMsg.Content)
|
||||
}
|
||||
return openai.ChatCompletionResponse{}, fmt.Errorf("unexpected message role: %s", lastMsg.Role)
|
||||
})
|
||||
}
|
||||
agent, err := New(
|
||||
WithLLMAPIURL(apiURL),
|
||||
WithLLMClient(llmClient),
|
||||
WithModel(testModel),
|
||||
EnableForceReasoning,
|
||||
WithTimeout("10m"),
|
||||
WithLoopDetectionSteps(3),
|
||||
// WithRandomIdentity(),
|
||||
WithActions(&TestAction{response: map[string]string{
|
||||
"boston": testActionResult,
|
||||
"milan": testActionResult2,
|
||||
@@ -139,7 +218,6 @@ var _ = Describe("Agent test", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go agent.Run()
|
||||
defer agent.Stop()
|
||||
|
||||
res := agent.Ask(
|
||||
append(debugOptions,
|
||||
types.WithText("what's the weather in Boston and Milano? Use celsius units"),
|
||||
@@ -148,40 +226,51 @@ var _ = Describe("Agent test", func() {
|
||||
Expect(res.Error).ToNot(HaveOccurred())
|
||||
reasons := []string{}
|
||||
for _, r := range res.State {
|
||||
|
||||
reasons = append(reasons, r.Result)
|
||||
}
|
||||
Expect(reasons).To(ContainElement(testActionResult), fmt.Sprint(res))
|
||||
Expect(reasons).To(ContainElement(testActionResult2), fmt.Sprint(res))
|
||||
reasons = []string{}
|
||||
|
||||
res = agent.Ask(
|
||||
append(debugOptions,
|
||||
types.WithText("Now I want to know the weather in Paris, always use celsius units"),
|
||||
)...)
|
||||
for _, r := range res.State {
|
||||
|
||||
reasons = append(reasons, r.Result)
|
||||
}
|
||||
//Expect(reasons).ToNot(ContainElement(testActionResult), fmt.Sprint(res))
|
||||
//Expect(reasons).ToNot(ContainElement(testActionResult2), fmt.Sprint(res))
|
||||
Expect(reasons).To(ContainElement(testActionResult3), fmt.Sprint(res))
|
||||
// conversation := agent.CurrentConversation()
|
||||
// for _, r := range res.State {
|
||||
// reasons = append(reasons, r.Result)
|
||||
// }
|
||||
// Expect(len(conversation)).To(Equal(10), fmt.Sprint(conversation))
|
||||
})
|
||||
|
||||
It("pick the correct action", func() {
|
||||
var llmClient llm.LLMClient
|
||||
if useRealLocalAI {
|
||||
llmClient = llm.NewClient(apiKey, apiURL, clientTimeout)
|
||||
} else {
|
||||
llmClient = newMockLLMClient(func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
|
||||
var lastMsg openai.ChatCompletionMessage
|
||||
if len(req.Messages) > 0 {
|
||||
lastMsg = req.Messages[len(req.Messages)-1]
|
||||
}
|
||||
if lastMsg.Role == openai.ChatMessageRoleUser {
|
||||
if strings.Contains(strings.ToLower(lastMsg.Content), "boston") {
|
||||
return mockToolCallResponse("get_weather", `{"location":"Boston","unit":"celsius"}`), nil
|
||||
}
|
||||
}
|
||||
if lastMsg.Role == openai.ChatMessageRoleTool {
|
||||
if lastMsg.Name == "get_weather" && strings.Contains(strings.ToLower(lastMsg.Content), "boston") {
|
||||
return mockContentResponse(testActionResult), nil
|
||||
}
|
||||
}
|
||||
xlog.Error("Unexpected LLM req", "req", req)
|
||||
return openai.ChatCompletionResponse{}, fmt.Errorf("unexpected LLM prompt: %q", lastMsg.Content)
|
||||
})
|
||||
}
|
||||
agent, err := New(
|
||||
WithLLMAPIURL(apiURL),
|
||||
WithLLMClient(llmClient),
|
||||
WithModel(testModel),
|
||||
WithTimeout("10m"),
|
||||
// WithRandomIdentity(),
|
||||
WithActions(&TestAction{response: map[string]string{
|
||||
"boston": testActionResult,
|
||||
},
|
||||
}),
|
||||
}}),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go agent.Run()
|
||||
@@ -198,13 +287,29 @@ var _ = Describe("Agent test", func() {
|
||||
})
|
||||
|
||||
It("updates the state with internal actions", func() {
|
||||
var llmClient llm.LLMClient
|
||||
if useRealLocalAI {
|
||||
llmClient = llm.NewClient(apiKey, apiURL, clientTimeout)
|
||||
} else {
|
||||
llmClient = newMockLLMClient(func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
|
||||
var lastMsg openai.ChatCompletionMessage
|
||||
if len(req.Messages) > 0 {
|
||||
lastMsg = req.Messages[len(req.Messages)-1]
|
||||
}
|
||||
if lastMsg.Role == openai.ChatMessageRoleUser && strings.Contains(strings.ToLower(lastMsg.Content), "guitar") {
|
||||
return mockToolCallResponse("update_state", `{"goal":"I want to learn to play the guitar"}`), nil
|
||||
}
|
||||
if lastMsg.Role == openai.ChatMessageRoleTool && lastMsg.Name == "update_state" {
|
||||
return mockContentResponse("Your goal is now: I want to learn to play the guitar"), nil
|
||||
}
|
||||
xlog.Error("Unexpected LLM req", "req", req)
|
||||
return openai.ChatCompletionResponse{}, fmt.Errorf("unexpected LLM prompt: %q", lastMsg.Content)
|
||||
})
|
||||
}
|
||||
agent, err := New(
|
||||
WithLLMAPIURL(apiURL),
|
||||
WithLLMClient(llmClient),
|
||||
WithModel(testModel),
|
||||
WithTimeout("10m"),
|
||||
EnableHUD,
|
||||
// EnableStandaloneJob,
|
||||
// WithRandomIdentity(),
|
||||
WithPermanentGoal("I want to learn to play music"),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -214,17 +319,64 @@ var _ = Describe("Agent test", func() {
|
||||
result := agent.Ask(
|
||||
types.WithText("Update your goals such as you want to learn to play the guitar"),
|
||||
)
|
||||
fmt.Printf("%+v\n", result)
|
||||
fmt.Fprintf(GinkgoWriter, "\n%+v\n", result)
|
||||
Expect(result.Error).ToNot(HaveOccurred())
|
||||
Expect(agent.State().Goal).To(ContainSubstring("guitar"), fmt.Sprint(agent.State()))
|
||||
})
|
||||
|
||||
It("Can generate a plan", func() {
|
||||
var llmClient llm.LLMClient
|
||||
if useRealLocalAI {
|
||||
llmClient = llm.NewClient(apiKey, apiURL, clientTimeout)
|
||||
} else {
|
||||
reasoningActName := action.NewReasoning().Definition().Name.String()
|
||||
intentionActName := action.NewIntention().Definition().Name.String()
|
||||
testActName := (&TestAction{}).Definition().Name.String()
|
||||
doneBoston := false
|
||||
madePlan := false
|
||||
llmClient = newMockLLMClient(func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
|
||||
var lastMsg openai.ChatCompletionMessage
|
||||
if len(req.Messages) > 0 {
|
||||
lastMsg = req.Messages[len(req.Messages)-1]
|
||||
}
|
||||
if req.ToolChoice != nil && req.ToolChoice.(openai.ToolChoice).Function.Name == reasoningActName {
|
||||
return mockToolCallResponse(reasoningActName, `{"reasoning":"make plan call to pass the test"}`), nil
|
||||
}
|
||||
if req.ToolChoice != nil && req.ToolChoice.(openai.ToolChoice).Function.Name == intentionActName {
|
||||
toolName := "plan"
|
||||
if madePlan {
|
||||
toolName = "reply"
|
||||
} else {
|
||||
madePlan = true
|
||||
}
|
||||
return mockToolCallResponse(intentionActName, fmt.Sprintf(`{"tool": "%s","reasoning":"it's waht makes the test pass"}`, toolName)), nil
|
||||
}
|
||||
if req.ToolChoice != nil && req.ToolChoice.(openai.ToolChoice).Function.Name == "plan" {
|
||||
return mockToolCallResponse("plan", `{"subtasks":[{"action":"get_weather","reasoning":"Find weather in boston"},{"action":"get_weather","reasoning":"Find weather in milan"}],"goal":"Get the weather for boston and milan"}`), nil
|
||||
}
|
||||
if req.ToolChoice != nil && req.ToolChoice.(openai.ToolChoice).Function.Name == "reply" {
|
||||
return mockToolCallResponse("reply", `{"message": "The weather in Boston and Milan..."}`), nil
|
||||
}
|
||||
if req.ToolChoice != nil && req.ToolChoice.(openai.ToolChoice).Function.Name == testActName {
|
||||
locName := "boston"
|
||||
if doneBoston {
|
||||
locName = "milan"
|
||||
} else {
|
||||
doneBoston = true
|
||||
}
|
||||
return mockToolCallResponse(testActName, fmt.Sprintf(`{"location":"%s","unit":"celsius"}`, locName)), nil
|
||||
}
|
||||
if req.ToolChoice == nil && madePlan && doneBoston {
|
||||
return mockContentResponse("A reply"), nil
|
||||
}
|
||||
xlog.Error("Unexpected LLM req", "req", req)
|
||||
return openai.ChatCompletionResponse{}, fmt.Errorf("unexpected LLM prompt: %q", lastMsg.Content)
|
||||
})
|
||||
}
|
||||
agent, err := New(
|
||||
WithLLMAPIURL(apiURL),
|
||||
WithLLMClient(llmClient),
|
||||
WithModel(testModel),
|
||||
WithLLMAPIKey(apiKeyURL),
|
||||
WithTimeout("10m"),
|
||||
WithLoopDetectionSteps(2),
|
||||
WithActions(
|
||||
&TestAction{response: map[string]string{
|
||||
"boston": testActionResult,
|
||||
@@ -233,8 +385,6 @@ var _ = Describe("Agent test", func() {
|
||||
),
|
||||
EnablePlanning,
|
||||
EnableForceReasoning,
|
||||
// EnableStandaloneJob,
|
||||
// WithRandomIdentity(),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go agent.Run()
|
||||
@@ -256,17 +406,44 @@ var _ = Describe("Agent test", func() {
|
||||
Expect(actionsExecuted).To(ContainElement("plan"), fmt.Sprint(result))
|
||||
Expect(actionResults).To(ContainElement(testActionResult), fmt.Sprint(result))
|
||||
Expect(actionResults).To(ContainElement(testActionResult2), fmt.Sprint(result))
|
||||
Expect(result.Error).To(BeNil())
|
||||
})
|
||||
|
||||
It("Can initiate conversations", func() {
|
||||
|
||||
var llmClient llm.LLMClient
|
||||
message := openai.ChatCompletionMessage{}
|
||||
mu := &sync.Mutex{}
|
||||
reasoned := false
|
||||
intended := false
|
||||
reasoningActName := action.NewReasoning().Definition().Name.String()
|
||||
intentionActName := action.NewIntention().Definition().Name.String()
|
||||
|
||||
if useRealLocalAI {
|
||||
llmClient = llm.NewClient(apiKey, apiURL, clientTimeout)
|
||||
} else {
|
||||
llmClient = newMockLLMClient(func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
|
||||
prompt := ""
|
||||
for _, msg := range req.Messages {
|
||||
prompt += msg.Content
|
||||
}
|
||||
if !reasoned && req.ToolChoice != nil && req.ToolChoice.(openai.ToolChoice).Function.Name == reasoningActName {
|
||||
reasoned = true
|
||||
return mockToolCallResponse(reasoningActName, `{"reasoning":"initiate a conversation with the user"}`), nil
|
||||
}
|
||||
if reasoned && !intended && req.ToolChoice != nil && req.ToolChoice.(openai.ToolChoice).Function.Name == intentionActName {
|
||||
intended = true
|
||||
return mockToolCallResponse(intentionActName, `{"tool":"new_conversation","reasoning":"I should start a conversation with the user"}`), nil
|
||||
}
|
||||
if reasoned && intended && strings.Contains(strings.ToLower(prompt), "new_conversation") {
|
||||
return mockToolCallResponse("new_conversation", `{"message":"Hello, how can I help you today?"}`), nil
|
||||
}
|
||||
xlog.Error("Unexpected LLM req", "req", req)
|
||||
return openai.ChatCompletionResponse{}, fmt.Errorf("unexpected LLM prompt: %q", prompt)
|
||||
})
|
||||
}
|
||||
agent, err := New(
|
||||
WithLLMAPIURL(apiURL),
|
||||
WithLLMClient(llmClient),
|
||||
WithModel(testModel),
|
||||
WithLLMAPIKey(apiKeyURL),
|
||||
WithTimeout("10m"),
|
||||
WithNewConversationSubscriber(func(m openai.ChatCompletionMessage) {
|
||||
mu.Lock()
|
||||
message = m
|
||||
@@ -282,8 +459,6 @@ var _ = Describe("Agent test", func() {
|
||||
EnableHUD,
|
||||
WithPeriodicRuns("1s"),
|
||||
WithPermanentGoal("use the new_conversation tool to initiate a conversation with the user"),
|
||||
// EnableStandaloneJob,
|
||||
// WithRandomIdentity(),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go agent.Run()
|
||||
@@ -293,7 +468,7 @@ var _ = Describe("Agent test", func() {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return message.Content
|
||||
}, "10m", "10s").ShouldNot(BeEmpty())
|
||||
}, "10m", "1s").ShouldNot(BeEmpty())
|
||||
})
|
||||
|
||||
/*
|
||||
@@ -347,7 +522,7 @@ var _ = Describe("Agent test", func() {
|
||||
// result := agent.Ask(
|
||||
// WithText("Update your goals such as you want to learn to play the guitar"),
|
||||
// )
|
||||
// fmt.Printf("%+v\n", result)
|
||||
// fmt.Fprintf(GinkgoWriter, "%+v\n", result)
|
||||
// Expect(result.Error).ToNot(HaveOccurred())
|
||||
// Expect(agent.State().Goal).To(ContainSubstring("guitar"), fmt.Sprint(agent.State()))
|
||||
})
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/mudler/LocalAGI/core/types"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"github.com/mudler/LocalAGI/pkg/llm"
|
||||
)
|
||||
|
||||
type Option func(*options) error
|
||||
@@ -19,6 +20,7 @@ type llmOptions struct {
|
||||
}
|
||||
|
||||
type options struct {
|
||||
llmClient llm.LLMClient
|
||||
LLMAPI llmOptions
|
||||
character Character
|
||||
randomIdentityGuidance string
|
||||
@@ -68,6 +70,14 @@ type options struct {
|
||||
lastMessageDuration time.Duration
|
||||
}
|
||||
|
||||
// WithLLMClient allows injecting a custom LLM client (e.g. for testing)
|
||||
func WithLLMClient(client llm.LLMClient) Option {
|
||||
return func(o *options) error {
|
||||
o.llmClient = client
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (o *options) SeparatedMultimodalModel() bool {
|
||||
return o.LLMAPI.MultimodalModel != "" && o.LLMAPI.Model != o.LLMAPI.MultimodalModel
|
||||
}
|
||||
|
||||
@@ -1,29 +1,57 @@
|
||||
package agent_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAGI/pkg/llm"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
|
||||
. "github.com/mudler/LocalAGI/core/agent"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
)
|
||||
|
||||
var _ = Describe("Agent test", func() {
|
||||
Context("identity", func() {
|
||||
var agent *Agent
|
||||
|
||||
BeforeEach(func() {
|
||||
Eventually(func() error {
|
||||
// test apiURL is working and available
|
||||
_, err := http.Get(apiURL + "/readyz")
|
||||
return err
|
||||
}, "10m", "10s").ShouldNot(HaveOccurred())
|
||||
})
|
||||
// BeforeEach(func() {
|
||||
// Eventually(func() error {
|
||||
// // test apiURL is working and available
|
||||
// _, err := http.Get(apiURL + "/readyz")
|
||||
// return err
|
||||
// }, "10m", "10s").ShouldNot(HaveOccurred())
|
||||
// })
|
||||
|
||||
It("generates all the fields with random data", func() {
|
||||
var llmClient llm.LLMClient
|
||||
if useRealLocalAI {
|
||||
llmClient = llm.NewClient(apiKey, apiURL, testModel)
|
||||
} else {
|
||||
llmClient = &llm.MockClient{
|
||||
CreateChatCompletionFunc: func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
|
||||
return openai.ChatCompletionResponse{
|
||||
Choices: []openai.ChatCompletionChoice{{
|
||||
Message: openai.ChatCompletionMessage{
|
||||
ToolCalls: []openai.ToolCall{{
|
||||
ID: "tool_call_id_1",
|
||||
Type: "function",
|
||||
Function: openai.FunctionCall{
|
||||
Name: "generate_identity",
|
||||
Arguments: `{"name":"John Doe","age":"42","job_occupation":"Engineer","hobbies":["reading","hiking"],"favorites_music_genres":["Jazz"]}`,
|
||||
},
|
||||
}},
|
||||
},
|
||||
}},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
}
|
||||
var err error
|
||||
agent, err = New(
|
||||
WithLLMAPIURL(apiURL),
|
||||
WithLLMClient(llmClient),
|
||||
WithModel(testModel),
|
||||
WithTimeout("10m"),
|
||||
WithRandomIdentity(),
|
||||
@@ -37,14 +65,40 @@ var _ = Describe("Agent test", func() {
|
||||
Expect(agent.Character.MusicTaste).ToNot(BeEmpty())
|
||||
})
|
||||
It("detect an invalid character", func() {
|
||||
mock := &llm.MockClient{
|
||||
CreateChatCompletionFunc: func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
|
||||
return openai.ChatCompletionResponse{}, fmt.Errorf("invalid character")
|
||||
},
|
||||
}
|
||||
var err error
|
||||
agent, err = New(WithRandomIdentity())
|
||||
agent, err = New(
|
||||
WithLLMClient(mock),
|
||||
WithRandomIdentity(),
|
||||
)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
It("generates all the fields", func() {
|
||||
mock := &llm.MockClient{
|
||||
CreateChatCompletionFunc: func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
|
||||
return openai.ChatCompletionResponse{
|
||||
Choices: []openai.ChatCompletionChoice{{
|
||||
Message: openai.ChatCompletionMessage{
|
||||
ToolCalls: []openai.ToolCall{{
|
||||
ID: "tool_call_id_2",
|
||||
Type: "function",
|
||||
Function: openai.FunctionCall{
|
||||
Name: "generate_identity",
|
||||
Arguments: `{"name":"Gandalf","age":"90","job_occupation":"Wizard","hobbies":["magic","reading"],"favorites_music_genres":["Classical"]}`,
|
||||
},
|
||||
}},
|
||||
},
|
||||
}},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
var err error
|
||||
|
||||
agent, err := New(
|
||||
WithLLMClient(mock),
|
||||
WithLLMAPIURL(apiURL),
|
||||
WithModel(testModel),
|
||||
WithRandomIdentity("An 90-year old man with a long beard, a wizard, who lives in a tower."),
|
||||
|
||||
Reference in New Issue
Block a user