Compare commits

...

1 Commits

Author SHA1 Message Date
Richard Palethorpe
5698d0b832 chore(tests): Mock LLM in tests for PRs
This saves time when testing on CPU which is the only sensible thing
to do on GitHub CI for PRs. For releases or once the commit is merged
we could use an external runner with GPU or just wait.

Signed-off-by: Richard Palethorpe <io@richiejp.com>
2025-05-12 13:51:45 +01:00
12 changed files with 429 additions and 72 deletions

View File

@@ -48,7 +48,11 @@ jobs:
- name: Run tests - name: Run tests
run: | run: |
make tests if [[ "$GITHUB_EVENT_NAME" == "pull_request" ]]; then
make tests-mock
else
make tests
fi
#sudo mv coverage/coverage.txt coverage.txt #sudo mv coverage/coverage.txt coverage.txt
#sudo chmod 777 coverage.txt #sudo chmod 777 coverage.txt

49
.github/workflows/tests_fragile.yml vendored Normal file
View File

@@ -0,0 +1,49 @@
name: Run Fragile Go Tests
on:
pull_request:
branches:
- '**'
concurrency:
group: ci-non-blocking-tests-${{ github.head_ref || github.ref }}-${{ github.repository }}
cancel-in-progress: true
jobs:
llm-tests:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- run: |
# Add Docker's official GPG key:
sudo apt-get update
sudo apt-get install -y ca-certificates curl
sudo install -m 0755 -d /etc/apt/keyrings
sudo curl -fsSL https://download.docker.com/linux/ubuntu/gpg -o /etc/apt/keyrings/docker.asc
sudo chmod a+r /etc/apt/keyrings/docker.asc
# Add the repository to Apt sources:
echo \
"deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/ubuntu \
$(. /etc/os-release && echo "${UBUNTU_CODENAME:-$VERSION_CODENAME}") stable" | \
sudo tee /etc/apt/sources.list.d/docker.list > /dev/null
sudo apt-get update
sudo apt-get install -y docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin make
docker version
docker run --rm hello-world
- uses: actions/setup-go@v5
with:
go-version: '>=1.17.0'
- name: Free up disk space
run: |
sudo rm -rf /usr/share/dotnet
sudo rm -rf /usr/local/lib/android
sudo rm -rf /opt/ghc
sudo apt-get clean
docker system prune -af || true
df -h
- name: Run tests
run: |
make tests

View File

@@ -3,6 +3,8 @@ IMAGE_NAME?=webui
MCPBOX_IMAGE_NAME?=mcpbox MCPBOX_IMAGE_NAME?=mcpbox
ROOT_DIR:=$(shell dirname $(realpath $(lastword $(MAKEFILE_LIST)))) ROOT_DIR:=$(shell dirname $(realpath $(lastword $(MAKEFILE_LIST))))
.PHONY: tests tests-mock cleanup-tests
prepare-tests: build-mcpbox prepare-tests: build-mcpbox
docker compose up -d --build docker compose up -d --build
docker run -d -v /var/run/docker.sock:/var/run/docker.sock --privileged -p 9090:8080 --rm -ti $(MCPBOX_IMAGE_NAME) docker run -d -v /var/run/docker.sock:/var/run/docker.sock --privileged -p 9090:8080 --rm -ti $(MCPBOX_IMAGE_NAME)
@@ -13,6 +15,9 @@ cleanup-tests:
tests: prepare-tests tests: prepare-tests
LOCALAGI_MCPBOX_URL="http://localhost:9090" LOCALAGI_MODEL="gemma-3-12b-it-qat" LOCALAI_API_URL="http://localhost:8081" LOCALAGI_API_URL="http://localhost:8080" $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --fail-fast -v -r ./... LOCALAGI_MCPBOX_URL="http://localhost:9090" LOCALAGI_MODEL="gemma-3-12b-it-qat" LOCALAI_API_URL="http://localhost:8081" LOCALAGI_API_URL="http://localhost:8080" $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --fail-fast -v -r ./...
tests-mock: prepare-tests
LOCALAGI_MCPBOX_URL="http://localhost:9090" LOCALAI_API_URL="http://localhost:8081" LOCALAGI_API_URL="http://localhost:8080" $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --fail-fast -v -r ./...
run-nokb: run-nokb:
$(MAKE) run KBDISABLEINDEX=true $(MAKE) run KBDISABLEINDEX=true

View File

