feat(agent): shared state, allow to track conversations globally (#148)
* feat(agent): shared state, allow to track conversations globally Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Cleanup Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * track conversations initiated by the bot Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
2b07dd79ec
commit
c23e655f44
@@ -81,7 +81,7 @@ func (a *CustomAction) Plannable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (a *CustomAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) {
|
||||
func (a *CustomAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||
v, err := a.i.Eval(fmt.Sprintf("%s.Run", a.config["name"]))
|
||||
if err != nil {
|
||||
return types.ActionResult{}, err
|
||||
|
||||
@@ -76,7 +76,7 @@ return []string{"foo"}
|
||||
Description: "A test action",
|
||||
}))
|
||||
|
||||
runResult, err := customAction.Run(context.Background(), types.ActionParams{
|
||||
runResult, err := customAction.Run(context.Background(), nil, types.ActionParams{
|
||||
"Foo": "bar",
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
@@ -21,7 +21,7 @@ type GoalResponse struct {
|
||||
Achieved bool `json:"achieved"`
|
||||
}
|
||||
|
||||
func (a *GoalAction) Run(context.Context, types.ActionParams) (types.ActionResult, error) {
|
||||
func (a *GoalAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||
return types.ActionResult{}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ type IntentResponse struct {
|
||||
Reasoning string `json:"reasoning"`
|
||||
}
|
||||
|
||||
func (a *IntentAction) Run(context.Context, types.ActionParams) (types.ActionResult, error) {
|
||||
func (a *IntentAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||
return types.ActionResult{}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ type ConversationActionResponse struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
func (a *ConversationAction) Run(context.Context, types.ActionParams) (types.ActionResult, error) {
|
||||
func (a *ConversationAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||
return types.ActionResult{}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ func NewStop() *StopAction {
|
||||
|
||||
type StopAction struct{}
|
||||
|
||||
func (a *StopAction) Run(context.Context, types.ActionParams) (types.ActionResult, error) {
|
||||
func (a *StopAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||
return types.ActionResult{}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ type PlanSubtask struct {
|
||||
Reasoning string `json:"reasoning"`
|
||||
}
|
||||
|
||||
func (a *PlanAction) Run(context.Context, types.ActionParams) (types.ActionResult, error) {
|
||||
func (a *PlanAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||
return types.ActionResult{}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ type ReasoningResponse struct {
|
||||
Reasoning string `json:"reasoning"`
|
||||
}
|
||||
|
||||
func (a *ReasoningAction) Run(context.Context, types.ActionParams) (types.ActionResult, error) {
|
||||
func (a *ReasoningAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||
return types.ActionResult{}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ type ReplyResponse struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
func (a *ReplyAction) Run(context.Context, types.ActionParams) (string, error) {
|
||||
func (a *ReplyAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (string, error) {
|
||||
return "no-op", nil
|
||||
}
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ func NewState() *StateAction {
|
||||
|
||||
type StateAction struct{}
|
||||
|
||||
func (a *StateAction) Run(context.Context, types.ActionParams) (types.ActionResult, error) {
|
||||
func (a *StateAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||
return types.ActionResult{Result: "internal state has been updated"}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -46,6 +46,8 @@ type Agent struct {
|
||||
newMessagesSubscribers []func(openai.ChatCompletionMessage)
|
||||
|
||||
observer Observer
|
||||
|
||||
sharedState *types.AgentSharedState
|
||||
}
|
||||
|
||||
type RAGDB interface {
|
||||
@@ -78,6 +80,7 @@ func New(opts ...Option) (*Agent, error) {
|
||||
context: types.NewActionContext(ctx, cancel),
|
||||
newConversations: make(chan openai.ChatCompletionMessage),
|
||||
newMessagesSubscribers: options.newConversationsSubscribers,
|
||||
sharedState: types.NewAgentSharedState(options.lastMessageDuration),
|
||||
}
|
||||
|
||||
// Initialize observer if provided
|
||||
@@ -118,6 +121,10 @@ func New(opts ...Option) (*Agent, error) {
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func (a *Agent) SharedState() *types.AgentSharedState {
|
||||
return a.sharedState
|
||||
}
|
||||
|
||||
func (a *Agent) startNewConversationsConsumer() {
|
||||
go func() {
|
||||
for {
|
||||
@@ -294,7 +301,7 @@ func (a *Agent) runAction(job *types.Job, chosenAction types.Action, params type
|
||||
|
||||
for _, act := range a.availableActions() {
|
||||
if act.Definition().Name == chosenAction.Definition().Name {
|
||||
res, err := act.Run(job.GetContext(), params)
|
||||
res, err := act.Run(job.GetContext(), a.sharedState, params)
|
||||
if err != nil {
|
||||
if obs != nil {
|
||||
obs.Completion = &types.Completion{
|
||||
|
||||
@@ -44,7 +44,7 @@ func (a *TestAction) Plannable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (a *TestAction) Run(c context.Context, p types.ActionParams) (types.ActionResult, error) {
|
||||
func (a *TestAction) Run(c context.Context, sharedState *types.AgentSharedState, p types.ActionParams) (types.ActionResult, error) {
|
||||
for k, r := range a.response {
|
||||
if strings.Contains(strings.ToLower(p.String()), strings.ToLower(k)) {
|
||||
return types.ActionResult{Result: r}, nil
|
||||
|
||||
@@ -38,7 +38,7 @@ func (a *mcpAction) Plannable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *mcpAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) {
|
||||
func (m *mcpAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||
resp, err := m.mcpClient.CallTool(ctx, m.toolName, params)
|
||||
if err != nil {
|
||||
xlog.Error("Failed to call tool", "error", err.Error())
|
||||
|
||||
@@ -64,6 +64,8 @@ type options struct {
|
||||
|
||||
observer Observer
|
||||
parallelJobs int
|
||||
|
||||
lastMessageDuration time.Duration
|
||||
}
|
||||
|
||||
func (o *options) SeparatedMultimodalModel() bool {
|
||||
@@ -151,6 +153,17 @@ func EnableKnowledgeBaseWithResults(results int) Option {
|
||||
}
|
||||
}
|
||||
|
||||
func WithLastMessageDuration(duration string) Option {
|
||||
return func(o *options) error {
|
||||
d, err := time.ParseDuration(duration)
|
||||
if err != nil {
|
||||
d = types.DefaultLastMessageDuration
|
||||
}
|
||||
o.lastMessageDuration = d
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithParallelJobs(jobs int) Option {
|
||||
return func(o *options) error {
|
||||
o.parallelJobs = jobs
|
||||
|
||||
@@ -14,10 +14,10 @@ import (
|
||||
// all information that should be displayed to the LLM
|
||||
// in the prompts
|
||||
type PromptHUD struct {
|
||||
Character Character `json:"character"`
|
||||
Character Character `json:"character"`
|
||||
CurrentState types.AgentInternalState `json:"current_state"`
|
||||
PermanentGoal string `json:"permanent_goal"`
|
||||
ShowCharacter bool `json:"show_character"`
|
||||
PermanentGoal string `json:"permanent_goal"`
|
||||
ShowCharacter bool `json:"show_character"`
|
||||
}
|
||||
|
||||
type Character struct {
|
||||
|
||||
13
core/conversations/conversations_suite_test.go
Normal file
13
core/conversations/conversations_suite_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package conversations_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestConversations(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Conversations test suite")
|
||||
}
|
||||
84
core/conversations/conversationstracker.go
Normal file
84
core/conversations/conversationstracker.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package conversations
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAGI/pkg/xlog"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
type TrackerKey interface{ ~int | ~int64 | ~string }
|
||||
|
||||
type ConversationTracker[K TrackerKey] struct {
|
||||
convMutex sync.Mutex
|
||||
currentconversation map[K][]openai.ChatCompletionMessage
|
||||
lastMessageTime map[K]time.Time
|
||||
lastMessageDuration time.Duration
|
||||
}
|
||||
|
||||
func NewConversationTracker[K TrackerKey](lastMessageDuration time.Duration) *ConversationTracker[K] {
|
||||
return &ConversationTracker[K]{
|
||||
lastMessageDuration: lastMessageDuration,
|
||||
currentconversation: map[K][]openai.ChatCompletionMessage{},
|
||||
lastMessageTime: map[K]time.Time{},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ConversationTracker[K]) GetConversation(key K) []openai.ChatCompletionMessage {
|
||||
// Lock the conversation mutex to update the conversation history
|
||||
c.convMutex.Lock()
|
||||
defer c.convMutex.Unlock()
|
||||
|
||||
// Clear up the conversation if the last message was sent more than lastMessageDuration ago
|
||||
currentConv := []openai.ChatCompletionMessage{}
|
||||
lastMessageTime := c.lastMessageTime[key]
|
||||
if lastMessageTime.IsZero() {
|
||||
lastMessageTime = time.Now()
|
||||
}
|
||||
if lastMessageTime.Add(c.lastMessageDuration).Before(time.Now()) {
|
||||
currentConv = []openai.ChatCompletionMessage{}
|
||||
c.lastMessageTime[key] = time.Now()
|
||||
xlog.Debug("Conversation history does not exist for", "key", fmt.Sprintf("%v", key))
|
||||
} else {
|
||||
xlog.Debug("Conversation history exists for", "key", fmt.Sprintf("%v", key))
|
||||
currentConv = append(currentConv, c.currentconversation[key]...)
|
||||
}
|
||||
|
||||
// cleanup other conversations if older
|
||||
for k := range c.currentconversation {
|
||||
lastMessage, exists := c.lastMessageTime[k]
|
||||
if !exists {
|
||||
delete(c.currentconversation, k)
|
||||
delete(c.lastMessageTime, k)
|
||||
continue
|
||||
}
|
||||
if lastMessage.Add(c.lastMessageDuration).Before(time.Now()) {
|
||||
xlog.Debug("Cleaning up conversation for", k)
|
||||
delete(c.currentconversation, k)
|
||||
delete(c.lastMessageTime, k)
|
||||
}
|
||||
}
|
||||
|
||||
return currentConv
|
||||
|
||||
}
|
||||
|
||||
func (c *ConversationTracker[K]) AddMessage(key K, message openai.ChatCompletionMessage) {
|
||||
// Lock the conversation mutex to update the conversation history
|
||||
c.convMutex.Lock()
|
||||
defer c.convMutex.Unlock()
|
||||
|
||||
c.currentconversation[key] = append(c.currentconversation[key], message)
|
||||
c.lastMessageTime[key] = time.Now()
|
||||
}
|
||||
|
||||
func (c *ConversationTracker[K]) SetConversation(key K, messages []openai.ChatCompletionMessage) {
|
||||
// Lock the conversation mutex to update the conversation history
|
||||
c.convMutex.Lock()
|
||||
defer c.convMutex.Unlock()
|
||||
|
||||
c.currentconversation[key] = messages
|
||||
c.lastMessageTime[key] = time.Now()
|
||||
}
|
||||
111
core/conversations/conversationstracker_test.go
Normal file
111
core/conversations/conversationstracker_test.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package conversations_test
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAGI/core/conversations"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
var _ = Describe("ConversationTracker", func() {
|
||||
var (
|
||||
tracker *conversations.ConversationTracker[string]
|
||||
duration time.Duration
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
duration = 1 * time.Second
|
||||
tracker = conversations.NewConversationTracker[string](duration)
|
||||
})
|
||||
|
||||
It("should initialize with empty conversations", func() {
|
||||
Expect(tracker.GetConversation("test")).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("should add a message and retrieve it", func() {
|
||||
message := openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: "Hello",
|
||||
}
|
||||
tracker.AddMessage("test", message)
|
||||
conv := tracker.GetConversation("test")
|
||||
Expect(conv).To(HaveLen(1))
|
||||
Expect(conv[0]).To(Equal(message))
|
||||
})
|
||||
|
||||
It("should clear the conversation after the duration", func() {
|
||||
message := openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: "Hello",
|
||||
}
|
||||
tracker.AddMessage("test", message)
|
||||
time.Sleep(2 * time.Second)
|
||||
conv := tracker.GetConversation("test")
|
||||
Expect(conv).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("should keep the conversation within the duration", func() {
|
||||
message := openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: "Hello",
|
||||
}
|
||||
tracker.AddMessage("test", message)
|
||||
time.Sleep(500 * time.Millisecond) // Half the duration
|
||||
conv := tracker.GetConversation("test")
|
||||
Expect(conv).To(HaveLen(1))
|
||||
Expect(conv[0]).To(Equal(message))
|
||||
})
|
||||
|
||||
It("should handle multiple keys and clear old conversations", func() {
|
||||
message1 := openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: "Hello 1",
|
||||
}
|
||||
message2 := openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: "Hello 2",
|
||||
}
|
||||
|
||||
tracker.AddMessage("key1", message1)
|
||||
tracker.AddMessage("key2", message2)
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
conv1 := tracker.GetConversation("key1")
|
||||
conv2 := tracker.GetConversation("key2")
|
||||
|
||||
Expect(conv1).To(BeEmpty())
|
||||
Expect(conv2).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("should handle different key types", func() {
|
||||
trackerInt := conversations.NewConversationTracker[int](duration)
|
||||
trackerInt64 := conversations.NewConversationTracker[int64](duration)
|
||||
|
||||
message := openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: "Hello",
|
||||
}
|
||||
|
||||
trackerInt.AddMessage(1, message)
|
||||
trackerInt64.AddMessage(int64(1), message)
|
||||
|
||||
Expect(trackerInt.GetConversation(1)).To(HaveLen(1))
|
||||
Expect(trackerInt64.GetConversation(int64(1))).To(HaveLen(1))
|
||||
})
|
||||
|
||||
It("should cleanup other conversations if older", func() {
|
||||
message := openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: "Hello",
|
||||
}
|
||||
tracker.AddMessage("key1", message)
|
||||
tracker.AddMessage("key2", message)
|
||||
time.Sleep(2 * time.Second)
|
||||
tracker.GetConversation("key3")
|
||||
Expect(tracker.GetConversation("key1")).To(BeEmpty())
|
||||
Expect(tracker.GetConversation("key2")).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
@@ -48,12 +48,13 @@ type AgentConfig struct {
|
||||
|
||||
Description string `json:"description" form:"description"`
|
||||
|
||||
Model string `json:"model" form:"model"`
|
||||
MultimodalModel string `json:"multimodal_model" form:"multimodal_model"`
|
||||
APIURL string `json:"api_url" form:"api_url"`
|
||||
APIKey string `json:"api_key" form:"api_key"`
|
||||
LocalRAGURL string `json:"local_rag_url" form:"local_rag_url"`
|
||||
LocalRAGAPIKey string `json:"local_rag_api_key" form:"local_rag_api_key"`
|
||||
Model string `json:"model" form:"model"`
|
||||
MultimodalModel string `json:"multimodal_model" form:"multimodal_model"`
|
||||
APIURL string `json:"api_url" form:"api_url"`
|
||||
APIKey string `json:"api_key" form:"api_key"`
|
||||
LocalRAGURL string `json:"local_rag_url" form:"local_rag_url"`
|
||||
LocalRAGAPIKey string `json:"local_rag_api_key" form:"local_rag_api_key"`
|
||||
LastMessageDuration string `json:"last_message_duration" form:"last_message_duration"`
|
||||
|
||||
Name string `json:"name" form:"name"`
|
||||
HUD bool `json:"hud" form:"hud"`
|
||||
@@ -329,6 +330,14 @@ func NewAgentConfigMeta(
|
||||
HelpText: "Maximum number of evaluation loops to perform when addressing gaps in responses",
|
||||
Tags: config.Tags{Section: "AdvancedSettings"},
|
||||
},
|
||||
{
|
||||
Name: "last_message_duration",
|
||||
Label: "Last Message Duration",
|
||||
Type: "text",
|
||||
DefaultValue: "5m",
|
||||
HelpText: "Duration for the last message to be considered in the conversation",
|
||||
Tags: config.Tags{Section: "AdvancedSettings"},
|
||||
},
|
||||
},
|
||||
MCPServers: []config.Field{
|
||||
{
|
||||
|
||||
@@ -462,6 +462,7 @@ func (a *AgentPool) startAgentWithConfig(name string, config *AgentConfig, obs O
|
||||
}),
|
||||
WithSystemPrompt(config.SystemPrompt),
|
||||
WithMultimodalModel(multimodalModel),
|
||||
WithLastMessageDuration(config.LastMessageDuration),
|
||||
WithAgentResultCallback(func(state types.ActionState) {
|
||||
a.Lock()
|
||||
if _, ok := a.agentStatus[name]; !ok {
|
||||
|
||||
@@ -88,7 +88,7 @@ func (a ActionDefinition) ToFunctionDefinition() *openai.FunctionDefinition {
|
||||
|
||||
// Actions is something the agent can do
|
||||
type Action interface {
|
||||
Run(ctx context.Context, action ActionParams) (ActionResult, error)
|
||||
Run(ctx context.Context, sharedState *AgentSharedState, action ActionParams) (ActionResult, error)
|
||||
Definition() ActionDefinition
|
||||
Plannable() bool
|
||||
}
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
package types
|
||||
|
||||
import "fmt"
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAGI/core/conversations"
|
||||
)
|
||||
|
||||
// State is the structure
|
||||
// that is used to keep track of the current state
|
||||
@@ -20,6 +25,23 @@ type AgentInternalState struct {
|
||||
Goal string `json:"goal"`
|
||||
}
|
||||
|
||||
const (
|
||||
DefaultLastMessageDuration = 5 * time.Minute
|
||||
)
|
||||
|
||||
type AgentSharedState struct {
|
||||
ConversationTracker *conversations.ConversationTracker[string] `json:"conversation_tracker"`
|
||||
}
|
||||
|
||||
func NewAgentSharedState(lastMessageDuration time.Duration) *AgentSharedState {
|
||||
if lastMessageDuration == 0 {
|
||||
lastMessageDuration = DefaultLastMessageDuration
|
||||
}
|
||||
return &AgentSharedState{
|
||||
ConversationTracker: conversations.NewConversationTracker[string](lastMessageDuration),
|
||||
}
|
||||
}
|
||||
|
||||
const fmtT = `=====================
|
||||
NowDoing: %s
|
||||
DoingNext: %s
|
||||
|
||||
Reference in New Issue
Block a user