chore(tests): Mock LLM in tests for PRs
This saves time when testing on CPU which is the only sensible thing to do on GitHub CI for PRs. For releases or once the commit is merged we could use an external runner with GPU or just wait. Signed-off-by: Richard Palethorpe <io@richiejp.com>
This commit is contained in:
4
.github/workflows/tests.yml
vendored
4
.github/workflows/tests.yml
vendored
@@ -48,7 +48,11 @@ jobs:
|
|||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
|
if [[ "$GITHUB_EVENT_NAME" == "pull_request" ]]; then
|
||||||
|
make tests-mock
|
||||||
|
else
|
||||||
make tests
|
make tests
|
||||||
|
fi
|
||||||
#sudo mv coverage/coverage.txt coverage.txt
|
#sudo mv coverage/coverage.txt coverage.txt
|
||||||
#sudo chmod 777 coverage.txt
|
#sudo chmod 777 coverage.txt
|
||||||
|
|
||||||
|
|||||||
49
.github/workflows/tests_fragile.yml
vendored
Normal file
49
.github/workflows/tests_fragile.yml
vendored
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
name: Run Fragile Go Tests
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- '**'
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ci-non-blocking-tests-${{ github.head_ref || github.ref }}-${{ github.repository }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
llm-tests:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
- run: |
|
||||||
|
# Add Docker's official GPG key:
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install -y ca-certificates curl
|
||||||
|
sudo install -m 0755 -d /etc/apt/keyrings
|
||||||
|
sudo curl -fsSL https://download.docker.com/linux/ubuntu/gpg -o /etc/apt/keyrings/docker.asc
|
||||||
|
sudo chmod a+r /etc/apt/keyrings/docker.asc
|
||||||
|
|
||||||
|
# Add the repository to Apt sources:
|
||||||
|
echo \
|
||||||
|
"deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/ubuntu \
|
||||||
|
$(. /etc/os-release && echo "${UBUNTU_CODENAME:-$VERSION_CODENAME}") stable" | \
|
||||||
|
sudo tee /etc/apt/sources.list.d/docker.list > /dev/null
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install -y docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin make
|
||||||
|
docker version
|
||||||
|
|
||||||
|
docker run --rm hello-world
|
||||||
|
- uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: '>=1.17.0'
|
||||||
|
- name: Free up disk space
|
||||||
|
run: |
|
||||||
|
sudo rm -rf /usr/share/dotnet
|
||||||
|
sudo rm -rf /usr/local/lib/android
|
||||||
|
sudo rm -rf /opt/ghc
|
||||||
|
sudo apt-get clean
|
||||||
|
docker system prune -af || true
|
||||||
|
df -h
|
||||||
|
- name: Run tests
|
||||||
|
run: |
|
||||||
|
make tests
|
||||||
5
Makefile
5
Makefile
@@ -3,6 +3,8 @@ IMAGE_NAME?=webui
|
|||||||
MCPBOX_IMAGE_NAME?=mcpbox
|
MCPBOX_IMAGE_NAME?=mcpbox
|
||||||
ROOT_DIR:=$(shell dirname $(realpath $(lastword $(MAKEFILE_LIST))))
|
ROOT_DIR:=$(shell dirname $(realpath $(lastword $(MAKEFILE_LIST))))
|
||||||
|
|
||||||
|
.PHONY: tests tests-mock cleanup-tests
|
||||||
|
|
||||||
prepare-tests: build-mcpbox
|
prepare-tests: build-mcpbox
|
||||||
docker compose up -d --build
|
docker compose up -d --build
|
||||||
docker run -d -v /var/run/docker.sock:/var/run/docker.sock --privileged -p 9090:8080 --rm -ti $(MCPBOX_IMAGE_NAME)
|
docker run -d -v /var/run/docker.sock:/var/run/docker.sock --privileged -p 9090:8080 --rm -ti $(MCPBOX_IMAGE_NAME)
|
||||||
@@ -13,6 +15,9 @@ cleanup-tests:
|
|||||||
tests: prepare-tests
|
tests: prepare-tests
|
||||||
LOCALAGI_MCPBOX_URL="http://localhost:9090" LOCALAGI_MODEL="gemma-3-12b-it-qat" LOCALAI_API_URL="http://localhost:8081" LOCALAGI_API_URL="http://localhost:8080" $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --fail-fast -v -r ./...
|
LOCALAGI_MCPBOX_URL="http://localhost:9090" LOCALAGI_MODEL="gemma-3-12b-it-qat" LOCALAI_API_URL="http://localhost:8081" LOCALAGI_API_URL="http://localhost:8080" $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --fail-fast -v -r ./...
|
||||||
|
|
||||||
|
tests-mock: prepare-tests
|
||||||
|
LOCALAGI_MCPBOX_URL="http://localhost:9090" LOCALAI_API_URL="http://localhost:8081" LOCALAGI_API_URL="http://localhost:8080" $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --fail-fast -v -r ./...
|
||||||
|
|
||||||
run-nokb:
|
run-nokb:
|
||||||
$(MAKE) run KBDISABLEINDEX=true
|
$(MAKE) run KBDISABLEINDEX=true
|
||||||
|
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ type Agent struct {
|
|||||||
sync.Mutex
|
sync.Mutex
|
||||||
options *options
|
options *options
|
||||||
Character Character
|
Character Character
|
||||||
client *openai.Client
|
client llm.LLMClient
|
||||||
jobQueue chan *types.Job
|
jobQueue chan *types.Job
|
||||||
context *types.ActionContext
|
context *types.ActionContext
|
||||||
|
|
||||||
@@ -63,7 +63,12 @@ func New(opts ...Option) (*Agent, error) {
|
|||||||
return nil, fmt.Errorf("failed to set options: %v", err)
|
return nil, fmt.Errorf("failed to set options: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
client := llm.NewClient(options.LLMAPI.APIKey, options.LLMAPI.APIURL, options.timeout)
|
var client llm.LLMClient
|
||||||
|
if options.llmClient != nil {
|
||||||
|
client = options.llmClient
|
||||||
|
} else {
|
||||||
|
client = llm.NewClient(options.LLMAPI.APIKey, options.LLMAPI.APIURL, options.timeout)
|
||||||
|
}
|
||||||
|
|
||||||
c := context.Background()
|
c := context.Background()
|
||||||
if options.context != nil {
|
if options.context != nil {
|
||||||
@@ -125,6 +130,11 @@ func (a *Agent) SharedState() *types.AgentSharedState {
|
|||||||
return a.sharedState
|
return a.sharedState
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LLMClient returns the agent's LLM client (for testing)
|
||||||
|
func (a *Agent) LLMClient() llm.LLMClient {
|
||||||
|
return a.client
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Agent) startNewConversationsConsumer() {
|
func (a *Agent) startNewConversationsConsumer() {
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package agent_test
|
package agent_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@@ -13,15 +14,19 @@ func TestAgent(t *testing.T) {
|
|||||||
RunSpecs(t, "Agent test suite")
|
RunSpecs(t, "Agent test suite")
|
||||||
}
|
}
|
||||||
|
|
||||||
var testModel = os.Getenv("LOCALAGI_MODEL")
|
var (
|
||||||
var apiURL = os.Getenv("LOCALAI_API_URL")
|
testModel = os.Getenv("LOCALAGI_MODEL")
|
||||||
var apiKeyURL = os.Getenv("LOCALAI_API_KEY")
|
apiURL = os.Getenv("LOCALAI_API_URL")
|
||||||
|
apiKey = os.Getenv("LOCALAI_API_KEY")
|
||||||
|
useRealLocalAI bool
|
||||||
|
clientTimeout = "10m"
|
||||||
|
)
|
||||||
|
|
||||||
|
func isValidURL(u string) bool {
|
||||||
|
parsed, err := url.ParseRequestURI(u)
|
||||||
|
return err == nil && parsed.Scheme != "" && parsed.Host != ""
|
||||||
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
if testModel == "" {
|
useRealLocalAI = isValidURL(apiURL) && apiURL != "" && testModel != ""
|
||||||
testModel = "hermes-2-pro-mistral"
|
|
||||||
}
|
|
||||||
if apiURL == "" {
|
|
||||||
apiURL = "http://192.168.68.113:8080"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,9 +7,11 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAGI/pkg/llm"
|
||||||
"github.com/mudler/LocalAGI/pkg/xlog"
|
"github.com/mudler/LocalAGI/pkg/xlog"
|
||||||
"github.com/mudler/LocalAGI/services/actions"
|
"github.com/mudler/LocalAGI/services/actions"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAGI/core/action"
|
||||||
. "github.com/mudler/LocalAGI/core/agent"
|
. "github.com/mudler/LocalAGI/core/agent"
|
||||||
"github.com/mudler/LocalAGI/core/types"
|
"github.com/mudler/LocalAGI/core/types"
|
||||||
. "github.com/onsi/ginkgo/v2"
|
. "github.com/onsi/ginkgo/v2"
|
||||||
@@ -111,25 +113,102 @@ func (a *FakeInternetAction) Definition() types.ActionDefinition {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --- Test utilities for mocking LLM responses ---
|
||||||
|
|
||||||
|
func mockToolCallResponse(toolName, arguments string) openai.ChatCompletionResponse {
|
||||||
|
return openai.ChatCompletionResponse{
|
||||||
|
Choices: []openai.ChatCompletionChoice{{
|
||||||
|
Message: openai.ChatCompletionMessage{
|
||||||
|
ToolCalls: []openai.ToolCall{{
|
||||||
|
ID: "tool_call_id_1",
|
||||||
|
Type: "function",
|
||||||
|
Function: openai.FunctionCall{
|
||||||
|
Name: toolName,
|
||||||
|
Arguments: arguments,
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mockContentResponse(content string) openai.ChatCompletionResponse {
|
||||||
|
return openai.ChatCompletionResponse{
|
||||||
|
Choices: []openai.ChatCompletionChoice{{
|
||||||
|
Message: openai.ChatCompletionMessage{
|
||||||
|
Content: content,
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMockLLMClient(handler func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error)) *llm.MockClient {
|
||||||
|
return &llm.MockClient{
|
||||||
|
CreateChatCompletionFunc: handler,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var _ = Describe("Agent test", func() {
|
var _ = Describe("Agent test", func() {
|
||||||
|
It("uses the mock LLM client", func() {
|
||||||
|
mock := newMockLLMClient(func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
|
||||||
|
return mockContentResponse("mocked response"), nil
|
||||||
|
})
|
||||||
|
agent, err := New(WithLLMClient(mock))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
msg, err := agent.LLMClient().CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{})
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(msg.Choices[0].Message.Content).To(Equal("mocked response"))
|
||||||
|
})
|
||||||
|
|
||||||
Context("jobs", func() {
|
Context("jobs", func() {
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
Eventually(func() error {
|
Eventually(func() error {
|
||||||
// test apiURL is working and available
|
if useRealLocalAI {
|
||||||
_, err := http.Get(apiURL + "/readyz")
|
_, err := http.Get(apiURL + "/readyz")
|
||||||
return err
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}, "10m", "10s").ShouldNot(HaveOccurred())
|
}, "10m", "10s").ShouldNot(HaveOccurred())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("pick the correct action", func() {
|
It("pick the correct action", func() {
|
||||||
|
var llmClient llm.LLMClient
|
||||||
|
if useRealLocalAI {
|
||||||
|
llmClient = llm.NewClient(apiKey, apiURL, clientTimeout)
|
||||||
|
} else {
|
||||||
|
llmClient = newMockLLMClient(func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
|
||||||
|
var lastMsg openai.ChatCompletionMessage
|
||||||
|
if len(req.Messages) > 0 {
|
||||||
|
lastMsg = req.Messages[len(req.Messages)-1]
|
||||||
|
}
|
||||||
|
if lastMsg.Role == openai.ChatMessageRoleUser {
|
||||||
|
if strings.Contains(strings.ToLower(lastMsg.Content), "boston") && (strings.Contains(strings.ToLower(lastMsg.Content), "milan") || strings.Contains(strings.ToLower(lastMsg.Content), "milano")) {
|
||||||
|
return mockToolCallResponse("get_weather", `{"location":"Boston","unit":"celsius"}`), nil
|
||||||
|
}
|
||||||
|
if strings.Contains(strings.ToLower(lastMsg.Content), "paris") {
|
||||||
|
return mockToolCallResponse("get_weather", `{"location":"Paris","unit":"celsius"}`), nil
|
||||||
|
}
|
||||||
|
return openai.ChatCompletionResponse{}, fmt.Errorf("unexpected user prompt: %s", lastMsg.Content)
|
||||||
|
}
|
||||||
|
if lastMsg.Role == openai.ChatMessageRoleTool {
|
||||||
|
if lastMsg.Name == "get_weather" && strings.Contains(strings.ToLower(lastMsg.Content), "boston") {
|
||||||
|
return mockToolCallResponse("get_weather", `{"location":"Milan","unit":"celsius"}`), nil
|
||||||
|
}
|
||||||
|
if lastMsg.Name == "get_weather" && strings.Contains(strings.ToLower(lastMsg.Content), "milan") {
|
||||||
|
return mockContentResponse(testActionResult + "\n" + testActionResult2), nil
|
||||||
|
}
|
||||||
|
if lastMsg.Name == "get_weather" && strings.Contains(strings.ToLower(lastMsg.Content), "paris") {
|
||||||
|
return mockContentResponse(testActionResult3), nil
|
||||||
|
}
|
||||||
|
return openai.ChatCompletionResponse{}, fmt.Errorf("unexpected tool result: %s", lastMsg.Content)
|
||||||
|
}
|
||||||
|
return openai.ChatCompletionResponse{}, fmt.Errorf("unexpected message role: %s", lastMsg.Role)
|
||||||
|
})
|
||||||
|
}
|
||||||
agent, err := New(
|
agent, err := New(
|
||||||
WithLLMAPIURL(apiURL),
|
WithLLMClient(llmClient),
|
||||||
WithModel(testModel),
|
WithModel(testModel),
|
||||||
EnableForceReasoning,
|
|
||||||
WithTimeout("10m"),
|
|
||||||
WithLoopDetectionSteps(3),
|
|
||||||
// WithRandomIdentity(),
|
|
||||||
WithActions(&TestAction{response: map[string]string{
|
WithActions(&TestAction{response: map[string]string{
|
||||||
"boston": testActionResult,
|
"boston": testActionResult,
|
||||||
"milan": testActionResult2,
|
"milan": testActionResult2,
|
||||||
@@ -139,7 +218,6 @@ var _ = Describe("Agent test", func() {
|
|||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
go agent.Run()
|
go agent.Run()
|
||||||
defer agent.Stop()
|
defer agent.Stop()
|
||||||
|
|
||||||
res := agent.Ask(
|
res := agent.Ask(
|
||||||
append(debugOptions,
|
append(debugOptions,
|
||||||
types.WithText("what's the weather in Boston and Milano? Use celsius units"),
|
types.WithText("what's the weather in Boston and Milano? Use celsius units"),
|
||||||
@@ -148,40 +226,51 @@ var _ = Describe("Agent test", func() {
|
|||||||
Expect(res.Error).ToNot(HaveOccurred())
|
Expect(res.Error).ToNot(HaveOccurred())
|
||||||
reasons := []string{}
|
reasons := []string{}
|
||||||
for _, r := range res.State {
|
for _, r := range res.State {
|
||||||
|
|
||||||
reasons = append(reasons, r.Result)
|
reasons = append(reasons, r.Result)
|
||||||
}
|
}
|
||||||
Expect(reasons).To(ContainElement(testActionResult), fmt.Sprint(res))
|
Expect(reasons).To(ContainElement(testActionResult), fmt.Sprint(res))
|
||||||
Expect(reasons).To(ContainElement(testActionResult2), fmt.Sprint(res))
|
Expect(reasons).To(ContainElement(testActionResult2), fmt.Sprint(res))
|
||||||
reasons = []string{}
|
reasons = []string{}
|
||||||
|
|
||||||
res = agent.Ask(
|
res = agent.Ask(
|
||||||
append(debugOptions,
|
append(debugOptions,
|
||||||
types.WithText("Now I want to know the weather in Paris, always use celsius units"),
|
types.WithText("Now I want to know the weather in Paris, always use celsius units"),
|
||||||
)...)
|
)...)
|
||||||
for _, r := range res.State {
|
for _, r := range res.State {
|
||||||
|
|
||||||
reasons = append(reasons, r.Result)
|
reasons = append(reasons, r.Result)
|
||||||
}
|
}
|
||||||
//Expect(reasons).ToNot(ContainElement(testActionResult), fmt.Sprint(res))
|
|
||||||
//Expect(reasons).ToNot(ContainElement(testActionResult2), fmt.Sprint(res))
|
|
||||||
Expect(reasons).To(ContainElement(testActionResult3), fmt.Sprint(res))
|
Expect(reasons).To(ContainElement(testActionResult3), fmt.Sprint(res))
|
||||||
// conversation := agent.CurrentConversation()
|
|
||||||
// for _, r := range res.State {
|
|
||||||
// reasons = append(reasons, r.Result)
|
|
||||||
// }
|
|
||||||
// Expect(len(conversation)).To(Equal(10), fmt.Sprint(conversation))
|
|
||||||
})
|
})
|
||||||
|
|
||||||
It("pick the correct action", func() {
|
It("pick the correct action", func() {
|
||||||
|
var llmClient llm.LLMClient
|
||||||
|
if useRealLocalAI {
|
||||||
|
llmClient = llm.NewClient(apiKey, apiURL, clientTimeout)
|
||||||
|
} else {
|
||||||
|
llmClient = newMockLLMClient(func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
|
||||||
|
var lastMsg openai.ChatCompletionMessage
|
||||||
|
if len(req.Messages) > 0 {
|
||||||
|
lastMsg = req.Messages[len(req.Messages)-1]
|
||||||
|
}
|
||||||
|
if lastMsg.Role == openai.ChatMessageRoleUser {
|
||||||
|
if strings.Contains(strings.ToLower(lastMsg.Content), "boston") {
|
||||||
|
return mockToolCallResponse("get_weather", `{"location":"Boston","unit":"celsius"}`), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if lastMsg.Role == openai.ChatMessageRoleTool {
|
||||||
|
if lastMsg.Name == "get_weather" && strings.Contains(strings.ToLower(lastMsg.Content), "boston") {
|
||||||
|
return mockContentResponse(testActionResult), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
xlog.Error("Unexpected LLM req", "req", req)
|
||||||
|
return openai.ChatCompletionResponse{}, fmt.Errorf("unexpected LLM prompt: %q", lastMsg.Content)
|
||||||
|
})
|
||||||
|
}
|
||||||
agent, err := New(
|
agent, err := New(
|
||||||
WithLLMAPIURL(apiURL),
|
WithLLMClient(llmClient),
|
||||||
WithModel(testModel),
|
WithModel(testModel),
|
||||||
WithTimeout("10m"),
|
|
||||||
// WithRandomIdentity(),
|
|
||||||
WithActions(&TestAction{response: map[string]string{
|
WithActions(&TestAction{response: map[string]string{
|
||||||
"boston": testActionResult,
|
"boston": testActionResult,
|
||||||
},
|
}}),
|
||||||
}),
|
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
go agent.Run()
|
go agent.Run()
|
||||||
@@ -198,13 +287,29 @@ var _ = Describe("Agent test", func() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
It("updates the state with internal actions", func() {
|
It("updates the state with internal actions", func() {
|
||||||
|
var llmClient llm.LLMClient
|
||||||
|
if useRealLocalAI {
|
||||||
|
llmClient = llm.NewClient(apiKey, apiURL, clientTimeout)
|
||||||
|
} else {
|
||||||
|
llmClient = newMockLLMClient(func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
|
||||||
|
var lastMsg openai.ChatCompletionMessage
|
||||||
|
if len(req.Messages) > 0 {
|
||||||
|
lastMsg = req.Messages[len(req.Messages)-1]
|
||||||
|
}
|
||||||
|
if lastMsg.Role == openai.ChatMessageRoleUser && strings.Contains(strings.ToLower(lastMsg.Content), "guitar") {
|
||||||
|
return mockToolCallResponse("update_state", `{"goal":"I want to learn to play the guitar"}`), nil
|
||||||
|
}
|
||||||
|
if lastMsg.Role == openai.ChatMessageRoleTool && lastMsg.Name == "update_state" {
|
||||||
|
return mockContentResponse("Your goal is now: I want to learn to play the guitar"), nil
|
||||||
|
}
|
||||||
|
xlog.Error("Unexpected LLM req", "req", req)
|
||||||
|
return openai.ChatCompletionResponse{}, fmt.Errorf("unexpected LLM prompt: %q", lastMsg.Content)
|
||||||
|
})
|
||||||
|
}
|
||||||
agent, err := New(
|
agent, err := New(
|
||||||
WithLLMAPIURL(apiURL),
|
WithLLMClient(llmClient),
|
||||||
WithModel(testModel),
|
WithModel(testModel),
|
||||||
WithTimeout("10m"),
|
|
||||||
EnableHUD,
|
EnableHUD,
|
||||||
// EnableStandaloneJob,
|
|
||||||
// WithRandomIdentity(),
|
|
||||||
WithPermanentGoal("I want to learn to play music"),
|
WithPermanentGoal("I want to learn to play music"),
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
@@ -214,17 +319,64 @@ var _ = Describe("Agent test", func() {
|
|||||||
result := agent.Ask(
|
result := agent.Ask(
|
||||||
types.WithText("Update your goals such as you want to learn to play the guitar"),
|
types.WithText("Update your goals such as you want to learn to play the guitar"),
|
||||||
)
|
)
|
||||||
fmt.Printf("%+v\n", result)
|
fmt.Fprintf(GinkgoWriter, "\n%+v\n", result)
|
||||||
Expect(result.Error).ToNot(HaveOccurred())
|
Expect(result.Error).ToNot(HaveOccurred())
|
||||||
Expect(agent.State().Goal).To(ContainSubstring("guitar"), fmt.Sprint(agent.State()))
|
Expect(agent.State().Goal).To(ContainSubstring("guitar"), fmt.Sprint(agent.State()))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("Can generate a plan", func() {
|
It("Can generate a plan", func() {
|
||||||
|
var llmClient llm.LLMClient
|
||||||
|
if useRealLocalAI {
|
||||||
|
llmClient = llm.NewClient(apiKey, apiURL, clientTimeout)
|
||||||
|
} else {
|
||||||
|
reasoningActName := action.NewReasoning().Definition().Name.String()
|
||||||
|
intentionActName := action.NewIntention().Definition().Name.String()
|
||||||
|
testActName := (&TestAction{}).Definition().Name.String()
|
||||||
|
doneBoston := false
|
||||||
|
madePlan := false
|
||||||
|
llmClient = newMockLLMClient(func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
|
||||||
|
var lastMsg openai.ChatCompletionMessage
|
||||||
|
if len(req.Messages) > 0 {
|
||||||
|
lastMsg = req.Messages[len(req.Messages)-1]
|
||||||
|
}
|
||||||
|
if req.ToolChoice != nil && req.ToolChoice.(openai.ToolChoice).Function.Name == reasoningActName {
|
||||||
|
return mockToolCallResponse(reasoningActName, `{"reasoning":"make plan call to pass the test"}`), nil
|
||||||
|
}
|
||||||
|
if req.ToolChoice != nil && req.ToolChoice.(openai.ToolChoice).Function.Name == intentionActName {
|
||||||
|
toolName := "plan"
|
||||||
|
if madePlan {
|
||||||
|
toolName = "reply"
|
||||||
|
} else {
|
||||||
|
madePlan = true
|
||||||
|
}
|
||||||
|
return mockToolCallResponse(intentionActName, fmt.Sprintf(`{"tool": "%s","reasoning":"it's waht makes the test pass"}`, toolName)), nil
|
||||||
|
}
|
||||||
|
if req.ToolChoice != nil && req.ToolChoice.(openai.ToolChoice).Function.Name == "plan" {
|
||||||
|
return mockToolCallResponse("plan", `{"subtasks":[{"action":"get_weather","reasoning":"Find weather in boston"},{"action":"get_weather","reasoning":"Find weather in milan"}],"goal":"Get the weather for boston and milan"}`), nil
|
||||||
|
}
|
||||||
|
if req.ToolChoice != nil && req.ToolChoice.(openai.ToolChoice).Function.Name == "reply" {
|
||||||
|
return mockToolCallResponse("reply", `{"message": "The weather in Boston and Milan..."}`), nil
|
||||||
|
}
|
||||||
|
if req.ToolChoice != nil && req.ToolChoice.(openai.ToolChoice).Function.Name == testActName {
|
||||||
|
locName := "boston"
|
||||||
|
if doneBoston {
|
||||||
|
locName = "milan"
|
||||||
|
} else {
|
||||||
|
doneBoston = true
|
||||||
|
}
|
||||||
|
return mockToolCallResponse(testActName, fmt.Sprintf(`{"location":"%s","unit":"celsius"}`, locName)), nil
|
||||||
|
}
|
||||||
|
if req.ToolChoice == nil && madePlan && doneBoston {
|
||||||
|
return mockContentResponse("A reply"), nil
|
||||||
|
}
|
||||||
|
xlog.Error("Unexpected LLM req", "req", req)
|
||||||
|
return openai.ChatCompletionResponse{}, fmt.Errorf("unexpected LLM prompt: %q", lastMsg.Content)
|
||||||
|
})
|
||||||
|
}
|
||||||
agent, err := New(
|
agent, err := New(
|
||||||
WithLLMAPIURL(apiURL),
|
WithLLMClient(llmClient),
|
||||||
WithModel(testModel),
|
WithModel(testModel),
|
||||||
WithLLMAPIKey(apiKeyURL),
|
WithLoopDetectionSteps(2),
|
||||||
WithTimeout("10m"),
|
|
||||||
WithActions(
|
WithActions(
|
||||||
&TestAction{response: map[string]string{
|
&TestAction{response: map[string]string{
|
||||||
"boston": testActionResult,
|
"boston": testActionResult,
|
||||||
@@ -233,8 +385,6 @@ var _ = Describe("Agent test", func() {
|
|||||||
),
|
),
|
||||||
EnablePlanning,
|
EnablePlanning,
|
||||||
EnableForceReasoning,
|
EnableForceReasoning,
|
||||||
// EnableStandaloneJob,
|
|
||||||
// WithRandomIdentity(),
|
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
go agent.Run()
|
go agent.Run()
|
||||||
@@ -256,17 +406,44 @@ var _ = Describe("Agent test", func() {
|
|||||||
Expect(actionsExecuted).To(ContainElement("plan"), fmt.Sprint(result))
|
Expect(actionsExecuted).To(ContainElement("plan"), fmt.Sprint(result))
|
||||||
Expect(actionResults).To(ContainElement(testActionResult), fmt.Sprint(result))
|
Expect(actionResults).To(ContainElement(testActionResult), fmt.Sprint(result))
|
||||||
Expect(actionResults).To(ContainElement(testActionResult2), fmt.Sprint(result))
|
Expect(actionResults).To(ContainElement(testActionResult2), fmt.Sprint(result))
|
||||||
|
Expect(result.Error).To(BeNil())
|
||||||
})
|
})
|
||||||
|
|
||||||
It("Can initiate conversations", func() {
|
It("Can initiate conversations", func() {
|
||||||
|
var llmClient llm.LLMClient
|
||||||
message := openai.ChatCompletionMessage{}
|
message := openai.ChatCompletionMessage{}
|
||||||
mu := &sync.Mutex{}
|
mu := &sync.Mutex{}
|
||||||
|
reasoned := false
|
||||||
|
intended := false
|
||||||
|
reasoningActName := action.NewReasoning().Definition().Name.String()
|
||||||
|
intentionActName := action.NewIntention().Definition().Name.String()
|
||||||
|
|
||||||
|
if useRealLocalAI {
|
||||||
|
llmClient = llm.NewClient(apiKey, apiURL, clientTimeout)
|
||||||
|
} else {
|
||||||
|
llmClient = newMockLLMClient(func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
|
||||||
|
prompt := ""
|
||||||
|
for _, msg := range req.Messages {
|
||||||
|
prompt += msg.Content
|
||||||
|
}
|
||||||
|
if !reasoned && req.ToolChoice != nil && req.ToolChoice.(openai.ToolChoice).Function.Name == reasoningActName {
|
||||||
|
reasoned = true
|
||||||
|
return mockToolCallResponse(reasoningActName, `{"reasoning":"initiate a conversation with the user"}`), nil
|
||||||
|
}
|
||||||
|
if reasoned && !intended && req.ToolChoice != nil && req.ToolChoice.(openai.ToolChoice).Function.Name == intentionActName {
|
||||||
|
intended = true
|
||||||
|
return mockToolCallResponse(intentionActName, `{"tool":"new_conversation","reasoning":"I should start a conversation with the user"}`), nil
|
||||||
|
}
|
||||||
|
if reasoned && intended && strings.Contains(strings.ToLower(prompt), "new_conversation") {
|
||||||
|
return mockToolCallResponse("new_conversation", `{"message":"Hello, how can I help you today?"}`), nil
|
||||||
|
}
|
||||||
|
xlog.Error("Unexpected LLM req", "req", req)
|
||||||
|
return openai.ChatCompletionResponse{}, fmt.Errorf("unexpected LLM prompt: %q", prompt)
|
||||||
|
})
|
||||||
|
}
|
||||||
agent, err := New(
|
agent, err := New(
|
||||||
WithLLMAPIURL(apiURL),
|
WithLLMClient(llmClient),
|
||||||
WithModel(testModel),
|
WithModel(testModel),
|
||||||
WithLLMAPIKey(apiKeyURL),
|
|
||||||
WithTimeout("10m"),
|
|
||||||
WithNewConversationSubscriber(func(m openai.ChatCompletionMessage) {
|
WithNewConversationSubscriber(func(m openai.ChatCompletionMessage) {
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
message = m
|
message = m
|
||||||
@@ -282,8 +459,6 @@ var _ = Describe("Agent test", func() {
|
|||||||
EnableHUD,
|
EnableHUD,
|
||||||
WithPeriodicRuns("1s"),
|
WithPeriodicRuns("1s"),
|
||||||
WithPermanentGoal("use the new_conversation tool to initiate a conversation with the user"),
|
WithPermanentGoal("use the new_conversation tool to initiate a conversation with the user"),
|
||||||
// EnableStandaloneJob,
|
|
||||||
// WithRandomIdentity(),
|
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
go agent.Run()
|
go agent.Run()
|
||||||
@@ -293,7 +468,7 @@ var _ = Describe("Agent test", func() {
|
|||||||
mu.Lock()
|
mu.Lock()
|
||||||
defer mu.Unlock()
|
defer mu.Unlock()
|
||||||
return message.Content
|
return message.Content
|
||||||
}, "10m", "10s").ShouldNot(BeEmpty())
|
}, "10m", "1s").ShouldNot(BeEmpty())
|
||||||
})
|
})
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@@ -347,7 +522,7 @@ var _ = Describe("Agent test", func() {
|
|||||||
// result := agent.Ask(
|
// result := agent.Ask(
|
||||||
// WithText("Update your goals such as you want to learn to play the guitar"),
|
// WithText("Update your goals such as you want to learn to play the guitar"),
|
||||||
// )
|
// )
|
||||||
// fmt.Printf("%+v\n", result)
|
// fmt.Fprintf(GinkgoWriter, "%+v\n", result)
|
||||||
// Expect(result.Error).ToNot(HaveOccurred())
|
// Expect(result.Error).ToNot(HaveOccurred())
|
||||||
// Expect(agent.State().Goal).To(ContainSubstring("guitar"), fmt.Sprint(agent.State()))
|
// Expect(agent.State().Goal).To(ContainSubstring("guitar"), fmt.Sprint(agent.State()))
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
|
|
||||||
"github.com/mudler/LocalAGI/core/types"
|
"github.com/mudler/LocalAGI/core/types"
|
||||||
"github.com/sashabaranov/go-openai"
|
"github.com/sashabaranov/go-openai"
|
||||||
|
"github.com/mudler/LocalAGI/pkg/llm"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Option func(*options) error
|
type Option func(*options) error
|
||||||
@@ -19,6 +20,7 @@ type llmOptions struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type options struct {
|
type options struct {
|
||||||
|
llmClient llm.LLMClient
|
||||||
LLMAPI llmOptions
|
LLMAPI llmOptions
|
||||||
character Character
|
character Character
|
||||||
randomIdentityGuidance string
|
randomIdentityGuidance string
|
||||||
@@ -68,6 +70,14 @@ type options struct {
|
|||||||
lastMessageDuration time.Duration
|
lastMessageDuration time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithLLMClient allows injecting a custom LLM client (e.g. for testing)
|
||||||
|
func WithLLMClient(client llm.LLMClient) Option {
|
||||||
|
return func(o *options) error {
|
||||||
|
o.llmClient = client
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (o *options) SeparatedMultimodalModel() bool {
|
func (o *options) SeparatedMultimodalModel() bool {
|
||||||
return o.LLMAPI.MultimodalModel != "" && o.LLMAPI.Model != o.LLMAPI.MultimodalModel
|
return o.LLMAPI.MultimodalModel != "" && o.LLMAPI.Model != o.LLMAPI.MultimodalModel
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,29 +1,57 @@
|
|||||||
package agent_test
|
package agent_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAGI/pkg/llm"
|
||||||
|
"github.com/sashabaranov/go-openai"
|
||||||
|
|
||||||
. "github.com/mudler/LocalAGI/core/agent"
|
. "github.com/mudler/LocalAGI/core/agent"
|
||||||
. "github.com/onsi/ginkgo/v2"
|
. "github.com/onsi/ginkgo/v2"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ = Describe("Agent test", func() {
|
var _ = Describe("Agent test", func() {
|
||||||
Context("identity", func() {
|
Context("identity", func() {
|
||||||
var agent *Agent
|
var agent *Agent
|
||||||
|
|
||||||
BeforeEach(func() {
|
// BeforeEach(func() {
|
||||||
Eventually(func() error {
|
// Eventually(func() error {
|
||||||
// test apiURL is working and available
|
// // test apiURL is working and available
|
||||||
_, err := http.Get(apiURL + "/readyz")
|
// _, err := http.Get(apiURL + "/readyz")
|
||||||
return err
|
// return err
|
||||||
}, "10m", "10s").ShouldNot(HaveOccurred())
|
// }, "10m", "10s").ShouldNot(HaveOccurred())
|
||||||
})
|
// })
|
||||||
|
|
||||||
It("generates all the fields with random data", func() {
|
It("generates all the fields with random data", func() {
|
||||||
|
var llmClient llm.LLMClient
|
||||||
|
if useRealLocalAI {
|
||||||
|
llmClient = llm.NewClient(apiKey, apiURL, testModel)
|
||||||
|
} else {
|
||||||
|
llmClient = &llm.MockClient{
|
||||||
|
CreateChatCompletionFunc: func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
|
||||||
|
return openai.ChatCompletionResponse{
|
||||||
|
Choices: []openai.ChatCompletionChoice{{
|
||||||
|
Message: openai.ChatCompletionMessage{
|
||||||
|
ToolCalls: []openai.ToolCall{{
|
||||||
|
ID: "tool_call_id_1",
|
||||||
|
Type: "function",
|
||||||
|
Function: openai.FunctionCall{
|
||||||
|
Name: "generate_identity",
|
||||||
|
Arguments: `{"name":"John Doe","age":"42","job_occupation":"Engineer","hobbies":["reading","hiking"],"favorites_music_genres":["Jazz"]}`,
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
var err error
|
var err error
|
||||||
agent, err = New(
|
agent, err = New(
|
||||||
WithLLMAPIURL(apiURL),
|
WithLLMClient(llmClient),
|
||||||
WithModel(testModel),
|
WithModel(testModel),
|
||||||
WithTimeout("10m"),
|
WithTimeout("10m"),
|
||||||
WithRandomIdentity(),
|
WithRandomIdentity(),
|
||||||
@@ -37,14 +65,40 @@ var _ = Describe("Agent test", func() {
|
|||||||
Expect(agent.Character.MusicTaste).ToNot(BeEmpty())
|
Expect(agent.Character.MusicTaste).ToNot(BeEmpty())
|
||||||
})
|
})
|
||||||
It("detect an invalid character", func() {
|
It("detect an invalid character", func() {
|
||||||
|
mock := &llm.MockClient{
|
||||||
|
CreateChatCompletionFunc: func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
|
||||||
|
return openai.ChatCompletionResponse{}, fmt.Errorf("invalid character")
|
||||||
|
},
|
||||||
|
}
|
||||||
var err error
|
var err error
|
||||||
agent, err = New(WithRandomIdentity())
|
agent, err = New(
|
||||||
|
WithLLMClient(mock),
|
||||||
|
WithRandomIdentity(),
|
||||||
|
)
|
||||||
Expect(err).To(HaveOccurred())
|
Expect(err).To(HaveOccurred())
|
||||||
})
|
})
|
||||||
It("generates all the fields", func() {
|
It("generates all the fields", func() {
|
||||||
|
mock := &llm.MockClient{
|
||||||
|
CreateChatCompletionFunc: func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
|
||||||
|
return openai.ChatCompletionResponse{
|
||||||
|
Choices: []openai.ChatCompletionChoice{{
|
||||||
|
Message: openai.ChatCompletionMessage{
|
||||||
|
ToolCalls: []openai.ToolCall{{
|
||||||
|
ID: "tool_call_id_2",
|
||||||
|
Type: "function",
|
||||||
|
Function: openai.FunctionCall{
|
||||||
|
Name: "generate_identity",
|
||||||
|
Arguments: `{"name":"Gandalf","age":"90","job_occupation":"Wizard","hobbies":["magic","reading"],"favorites_music_genres":["Classical"]}`,
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
agent, err := New(
|
agent, err := New(
|
||||||
|
WithLLMClient(mock),
|
||||||
WithLLMAPIURL(apiURL),
|
WithLLMAPIURL(apiURL),
|
||||||
WithModel(testModel),
|
WithModel(testModel),
|
||||||
WithRandomIdentity("An 90-year old man with a long beard, a wizard, who lives in a tower."),
|
WithRandomIdentity("An 90-year old man with a long beard, a wizard, who lives in a tower."),
|
||||||
|
|||||||
@@ -1,13 +1,33 @@
|
|||||||
package llm
|
package llm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAGI/pkg/xlog"
|
||||||
"github.com/sashabaranov/go-openai"
|
"github.com/sashabaranov/go-openai"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewClient(APIKey, URL, timeout string) *openai.Client {
|
type LLMClient interface {
|
||||||
|
CreateChatCompletion(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error)
|
||||||
|
CreateImage(ctx context.Context, req openai.ImageRequest) (openai.ImageResponse, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type realClient struct {
|
||||||
|
*openai.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *realClient) CreateChatCompletion(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
|
||||||
|
return r.Client.CreateChatCompletion(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *realClient) CreateImage(ctx context.Context, req openai.ImageRequest) (openai.ImageResponse, error) {
|
||||||
|
return r.Client.CreateImage(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClient returns a real OpenAI client as LLMClient
|
||||||
|
func NewClient(APIKey, URL, timeout string) LLMClient {
|
||||||
// Set up OpenAI client
|
// Set up OpenAI client
|
||||||
if APIKey == "" {
|
if APIKey == "" {
|
||||||
//log.Fatal("OPENAI_API_KEY environment variable not set")
|
//log.Fatal("OPENAI_API_KEY environment variable not set")
|
||||||
@@ -18,11 +38,12 @@ func NewClient(APIKey, URL, timeout string) *openai.Client {
|
|||||||
|
|
||||||
dur, err := time.ParseDuration(timeout)
|
dur, err := time.ParseDuration(timeout)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
xlog.Error("Failed to parse timeout", "error", err)
|
||||||
dur = 150 * time.Second
|
dur = 150 * time.Second
|
||||||
}
|
}
|
||||||
|
|
||||||
config.HTTPClient = &http.Client{
|
config.HTTPClient = &http.Client{
|
||||||
Timeout: dur,
|
Timeout: dur,
|
||||||
}
|
}
|
||||||
return openai.NewClientWithConfig(config)
|
return &realClient{openai.NewClientWithConfig(config)}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
"github.com/sashabaranov/go-openai/jsonschema"
|
"github.com/sashabaranov/go-openai/jsonschema"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GenerateTypedJSONWithGuidance(ctx context.Context, client *openai.Client, guidance, model string, i jsonschema.Definition, dst any) error {
|
func GenerateTypedJSONWithGuidance(ctx context.Context, client LLMClient, guidance, model string, i jsonschema.Definition, dst any) error {
|
||||||
return GenerateTypedJSONWithConversation(ctx, client, []openai.ChatCompletionMessage{
|
return GenerateTypedJSONWithConversation(ctx, client, []openai.ChatCompletionMessage{
|
||||||
{
|
{
|
||||||
Role: "user",
|
Role: "user",
|
||||||
@@ -19,7 +19,7 @@ func GenerateTypedJSONWithGuidance(ctx context.Context, client *openai.Client, g
|
|||||||
}, model, i, dst)
|
}, model, i, dst)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenerateTypedJSONWithConversation(ctx context.Context, client *openai.Client, conv []openai.ChatCompletionMessage, model string, i jsonschema.Definition, dst any) error {
|
func GenerateTypedJSONWithConversation(ctx context.Context, client LLMClient, conv []openai.ChatCompletionMessage, model string, i jsonschema.Definition, dst any) error {
|
||||||
toolName := "json"
|
toolName := "json"
|
||||||
decision := openai.ChatCompletionRequest{
|
decision := openai.ChatCompletionRequest{
|
||||||
Model: model,
|
Model: model,
|
||||||
|
|||||||
25
pkg/llm/mock_client.go
Normal file
25
pkg/llm/mock_client.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
package llm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"github.com/sashabaranov/go-openai"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MockClient struct {
|
||||||
|
CreateChatCompletionFunc func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error)
|
||||||
|
CreateImageFunc func(ctx context.Context, req openai.ImageRequest) (openai.ImageResponse, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockClient) CreateChatCompletion(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
|
||||||
|
if m.CreateChatCompletionFunc != nil {
|
||||||
|
return m.CreateChatCompletionFunc(ctx, req)
|
||||||
|
}
|
||||||
|
return openai.ChatCompletionResponse{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockClient) CreateImage(ctx context.Context, req openai.ImageRequest) (openai.ImageResponse, error) {
|
||||||
|
if m.CreateImageFunc != nil {
|
||||||
|
return m.CreateImageFunc(ctx, req)
|
||||||
|
}
|
||||||
|
return openai.ImageResponse{}, nil
|
||||||
|
}
|
||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
"github.com/mudler/LocalAGI/core/types"
|
"github.com/mudler/LocalAGI/core/types"
|
||||||
"github.com/mudler/LocalAGI/pkg/config"
|
"github.com/mudler/LocalAGI/pkg/config"
|
||||||
"github.com/mudler/LocalAGI/pkg/llm"
|
"github.com/mudler/LocalAGI/pkg/llm"
|
||||||
"github.com/sashabaranov/go-openai"
|
|
||||||
"github.com/sashabaranov/go-openai/jsonschema"
|
"github.com/sashabaranov/go-openai/jsonschema"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -16,7 +15,7 @@ const FilterClassifier = "classifier"
|
|||||||
|
|
||||||
type ClassifierFilter struct {
|
type ClassifierFilter struct {
|
||||||
name string
|
name string
|
||||||
client *openai.Client
|
client llm.LLMClient
|
||||||
model string
|
model string
|
||||||
description string
|
description string
|
||||||
allowOnMatch bool
|
allowOnMatch bool
|
||||||
|
|||||||
Reference in New Issue
Block a user