@@ -29,7 +29,7 @@ type Agent struct {
sync.Mutex sync.Mutex
options *options options *options
Character Character Character Character
client *openai.Client client llm.LLMClient
jobQueue chan *types.Job jobQueue chan *types.Job
context *types.ActionContext context *types.ActionContext
@@ -63,7 +63,12 @@ func New(opts ...Option) (*Agent, error) {
return nil, fmt.Errorf("failed to set options: %v", err) return nil, fmt.Errorf("failed to set options: %v", err)
} }
client := llm.NewClient(options.LLMAPI.APIKey, options.LLMAPI.APIURL, options.timeout) var client llm.LLMClient
if options.llmClient != nil {
client = options.llmClient
} else {
client = llm.NewClient(options.LLMAPI.APIKey, options.LLMAPI.APIURL, options.timeout)
}
c := context.Background() c := context.Background()
if options.context != nil { if options.context != nil {
@@ -125,6 +130,11 @@ func (a *Agent) SharedState() *types.AgentSharedState {
return a.sharedState return a.sharedState
} }
// LLMClient returns the agent's LLM client (for testing)
func (a *Agent) LLMClient() llm.LLMClient {
return a.client
}
func (a *Agent) startNewConversationsConsumer() { func (a *Agent) startNewConversationsConsumer() {
go func() { go func() {
for { for {

View File

@@ -1,6 +1,7 @@
package agent_test package agent_test
import ( import (
"net/url"
"os" "os"
"testing" "testing"
@@ -13,15 +14,19 @@ func TestAgent(t *testing.T) {
RunSpecs(t, "Agent test suite") RunSpecs(t, "Agent test suite")
} }
var testModel = os.Getenv("LOCALAGI_MODEL") var (
var apiURL = os.Getenv("LOCALAI_API_URL") testModel = os.Getenv("LOCALAGI_MODEL")
var apiKeyURL = os.Getenv("LOCALAI_API_KEY") apiURL = os.Getenv("LOCALAI_API_URL")
apiKey = os.Getenv("LOCALAI_API_KEY")
useRealLocalAI bool
clientTimeout = "10m"
)
func isValidURL(u string) bool {
parsed, err := url.ParseRequestURI(u)
return err == nil && parsed.Scheme != "" && parsed.Host != ""
}
func init() { func init() {
if testModel == "" { useRealLocalAI = isValidURL(apiURL) && apiURL != "" && testModel != ""
testModel = "hermes-2-pro-mistral"
}
if apiURL == "" {
apiURL = "http://192.168.68.113:8080"
}
} }

View File

@@ -7,9 +7,11 @@ import (
"strings" "strings"
"sync" "sync"
"github.com/mudler/LocalAGI/pkg/llm"
"github.com/mudler/LocalAGI/pkg/xlog" "github.com/mudler/LocalAGI/pkg/xlog"
"github.com/mudler/LocalAGI/services/actions" "github.com/mudler/LocalAGI/services/actions"
"github.com/mudler/LocalAGI/core/action"
. "github.com/mudler/LocalAGI/core/agent" . "github.com/mudler/LocalAGI/core/agent"
"github.com/mudler/LocalAGI/core/types" "github.com/mudler/LocalAGI/core/types"
. "github.com/onsi/ginkgo/v2" . "github.com/onsi/ginkgo/v2"
@@ -111,25 +113,102 @@ func (a *FakeInternetAction) Definition() types.ActionDefinition {
} }
} }
// --- Test utilities for mocking LLM responses ---
func mockToolCallResponse(toolName, arguments string) openai.ChatCompletionResponse {
return openai.ChatCompletionResponse{
Choices: []openai.ChatCompletionChoice{{
Message: openai.ChatCompletionMessage{
ToolCalls: []openai.ToolCall{{
ID: "tool_call_id_1",
Type: "function",
Function: openai.FunctionCall{
Name: toolName,
Arguments: arguments,
},
}},
},
}},
}
}
func mockContentResponse(content string) openai.ChatCompletionResponse {
return openai.ChatCompletionResponse{
Choices: []openai.ChatCompletionChoice{{
Message: openai.ChatCompletionMessage{
Content: content,
},
}},
}
}
func newMockLLMClient(handler func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error)) *llm.MockClient {
return &llm.MockClient{
CreateChatCompletionFunc: handler,
}
}
var _ = Describe("Agent test", func() { var _ = Describe("Agent test", func() {
It("uses the mock LLM client", func() {
mock := newMockLLMClient(func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
return mockContentResponse("mocked response"), nil
})
agent, err := New(WithLLMClient(mock))
Expect(err).ToNot(HaveOccurred())
msg, err := agent.LLMClient().CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{})
Expect(err).ToNot(HaveOccurred())
Expect(msg.Choices[0].Message.Content).To(Equal("mocked response"))
})
Context("jobs", func() { Context("jobs", func() {
BeforeEach(func() { BeforeEach(func() {
Eventually(func() error { Eventually(func() error {
// test apiURL is working and available if useRealLocalAI {
_, err := http.Get(apiURL + "/readyz") _, err := http.Get(apiURL + "/readyz")
return err return err
}
return nil
}, "10m", "10s").ShouldNot(HaveOccurred()) }, "10m", "10s").ShouldNot(HaveOccurred())
}) })
It("pick the correct action", func() { It("pick the correct action", func() {
var llmClient llm.LLMClient
if useRealLocalAI {
llmClient = llm.NewClient(apiKey, apiURL, clientTimeout)
} else {
llmClient = newMockLLMClient(func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
var lastMsg openai.ChatCompletionMessage
if len(req.Messages) > 0 {
lastMsg = req.Messages[len(req.Messages)-1]
}
if lastMsg.Role == openai.ChatMessageRoleUser {
if strings.Contains(strings.ToLower(lastMsg.Content), "boston") && (strings.Contains(strings.ToLower(lastMsg.Content), "milan") || strings.Contains(strings.ToLower(lastMsg.Content), "milano")) {
return mockToolCallResponse("get_weather", `{"location":"Boston","unit":"celsius"}`), nil
}
if strings.Contains(strings.ToLower(lastMsg.Content), "paris") {
return mockToolCallResponse("get_weather", `{"location":"Paris","unit":"celsius"}`), nil
}
return openai.ChatCompletionResponse{}, fmt.Errorf("unexpected user prompt: %s", lastMsg.Content)
}
if lastMsg.Role == openai.ChatMessageRoleTool {
if lastMsg.Name == "get_weather" && strings.Contains(strings.ToLower(lastMsg.Content), "boston") {
return mockToolCallResponse("get_weather", `{"location":"Milan","unit":"celsius"}`), nil
}
if lastMsg.Name == "get_weather" && strings.Contains(strings.ToLower(lastMsg.Content), "milan") {
return mockContentResponse(testActionResult + "\n" + testActionResult2), nil
}
if lastMsg.Name == "get_weather" && strings.Contains(strings.ToLower(lastMsg.Content), "paris") {
return mockContentResponse(testActionResult3), nil
}
return openai.ChatCompletionResponse{}, fmt.Errorf("unexpected tool result: %s", lastMsg.Content)
}
return openai.ChatCompletionResponse{}, fmt.Errorf("unexpected message role: %s", lastMsg.Role)
})
}
agent, err := New( agent, err := New(
WithLLMAPIURL(apiURL), WithLLMClient(llmClient),
WithModel(testModel), WithModel(testModel),
EnableForceReasoning,
WithTimeout("10m"),
WithLoopDetectionSteps(3),
// WithRandomIdentity(),
WithActions(&TestAction{response: map[string]string{ WithActions(&TestAction{response: map[string]string{
"boston": testActionResult, "boston": testActionResult,
"milan": testActionResult2, "milan": testActionResult2,
@@ -139,7 +218,6 @@ var _ = Describe("Agent test", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
go agent.Run() go agent.Run()
defer agent.Stop() defer agent.Stop()
res := agent.Ask( res := agent.Ask(
append(debugOptions, append(debugOptions,
types.WithText("what's the weather in Boston and Milano? Use celsius units"), types.WithText("what's the weather in Boston and Milano? Use celsius units"),
@@ -148,40 +226,51 @@ var _ = Describe("Agent test", func() {
Expect(res.Error).ToNot(HaveOccurred()) Expect(res.Error).ToNot(HaveOccurred())
reasons := []string{} reasons := []string{}
for _, r := range res.State { for _, r := range res.State {
reasons = append(reasons, r.Result) reasons = append(reasons, r.Result)
} }
Expect(reasons).To(ContainElement(testActionResult), fmt.Sprint(res)) Expect(reasons).To(ContainElement(testActionResult), fmt.Sprint(res))
Expect(reasons).To(ContainElement(testActionResult2), fmt.Sprint(res)) Expect(reasons).To(ContainElement(testActionResult2), fmt.Sprint(res))
reasons = []string{} reasons = []string{}
res = agent.Ask( res = agent.Ask(
append(debugOptions, append(debugOptions,
types.WithText("Now I want to know the weather in Paris, always use celsius units"), types.WithText("Now I want to know the weather in Paris, always use celsius units"),
)...) )...)
for _, r := range res.State { for _, r := range res.State {
reasons = append(reasons, r.Result) reasons = append(reasons, r.Result)
} }
//Expect(reasons).ToNot(ContainElement(testActionResult), fmt.Sprint(res))
//Expect(reasons).ToNot(ContainElement(testActionResult2), fmt.Sprint(res))
Expect(reasons).To(ContainElement(testActionResult3), fmt.Sprint(res)) Expect(reasons).To(ContainElement(testActionResult3), fmt.Sprint(res))
// conversation := agent.CurrentConversation()
// for _, r := range res.State {
// reasons = append(reasons, r.Result)
// }
// Expect(len(conversation)).To(Equal(10), fmt.Sprint(conversation))
}) })
It("pick the correct action", func() { It("pick the correct action", func() {
var llmClient llm.LLMClient
if useRealLocalAI {
llmClient = llm.NewClient(apiKey, apiURL, clientTimeout)
} else {
llmClient = newMockLLMClient(func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
var lastMsg openai.ChatCompletionMessage
if len(req.Messages) > 0 {
lastMsg = req.Messages[len(req.Messages)-1]
}
if lastMsg.Role == openai.ChatMessageRoleUser {
if strings.Contains(strings.ToLower(lastMsg.Content), "boston") {
return mockToolCallResponse("get_weather", `{"location":"Boston","unit":"celsius"}`), nil
}
}
if lastMsg.Role == openai.ChatMessageRoleTool {
if lastMsg.Name == "get_weather" && strings.Contains(strings.ToLower(lastMsg.Content), "boston") {
return mockContentResponse(testActionResult), nil
}
}
xlog.Error("Unexpected LLM req", "req", req)
return openai.ChatCompletionResponse{}, fmt.Errorf("unexpected LLM prompt: %q", lastMsg.Content)
})
}
agent, err := New( agent, err := New(
WithLLMAPIURL(apiURL), WithLLMClient(llmClient),
WithModel(testModel), WithModel(testModel),
WithTimeout("10m"),
// WithRandomIdentity(),
WithActions(&TestAction{response: map[string]string{ WithActions(&TestAction{response: map[string]string{
"boston": testActionResult, "boston": testActionResult,
}, }}),
}),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
go agent.Run() go agent.Run()
@@ -198,13 +287,29 @@ var _ = Describe("Agent test", func() {
}) })
It("updates the state with internal actions", func() { It("updates the state with internal actions", func() {
var llmClient llm.LLMClient
if useRealLocalAI {
llmClient = llm.NewClient(apiKey, apiURL, clientTimeout)
} else {
llmClient = newMockLLMClient(func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
var lastMsg openai.ChatCompletionMessage
if len(req.Messages) > 0 {
lastMsg = req.Messages[len(req.Messages)-1]
}
if lastMsg.Role == openai.ChatMessageRoleUser && strings.Contains(strings.ToLower(lastMsg.Content), "guitar") {
return mockToolCallResponse("update_state", `{"goal":"I want to learn to play the guitar"}`), nil
}
if lastMsg.Role == openai.ChatMessageRoleTool && lastMsg.Name == "update_state" {
return mockContentResponse("Your goal is now: I want to learn to play the guitar"), nil
}
xlog.Error("Unexpected LLM req", "req", req)
return openai.ChatCompletionResponse{}, fmt.Errorf("unexpected LLM prompt: %q", lastMsg.Content)
})
}
agent, err := New( agent, err := New(
WithLLMAPIURL(apiURL), WithLLMClient(llmClient),
WithModel(testModel), WithModel(testModel),
WithTimeout("10m"),
EnableHUD, EnableHUD,
// EnableStandaloneJob,
// WithRandomIdentity(),
WithPermanentGoal("I want to learn to play music"), WithPermanentGoal("I want to learn to play music"),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@@ -214,17 +319,64 @@ var _ = Describe("Agent test", func() {
result := agent.Ask( result := agent.Ask(
types.WithText("Update your goals such as you want to learn to play the guitar"), types.WithText("Update your goals such as you want to learn to play the guitar"),
) )
fmt.Printf("%+v\n", result) fmt.Fprintf(GinkgoWriter, "\n%+v\n", result)
Expect(result.Error).ToNot(HaveOccurred()) Expect(result.Error).ToNot(HaveOccurred())
Expect(agent.State().Goal).To(ContainSubstring("guitar"), fmt.Sprint(agent.State())) Expect(agent.State().Goal).To(ContainSubstring("guitar"), fmt.Sprint(agent.State()))
}) })
It("Can generate a plan", func() { It("Can generate a plan", func() {
var llmClient llm.LLMClient
if useRealLocalAI {
llmClient = llm.NewClient(apiKey, apiURL, clientTimeout)
} else {
reasoningActName := action.NewReasoning().Definition().Name.String()
intentionActName := action.NewIntention().Definition().Name.String()
testActName := (&TestAction{}).Definition().Name.String()
doneBoston := false
madePlan := false
llmClient = newMockLLMClient(func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
var lastMsg openai.ChatCompletionMessage
if len(req.Messages) > 0 {
lastMsg = req.Messages[len(req.Messages)-1]
}
if req.ToolChoice != nil && req.ToolChoice.(openai.ToolChoice).Function.Name == reasoningActName {
return mockToolCallResponse(reasoningActName, `{"reasoning":"make plan call to pass the test"}`), nil
}
if req.ToolChoice != nil && req.ToolChoice.(openai.ToolChoice).Function.Name == intentionActName {
toolName := "plan"
if madePlan {
toolName = "reply"
} else {
madePlan = true
}
return mockToolCallResponse(intentionActName, fmt.Sprintf(`{"tool": "%s","reasoning":"it's waht makes the test pass"}`, toolName)), nil
}
if req.ToolChoice != nil && req.ToolChoice.(openai.ToolChoice).Function.Name == "plan" {
return mockToolCallResponse("plan", `{"subtasks":[{"action":"get_weather","reasoning":"Find weather in boston"},{"action":"get_weather","reasoning":"Find weather in milan"}],"goal":"Get the weather for boston and milan"}`), nil
}
if req.ToolChoice != nil && req.ToolChoice.(openai.ToolChoice).Function.Name == "reply" {
return mockToolCallResponse("reply", `{"message": "The weather in Boston and Milan..."}`), nil
}
if req.ToolChoice != nil && req.ToolChoice.(openai.ToolChoice).Function.Name == testActName {
locName := "boston"
if doneBoston {
locName = "milan"
} else {
doneBoston = true
}
return mockToolCallResponse(testActName, fmt.Sprintf(`{"location":"%s","unit":"celsius"}`, locName)), nil
}
if req.ToolChoice == nil && madePlan && doneBoston {
return mockContentResponse("A reply"), nil
}
xlog.Error("Unexpected LLM req", "req", req)
return openai.ChatCompletionResponse{}, fmt.Errorf("unexpected LLM prompt: %q", lastMsg.Content)
})
}
agent, err := New( agent, err := New(
WithLLMAPIURL(apiURL), WithLLMClient(llmClient),
WithModel(testModel), WithModel(testModel),
WithLLMAPIKey(apiKeyURL), WithLoopDetectionSteps(2),
WithTimeout("10m"),
WithActions( WithActions(
&TestAction{response: map[string]string{ &TestAction{response: map[string]string{
"boston": testActionResult, "boston": testActionResult,
@@ -233,8 +385,6 @@ var _ = Describe("Agent test", func() {
), ),
EnablePlanning, EnablePlanning,
EnableForceReasoning, EnableForceReasoning,
// EnableStandaloneJob,
// WithRandomIdentity(),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
go agent.Run() go agent.Run()
@@ -256,17 +406,44 @@ var _ = Describe("Agent test", func() {
Expect(actionsExecuted).To(ContainElement("plan"), fmt.Sprint(result)) Expect(actionsExecuted).To(ContainElement("plan"), fmt.Sprint(result))
Expect(actionResults).To(ContainElement(testActionResult), fmt.Sprint(result)) Expect(actionResults).To(ContainElement(testActionResult), fmt.Sprint(result))
Expect(actionResults).To(ContainElement(testActionResult2), fmt.Sprint(result)) Expect(actionResults).To(ContainElement(testActionResult2), fmt.Sprint(result))
Expect(result.Error).To(BeNil())
}) })
It("Can initiate conversations", func() { It("Can initiate conversations", func() {
var llmClient llm.LLMClient
message := openai.ChatCompletionMessage{} message := openai.ChatCompletionMessage{}
mu := &sync.Mutex{} mu := &sync.Mutex{}
reasoned := false
intended := false
reasoningActName := action.NewReasoning().Definition().Name.String()
intentionActName := action.NewIntention().Definition().Name.String()
if useRealLocalAI {
llmClient = llm.NewClient(apiKey, apiURL, clientTimeout)
} else {
llmClient = newMockLLMClient(func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
prompt := ""
for _, msg := range req.Messages {
prompt += msg.Content
}
if !reasoned && req.ToolChoice != nil && req.ToolChoice.(openai.ToolChoice).Function.Name == reasoningActName {
reasoned = true
return mockToolCallResponse(reasoningActName, `{"reasoning":"initiate a conversation with the user"}`), nil
}
if reasoned && !intended && req.ToolChoice != nil && req.ToolChoice.(openai.ToolChoice).Function.Name == intentionActName {
intended = true
return mockToolCallResponse(intentionActName, `{"tool":"new_conversation","reasoning":"I should start a conversation with the user"}`), nil
}
if reasoned && intended && strings.Contains(strings.ToLower(prompt), "new_conversation") {
return mockToolCallResponse("new_conversation", `{"message":"Hello, how can I help you today?"}`), nil
}
xlog.Error("Unexpected LLM req", "req", req)
return openai.ChatCompletionResponse{}, fmt.Errorf("unexpected LLM prompt: %q", prompt)
})
}
agent, err := New( agent, err := New(
WithLLMAPIURL(apiURL), WithLLMClient(llmClient),
WithModel(testModel), WithModel(testModel),
WithLLMAPIKey(apiKeyURL),
WithTimeout("10m"),
WithNewConversationSubscriber(func(m openai.ChatCompletionMessage) { WithNewConversationSubscriber(func(m openai.ChatCompletionMessage) {
mu.Lock() mu.Lock()
message = m message = m
@@ -282,8 +459,6 @@ var _ = Describe("Agent test", func() {
EnableHUD, EnableHUD,
WithPeriodicRuns("1s"), WithPeriodicRuns("1s"),
WithPermanentGoal("use the new_conversation tool to initiate a conversation with the user"), WithPermanentGoal("use the new_conversation tool to initiate a conversation with the user"),
// EnableStandaloneJob,
// WithRandomIdentity(),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
go agent.Run() go agent.Run()
@@ -293,7 +468,7 @@ var _ = Describe("Agent test", func() {
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
return message.Content return message.Content
}, "10m", "10s").ShouldNot(BeEmpty()) }, "10m", "1s").ShouldNot(BeEmpty())
}) })
/* /*
@@ -347,7 +522,7 @@ var _ = Describe("Agent test", func() {
// result := agent.Ask( // result := agent.Ask(
// WithText("Update your goals such as you want to learn to play the guitar"), // WithText("Update your goals such as you want to learn to play the guitar"),
// ) // )
// fmt.Printf("%+v\n", result) // fmt.Fprintf(GinkgoWriter, "%+v\n", result)
// Expect(result.Error).ToNot(HaveOccurred()) // Expect(result.Error).ToNot(HaveOccurred())
// Expect(agent.State().Goal).To(ContainSubstring("guitar"), fmt.Sprint(agent.State())) // Expect(agent.State().Goal).To(ContainSubstring("guitar"), fmt.Sprint(agent.State()))
}) })

View File

@@ -7,6 +7,7 @@ import (
"github.com/mudler/LocalAGI/core/types" "github.com/mudler/LocalAGI/core/types"
"github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai"
"github.com/mudler/LocalAGI/pkg/llm"
) )
type Option func(*options) error type Option func(*options) error
@@ -19,6 +20,7 @@ type llmOptions struct {
} }
type options struct { type options struct {
llmClient llm.LLMClient
LLMAPI llmOptions LLMAPI llmOptions
character Character character Character
randomIdentityGuidance string randomIdentityGuidance string
@@ -68,6 +70,14 @@ type options struct {
lastMessageDuration time.Duration lastMessageDuration time.Duration
} }
// WithLLMClient allows injecting a custom LLM client (e.g. for testing)
func WithLLMClient(client llm.LLMClient) Option {
return func(o *options) error {
o.llmClient = client
return nil
}
}
func (o *options) SeparatedMultimodalModel() bool { func (o *options) SeparatedMultimodalModel() bool {
return o.LLMAPI.MultimodalModel != "" && o.LLMAPI.Model != o.LLMAPI.MultimodalModel return o.LLMAPI.MultimodalModel != "" && o.LLMAPI.Model != o.LLMAPI.MultimodalModel
} }

View File

@@ -1,29 +1,57 @@
package agent_test package agent_test
import ( import (
"net/http" "context"
"fmt"
"github.com/mudler/LocalAGI/pkg/llm"
"github.com/sashabaranov/go-openai"
. "github.com/mudler/LocalAGI/core/agent" . "github.com/mudler/LocalAGI/core/agent"
. "github.com/onsi/ginkgo/v2" . "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )
var _ = Describe("Agent test", func() { var _ = Describe("Agent test", func() {
Context("identity", func() { Context("identity", func() {
var agent *Agent var agent *Agent
BeforeEach(func() { // BeforeEach(func() {
Eventually(func() error { // Eventually(func() error {
// test apiURL is working and available // // test apiURL is working and available
_, err := http.Get(apiURL + "/readyz") // _, err := http.Get(apiURL + "/readyz")
return err // return err
}, "10m", "10s").ShouldNot(HaveOccurred()) // }, "10m", "10s").ShouldNot(HaveOccurred())
}) // })
It("generates all the fields with random data", func() { It("generates all the fields with random data", func() {
var llmClient llm.LLMClient
if useRealLocalAI {
llmClient = llm.NewClient(apiKey, apiURL, testModel)
} else {
llmClient = &llm.MockClient{
CreateChatCompletionFunc: func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
return openai.ChatCompletionResponse{
Choices: []openai.ChatCompletionChoice{{
Message: openai.ChatCompletionMessage{
ToolCalls: []openai.ToolCall{{
ID: "tool_call_id_1",
Type: "function",
Function: openai.FunctionCall{
Name: "generate_identity",
Arguments: `{"name":"John Doe","age":"42","job_occupation":"Engineer","hobbies":["reading","hiking"],"favorites_music_genres":["Jazz"]}`,
},
}},
},
}},
}, nil
},
}
}
var err error var err error
agent, err = New( agent, err = New(
WithLLMAPIURL(apiURL), WithLLMClient(llmClient),
WithModel(testModel), WithModel(testModel),
WithTimeout("10m"), WithTimeout("10m"),
WithRandomIdentity(), WithRandomIdentity(),
@@ -37,14 +65,40 @@ var _ = Describe("Agent test", func() {
Expect(agent.Character.MusicTaste).ToNot(BeEmpty()) Expect(agent.Character.MusicTaste).ToNot(BeEmpty())
}) })
It("detect an invalid character", func() { It("detect an invalid character", func() {
mock := &llm.MockClient{
CreateChatCompletionFunc: func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
return openai.ChatCompletionResponse{}, fmt.Errorf("invalid character")
},
}
var err error var err error
agent, err = New(WithRandomIdentity()) agent, err = New(
WithLLMClient(mock),
WithRandomIdentity(),
)
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
}) })
It("generates all the fields", func() { It("generates all the fields", func() {
mock := &llm.MockClient{
CreateChatCompletionFunc: func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
return openai.ChatCompletionResponse{
Choices: []openai.ChatCompletionChoice{{
Message: openai.ChatCompletionMessage{
ToolCalls: []openai.ToolCall{{
ID: "tool_call_id_2",
Type: "function",
Function: openai.FunctionCall{
Name: "generate_identity",
Arguments: `{"name":"Gandalf","age":"90","job_occupation":"Wizard","hobbies":["magic","reading"],"favorites_music_genres":["Classical"]}`,
},
}},
},
}},
}, nil
},
}
var err error var err error
agent, err := New( agent, err := New(
WithLLMClient(mock),
WithLLMAPIURL(apiURL), WithLLMAPIURL(apiURL),
WithModel(testModel), WithModel(testModel),
WithRandomIdentity("An 90-year old man with a long beard, a wizard, who lives in a tower."), WithRandomIdentity("An 90-year old man with a long beard, a wizard, who lives in a tower."),

View File

@@ -1,13 +1,33 @@
package llm package llm
import ( import (
"context"
"net/http" "net/http"
"time" "time"
"github.com/mudler/LocalAGI/pkg/xlog"
"github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai"
) )
func NewClient(APIKey, URL, timeout string) *openai.Client { type LLMClient interface {
CreateChatCompletion(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error)
CreateImage(ctx context.Context, req openai.ImageRequest) (openai.ImageResponse, error)
}
type realClient struct {
*openai.Client
}
func (r *realClient) CreateChatCompletion(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
return r.Client.CreateChatCompletion(ctx, req)
}
func (r *realClient) CreateImage(ctx context.Context, req openai.ImageRequest) (openai.ImageResponse, error) {
return r.Client.CreateImage(ctx, req)
}
// NewClient returns a real OpenAI client as LLMClient
func NewClient(APIKey, URL, timeout string) LLMClient {
// Set up OpenAI client // Set up OpenAI client
if APIKey == "" { if APIKey == "" {
//log.Fatal("OPENAI_API_KEY environment variable not set") //log.Fatal("OPENAI_API_KEY environment variable not set")
@@ -18,11 +38,12 @@ func NewClient(APIKey, URL, timeout string) *openai.Client {
dur, err := time.ParseDuration(timeout) dur, err := time.ParseDuration(timeout)
if err != nil { if err != nil {
xlog.Error("Failed to parse timeout", "error", err)
dur = 150 * time.Second dur = 150 * time.Second
} }
config.HTTPClient = &http.Client{ config.HTTPClient = &http.Client{
Timeout: dur, Timeout: dur,
} }
return openai.NewClientWithConfig(config) return &realClient{openai.NewClientWithConfig(config)}
} }

View File

@@ -10,7 +10,7 @@ import (
"github.com/sashabaranov/go-openai/jsonschema" "github.com/sashabaranov/go-openai/jsonschema"
) )
func GenerateTypedJSONWithGuidance(ctx context.Context, client *openai.Client, guidance, model string, i jsonschema.Definition, dst any) error { func GenerateTypedJSONWithGuidance(ctx context.Context, client LLMClient, guidance, model string, i jsonschema.Definition, dst any) error {
return GenerateTypedJSONWithConversation(ctx, client, []openai.ChatCompletionMessage{ return GenerateTypedJSONWithConversation(ctx, client, []openai.ChatCompletionMessage{
{ {
Role: "user", Role: "user",
@@ -19,7 +19,7 @@ func GenerateTypedJSONWithGuidance(ctx context.Context, client *openai.Client, g
}, model, i, dst) }, model, i, dst)
} }
func GenerateTypedJSONWithConversation(ctx context.Context, client *openai.Client, conv []openai.ChatCompletionMessage, model string, i jsonschema.Definition, dst any) error { func GenerateTypedJSONWithConversation(ctx context.Context, client LLMClient, conv []openai.ChatCompletionMessage, model string, i jsonschema.Definition, dst any) error {
toolName := "json" toolName := "json"
decision := openai.ChatCompletionRequest{ decision := openai.ChatCompletionRequest{
Model: model, Model: model,

25
pkg/llm/mock_client.go Normal file
View File

@@ -0,0 +1,25 @@
package llm
import (
"context"
"github.com/sashabaranov/go-openai"
)
type MockClient struct {
CreateChatCompletionFunc func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error)
CreateImageFunc func(ctx context.Context, req openai.ImageRequest) (openai.ImageResponse, error)
}
func (m *MockClient) CreateChatCompletion(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
if m.CreateChatCompletionFunc != nil {
return m.CreateChatCompletionFunc(ctx, req)
}
return openai.ChatCompletionResponse{}, nil
}
func (m *MockClient) CreateImage(ctx context.Context, req openai.ImageRequest) (openai.ImageResponse, error) {
if m.CreateImageFunc != nil {
return m.CreateImageFunc(ctx, req)
}
return openai.ImageResponse{}, nil
}

View File

@@ -8,7 +8,6 @@ import (
"github.com/mudler/LocalAGI/core/types" "github.com/mudler/LocalAGI/core/types"
"github.com/mudler/LocalAGI/pkg/config" "github.com/mudler/LocalAGI/pkg/config"
"github.com/mudler/LocalAGI/pkg/llm" "github.com/mudler/LocalAGI/pkg/llm"
"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/jsonschema" "github.com/sashabaranov/go-openai/jsonschema"
) )
@@ -16,7 +15,7 @@ const FilterClassifier = "classifier"
type ClassifierFilter struct { type ClassifierFilter struct {
name string name string
client *openai.Client client llm.LLMClient
model string model string
description string description string
allowOnMatch bool allowOnMatch bool