From c23e655f44ba649aac803ff840e46c613cb0d31c Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sun, 11 May 2025 22:23:01 +0200 Subject: [PATCH] feat(agent): shared state, allow to track conversations globally (#148) * feat(agent): shared state, allow to track conversations globally Signed-off-by: Ettore Di Giacinto * Cleanup Signed-off-by: Ettore Di Giacinto * track conversations initiated by the bot Signed-off-by: Ettore Di Giacinto --------- Signed-off-by: Ettore Di Giacinto --- core/action/custom.go | 2 +- core/action/custom_test.go | 2 +- core/action/goal.go | 2 +- core/action/intention.go | 2 +- core/action/newconversation.go | 2 +- core/action/noreply.go | 2 +- core/action/plan.go | 2 +- core/action/reasoning.go | 2 +- core/action/reply.go | 2 +- core/action/state.go | 2 +- core/agent/agent.go | 9 +- core/agent/agent_test.go | 2 +- core/agent/mcp.go | 2 +- core/agent/options.go | 13 +++ core/agent/state.go | 6 +- .../conversations/conversations_suite_test.go | 13 +++ .../conversations}/conversationstracker.go | 2 +- .../conversationstracker_test.go | 12 +-- core/state/config.go | 21 +++-- core/state/pool.go | 1 + core/types/actions.go | 2 +- core/types/state.go | 24 ++++- services/actions/browse.go | 2 +- services/actions/browseragentrunner.go | 2 +- services/actions/callagents.go | 2 +- services/actions/counter.go | 2 +- services/actions/deepresearchrunner.go | 2 +- services/actions/genimage.go | 2 +- services/actions/genimage_test.go | 4 +- services/actions/githubissuecloser.go | 2 +- services/actions/githubissuecomment.go | 2 +- services/actions/githubissueedit.go | 2 +- services/actions/githubissuelabeler.go | 2 +- services/actions/githubissueopener.go | 2 +- services/actions/githubissuereader.go | 2 +- services/actions/githubissuesearch.go | 2 +- services/actions/githubprcommenter.go | 94 +------------------ services/actions/githubprcreator.go | 2 +- services/actions/githubprcreator_test.go | 4 +- services/actions/githubprreader.go | 2 +- services/actions/githubprreviewer.go | 2 +- services/actions/githubprreviewer_test.go | 6 +- .../githubrepositorycreateupdatecontent.go | 2 +- .../actions/githubrepositorygetallcontent.go | 2 +- .../githubrepositorygetallcontent_test.go | 4 +- .../actions/githubrepositorygetcontent.go | 2 +- services/actions/githubrepositorylistfiles.go | 2 +- services/actions/githubrepositoryreadme.go | 2 +- .../actions/githubrepositorysearchfiles.go | 2 +- services/actions/scrape.go | 2 +- services/actions/search.go | 2 +- services/actions/sendmail.go | 2 +- services/actions/sendtelegrammessage.go | 57 ++++++++--- services/actions/shell.go | 2 +- services/actions/twitter_post.go | 2 +- services/actions/wikipedia.go | 2 +- services/connectors/discord.go | 23 ++--- services/connectors/irc.go | 42 +++------ services/connectors/matrix.go | 37 +++----- services/connectors/slack.go | 56 ++++------- services/connectors/telegram.go | 72 ++++++++------ webui/app.go | 18 ++-- webui/routes.go | 6 +- 63 files changed, 290 insertions(+), 316 deletions(-) create mode 100644 core/conversations/conversations_suite_test.go rename {services/connectors => core/conversations}/conversationstracker.go (99%) rename {services/connectors => core/conversations}/conversationstracker_test.go (89%) diff --git a/core/action/custom.go b/core/action/custom.go index af5d590..96e7957 100644 --- a/core/action/custom.go +++ b/core/action/custom.go @@ -81,7 +81,7 @@ func (a *CustomAction) Plannable() bool { return true } -func (a *CustomAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { +func (a *CustomAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) { v, err := a.i.Eval(fmt.Sprintf("%s.Run", a.config["name"])) if err != nil { return types.ActionResult{}, err diff --git a/core/action/custom_test.go b/core/action/custom_test.go index f0171db..9d885b5 100644 --- a/core/action/custom_test.go +++ b/core/action/custom_test.go @@ -76,7 +76,7 @@ return []string{"foo"} Description: "A test action", })) - runResult, err := customAction.Run(context.Background(), types.ActionParams{ + runResult, err := customAction.Run(context.Background(), nil, types.ActionParams{ "Foo": "bar", }) Expect(err).ToNot(HaveOccurred()) diff --git a/core/action/goal.go b/core/action/goal.go index 7c0851d..e746201 100644 --- a/core/action/goal.go +++ b/core/action/goal.go @@ -21,7 +21,7 @@ type GoalResponse struct { Achieved bool `json:"achieved"` } -func (a *GoalAction) Run(context.Context, types.ActionParams) (types.ActionResult, error) { +func (a *GoalAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) { return types.ActionResult{}, nil } diff --git a/core/action/intention.go b/core/action/intention.go index 8316c9f..229fd07 100644 --- a/core/action/intention.go +++ b/core/action/intention.go @@ -22,7 +22,7 @@ type IntentResponse struct { Reasoning string `json:"reasoning"` } -func (a *IntentAction) Run(context.Context, types.ActionParams) (types.ActionResult, error) { +func (a *IntentAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) { return types.ActionResult{}, nil } diff --git a/core/action/newconversation.go b/core/action/newconversation.go index aa1a859..5fa9dc8 100644 --- a/core/action/newconversation.go +++ b/core/action/newconversation.go @@ -19,7 +19,7 @@ type ConversationActionResponse struct { Message string `json:"message"` } -func (a *ConversationAction) Run(context.Context, types.ActionParams) (types.ActionResult, error) { +func (a *ConversationAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) { return types.ActionResult{}, nil } diff --git a/core/action/noreply.go b/core/action/noreply.go index c2ed874..f0e77a3 100644 --- a/core/action/noreply.go +++ b/core/action/noreply.go @@ -16,7 +16,7 @@ func NewStop() *StopAction { type StopAction struct{} -func (a *StopAction) Run(context.Context, types.ActionParams) (types.ActionResult, error) { +func (a *StopAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) { return types.ActionResult{}, nil } diff --git a/core/action/plan.go b/core/action/plan.go index 6f8d5b3..5685cac 100644 --- a/core/action/plan.go +++ b/core/action/plan.go @@ -30,7 +30,7 @@ type PlanSubtask struct { Reasoning string `json:"reasoning"` } -func (a *PlanAction) Run(context.Context, types.ActionParams) (types.ActionResult, error) { +func (a *PlanAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) { return types.ActionResult{}, nil } diff --git a/core/action/reasoning.go b/core/action/reasoning.go index ee077ed..7e94fd9 100644 --- a/core/action/reasoning.go +++ b/core/action/reasoning.go @@ -20,7 +20,7 @@ type ReasoningResponse struct { Reasoning string `json:"reasoning"` } -func (a *ReasoningAction) Run(context.Context, types.ActionParams) (types.ActionResult, error) { +func (a *ReasoningAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) { return types.ActionResult{}, nil } diff --git a/core/action/reply.go b/core/action/reply.go index 2d99c68..2fe00a2 100644 --- a/core/action/reply.go +++ b/core/action/reply.go @@ -22,7 +22,7 @@ type ReplyResponse struct { Message string `json:"message"` } -func (a *ReplyAction) Run(context.Context, types.ActionParams) (string, error) { +func (a *ReplyAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (string, error) { return "no-op", nil } diff --git a/core/action/state.go b/core/action/state.go index d0ac586..ed75a48 100644 --- a/core/action/state.go +++ b/core/action/state.go @@ -15,7 +15,7 @@ func NewState() *StateAction { type StateAction struct{} -func (a *StateAction) Run(context.Context, types.ActionParams) (types.ActionResult, error) { +func (a *StateAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) { return types.ActionResult{Result: "internal state has been updated"}, nil } diff --git a/core/agent/agent.go b/core/agent/agent.go index 016d9bf..a614bb1 100644 --- a/core/agent/agent.go +++ b/core/agent/agent.go @@ -46,6 +46,8 @@ type Agent struct { newMessagesSubscribers []func(openai.ChatCompletionMessage) observer Observer + + sharedState *types.AgentSharedState } type RAGDB interface { @@ -78,6 +80,7 @@ func New(opts ...Option) (*Agent, error) { context: types.NewActionContext(ctx, cancel), newConversations: make(chan openai.ChatCompletionMessage), newMessagesSubscribers: options.newConversationsSubscribers, + sharedState: types.NewAgentSharedState(options.lastMessageDuration), } // Initialize observer if provided @@ -118,6 +121,10 @@ func New(opts ...Option) (*Agent, error) { return a, nil } +func (a *Agent) SharedState() *types.AgentSharedState { + return a.sharedState +} + func (a *Agent) startNewConversationsConsumer() { go func() { for { @@ -294,7 +301,7 @@ func (a *Agent) runAction(job *types.Job, chosenAction types.Action, params type for _, act := range a.availableActions() { if act.Definition().Name == chosenAction.Definition().Name { - res, err := act.Run(job.GetContext(), params) + res, err := act.Run(job.GetContext(), a.sharedState, params) if err != nil { if obs != nil { obs.Completion = &types.Completion{ diff --git a/core/agent/agent_test.go b/core/agent/agent_test.go index 0b80945..cdc2e8b 100644 --- a/core/agent/agent_test.go +++ b/core/agent/agent_test.go @@ -44,7 +44,7 @@ func (a *TestAction) Plannable() bool { return true } -func (a *TestAction) Run(c context.Context, p types.ActionParams) (types.ActionResult, error) { +func (a *TestAction) Run(c context.Context, sharedState *types.AgentSharedState, p types.ActionParams) (types.ActionResult, error) { for k, r := range a.response { if strings.Contains(strings.ToLower(p.String()), strings.ToLower(k)) { return types.ActionResult{Result: r}, nil diff --git a/core/agent/mcp.go b/core/agent/mcp.go index 0d3987b..bf05124 100644 --- a/core/agent/mcp.go +++ b/core/agent/mcp.go @@ -38,7 +38,7 @@ func (a *mcpAction) Plannable() bool { return true } -func (m *mcpAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { +func (m *mcpAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) { resp, err := m.mcpClient.CallTool(ctx, m.toolName, params) if err != nil { xlog.Error("Failed to call tool", "error", err.Error()) diff --git a/core/agent/options.go b/core/agent/options.go index c7f4514..e1c0e97 100644 --- a/core/agent/options.go +++ b/core/agent/options.go @@ -64,6 +64,8 @@ type options struct { observer Observer parallelJobs int + + lastMessageDuration time.Duration } func (o *options) SeparatedMultimodalModel() bool { @@ -151,6 +153,17 @@ func EnableKnowledgeBaseWithResults(results int) Option { } } +func WithLastMessageDuration(duration string) Option { + return func(o *options) error { + d, err := time.ParseDuration(duration) + if err != nil { + d = types.DefaultLastMessageDuration + } + o.lastMessageDuration = d + return nil + } +} + func WithParallelJobs(jobs int) Option { return func(o *options) error { o.parallelJobs = jobs diff --git a/core/agent/state.go b/core/agent/state.go index 2c35306..2f1cf41 100644 --- a/core/agent/state.go +++ b/core/agent/state.go @@ -14,10 +14,10 @@ import ( // all information that should be displayed to the LLM // in the prompts type PromptHUD struct { - Character Character `json:"character"` + Character Character `json:"character"` CurrentState types.AgentInternalState `json:"current_state"` - PermanentGoal string `json:"permanent_goal"` - ShowCharacter bool `json:"show_character"` + PermanentGoal string `json:"permanent_goal"` + ShowCharacter bool `json:"show_character"` } type Character struct { diff --git a/core/conversations/conversations_suite_test.go b/core/conversations/conversations_suite_test.go new file mode 100644 index 0000000..8fcfed4 --- /dev/null +++ b/core/conversations/conversations_suite_test.go @@ -0,0 +1,13 @@ +package conversations_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestConversations(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Conversations test suite") +} diff --git a/services/connectors/conversationstracker.go b/core/conversations/conversationstracker.go similarity index 99% rename from services/connectors/conversationstracker.go rename to core/conversations/conversationstracker.go index 5e70e5a..c431e88 100644 --- a/services/connectors/conversationstracker.go +++ b/core/conversations/conversationstracker.go @@ -1,4 +1,4 @@ -package connectors +package conversations import ( "fmt" diff --git a/services/connectors/conversationstracker_test.go b/core/conversations/conversationstracker_test.go similarity index 89% rename from services/connectors/conversationstracker_test.go rename to core/conversations/conversationstracker_test.go index 091e3b5..2fd2250 100644 --- a/services/connectors/conversationstracker_test.go +++ b/core/conversations/conversationstracker_test.go @@ -1,9 +1,9 @@ -package connectors_test +package conversations_test import ( "time" - "github.com/mudler/LocalAGI/services/connectors" + "github.com/mudler/LocalAGI/core/conversations" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "github.com/sashabaranov/go-openai" @@ -11,13 +11,13 @@ import ( var _ = Describe("ConversationTracker", func() { var ( - tracker *connectors.ConversationTracker[string] + tracker *conversations.ConversationTracker[string] duration time.Duration ) BeforeEach(func() { duration = 1 * time.Second - tracker = connectors.NewConversationTracker[string](duration) + tracker = conversations.NewConversationTracker[string](duration) }) It("should initialize with empty conversations", func() { @@ -81,8 +81,8 @@ var _ = Describe("ConversationTracker", func() { }) It("should handle different key types", func() { - trackerInt := connectors.NewConversationTracker[int](duration) - trackerInt64 := connectors.NewConversationTracker[int64](duration) + trackerInt := conversations.NewConversationTracker[int](duration) + trackerInt64 := conversations.NewConversationTracker[int64](duration) message := openai.ChatCompletionMessage{ Role: openai.ChatMessageRoleUser, diff --git a/core/state/config.go b/core/state/config.go index 0898975..4743f43 100644 --- a/core/state/config.go +++ b/core/state/config.go @@ -48,12 +48,13 @@ type AgentConfig struct { Description string `json:"description" form:"description"` - Model string `json:"model" form:"model"` - MultimodalModel string `json:"multimodal_model" form:"multimodal_model"` - APIURL string `json:"api_url" form:"api_url"` - APIKey string `json:"api_key" form:"api_key"` - LocalRAGURL string `json:"local_rag_url" form:"local_rag_url"` - LocalRAGAPIKey string `json:"local_rag_api_key" form:"local_rag_api_key"` + Model string `json:"model" form:"model"` + MultimodalModel string `json:"multimodal_model" form:"multimodal_model"` + APIURL string `json:"api_url" form:"api_url"` + APIKey string `json:"api_key" form:"api_key"` + LocalRAGURL string `json:"local_rag_url" form:"local_rag_url"` + LocalRAGAPIKey string `json:"local_rag_api_key" form:"local_rag_api_key"` + LastMessageDuration string `json:"last_message_duration" form:"last_message_duration"` Name string `json:"name" form:"name"` HUD bool `json:"hud" form:"hud"` @@ -329,6 +330,14 @@ func NewAgentConfigMeta( HelpText: "Maximum number of evaluation loops to perform when addressing gaps in responses", Tags: config.Tags{Section: "AdvancedSettings"}, }, + { + Name: "last_message_duration", + Label: "Last Message Duration", + Type: "text", + DefaultValue: "5m", + HelpText: "Duration for the last message to be considered in the conversation", + Tags: config.Tags{Section: "AdvancedSettings"}, + }, }, MCPServers: []config.Field{ { diff --git a/core/state/pool.go b/core/state/pool.go index 1ca8415..4214115 100644 --- a/core/state/pool.go +++ b/core/state/pool.go @@ -462,6 +462,7 @@ func (a *AgentPool) startAgentWithConfig(name string, config *AgentConfig, obs O }), WithSystemPrompt(config.SystemPrompt), WithMultimodalModel(multimodalModel), + WithLastMessageDuration(config.LastMessageDuration), WithAgentResultCallback(func(state types.ActionState) { a.Lock() if _, ok := a.agentStatus[name]; !ok { diff --git a/core/types/actions.go b/core/types/actions.go index 050c765..cc9ae1d 100644 --- a/core/types/actions.go +++ b/core/types/actions.go @@ -88,7 +88,7 @@ func (a ActionDefinition) ToFunctionDefinition() *openai.FunctionDefinition { // Actions is something the agent can do type Action interface { - Run(ctx context.Context, action ActionParams) (ActionResult, error) + Run(ctx context.Context, sharedState *AgentSharedState, action ActionParams) (ActionResult, error) Definition() ActionDefinition Plannable() bool } diff --git a/core/types/state.go b/core/types/state.go index e3baa5e..b4d7e4d 100644 --- a/core/types/state.go +++ b/core/types/state.go @@ -1,6 +1,11 @@ package types -import "fmt" +import ( + "fmt" + "time" + + "github.com/mudler/LocalAGI/core/conversations" +) // State is the structure // that is used to keep track of the current state @@ -20,6 +25,23 @@ type AgentInternalState struct { Goal string `json:"goal"` } +const ( + DefaultLastMessageDuration = 5 * time.Minute +) + +type AgentSharedState struct { + ConversationTracker *conversations.ConversationTracker[string] `json:"conversation_tracker"` +} + +func NewAgentSharedState(lastMessageDuration time.Duration) *AgentSharedState { + if lastMessageDuration == 0 { + lastMessageDuration = DefaultLastMessageDuration + } + return &AgentSharedState{ + ConversationTracker: conversations.NewConversationTracker[string](lastMessageDuration), + } +} + const fmtT = `===================== NowDoing: %s DoingNext: %s diff --git a/services/actions/browse.go b/services/actions/browse.go index 9c57073..20c1f6c 100644 --- a/services/actions/browse.go +++ b/services/actions/browse.go @@ -18,7 +18,7 @@ func NewBrowse(config map[string]string) *BrowseAction { type BrowseAction struct{} -func (a *BrowseAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { +func (a *BrowseAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) { result := struct { URL string `json:"url"` }{} diff --git a/services/actions/browseragentrunner.go b/services/actions/browseragentrunner.go index f339caf..b70bd62 100644 --- a/services/actions/browseragentrunner.go +++ b/services/actions/browseragentrunner.go @@ -45,7 +45,7 @@ func NewBrowserAgentRunner(config map[string]string, defaultURL string) *Browser } } -func (b *BrowserAgentRunner) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { +func (b *BrowserAgentRunner) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) { result := api.AgentRequest{} err := params.Unmarshal(&result) if err != nil { diff --git a/services/actions/callagents.go b/services/actions/callagents.go index d4b65ac..ce0b57b 100644 --- a/services/actions/callagents.go +++ b/services/actions/callagents.go @@ -52,7 +52,7 @@ type CallAgentAction struct { blacklist []string } -func (a *CallAgentAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { +func (a *CallAgentAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) { result := struct { AgentName string `json:"agent_name"` Message string `json:"message"` diff --git a/services/actions/counter.go b/services/actions/counter.go index e12da96..1c244e3 100644 --- a/services/actions/counter.go +++ b/services/actions/counter.go @@ -24,7 +24,7 @@ func NewCounter(config map[string]string) *CounterAction { } // Run executes the counter action -func (a *CounterAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { +func (a *CounterAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) { // Parse parameters request := struct { Name string `json:"name"` diff --git a/services/actions/deepresearchrunner.go b/services/actions/deepresearchrunner.go index 68902d6..407ba81 100644 --- a/services/actions/deepresearchrunner.go +++ b/services/actions/deepresearchrunner.go @@ -45,7 +45,7 @@ func NewDeepResearchRunner(config map[string]string, defaultURL string) *DeepRes } } -func (d *DeepResearchRunner) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { +func (d *DeepResearchRunner) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) { result := api.DeepResearchRequest{} err := params.Unmarshal(&result) if err != nil { diff --git a/services/actions/genimage.go b/services/actions/genimage.go index 61a101d..7c5045e 100644 --- a/services/actions/genimage.go +++ b/services/actions/genimage.go @@ -29,7 +29,7 @@ type GenImageAction struct { imageModel string } -func (a *GenImageAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { +func (a *GenImageAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) { result := struct { Prompt string `json:"prompt"` Size string `json:"size"` diff --git a/services/actions/genimage_test.go b/services/actions/genimage_test.go index ddaa999..0902b2d 100644 --- a/services/actions/genimage_test.go +++ b/services/actions/genimage_test.go @@ -42,7 +42,7 @@ var _ = Describe("GenImageAction", func() { "size": "256x256", } - url, err := action.Run(ctx, params) + url, err := action.Run(ctx, nil, params) Expect(err).ToNot(HaveOccurred()) Expect(url).ToNot(BeEmpty()) }) @@ -52,7 +52,7 @@ var _ = Describe("GenImageAction", func() { "size": "256x256", } - _, err := action.Run(ctx, params) + _, err := action.Run(ctx, nil, params) Expect(err).To(HaveOccurred()) }) }) diff --git a/services/actions/githubissuecloser.go b/services/actions/githubissuecloser.go index 6a1ce08..99aa2cd 100644 --- a/services/actions/githubissuecloser.go +++ b/services/actions/githubissuecloser.go @@ -26,7 +26,7 @@ func NewGithubIssueCloser(config map[string]string) *GithubIssuesCloser { } } -func (g *GithubIssuesCloser) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { +func (g *GithubIssuesCloser) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) { result := struct { Repository string `json:"repository"` Owner string `json:"owner"` diff --git a/services/actions/githubissuecomment.go b/services/actions/githubissuecomment.go index 4370b1b..3854b32 100644 --- a/services/actions/githubissuecomment.go +++ b/services/actions/githubissuecomment.go @@ -27,7 +27,7 @@ func NewGithubIssueCommenter(config map[string]string) *GithubIssuesCommenter { } } -func (g *GithubIssuesCommenter) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { +func (g *GithubIssuesCommenter) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) { result := struct { Repository string `json:"repository"` Owner string `json:"owner"` diff --git a/services/actions/githubissueedit.go b/services/actions/githubissueedit.go index 81e2139..0c933eb 100644 --- a/services/actions/githubissueedit.go +++ b/services/actions/githubissueedit.go @@ -27,7 +27,7 @@ func NewGithubIssueEditor(config map[string]string) *GithubIssueEditor { } } -func (g *GithubIssueEditor) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { +func (g *GithubIssueEditor) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) { result := struct { Repository string `json:"repository"` Owner string `json:"owner"` diff --git a/services/actions/githubissuelabeler.go b/services/actions/githubissuelabeler.go index 90faebd..78bd803 100644 --- a/services/actions/githubissuelabeler.go +++ b/services/actions/githubissuelabeler.go @@ -38,7 +38,7 @@ func NewGithubIssueLabeler(config map[string]string) *GithubIssuesLabeler { } } -func (g *GithubIssuesLabeler) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { +func (g *GithubIssuesLabeler) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) { result := struct { Repository string `json:"repository"` Owner string `json:"owner"` diff --git a/services/actions/githubissueopener.go b/services/actions/githubissueopener.go index 5bf5e6f..7d75238 100644 --- a/services/actions/githubissueopener.go +++ b/services/actions/githubissueopener.go @@ -27,7 +27,7 @@ func NewGithubIssueOpener(config map[string]string) *GithubIssuesOpener { } } -func (g *GithubIssuesOpener) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { +func (g *GithubIssuesOpener) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) { result := struct { Title string `json:"title"` Body string `json:"text"` diff --git a/services/actions/githubissuereader.go b/services/actions/githubissuereader.go index 39b2f02..7e23466 100644 --- a/services/actions/githubissuereader.go +++ b/services/actions/githubissuereader.go @@ -27,7 +27,7 @@ func NewGithubIssueReader(config map[string]string) *GithubIssuesReader { } } -func (g *GithubIssuesReader) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { +func (g *GithubIssuesReader) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) { result := struct { Repository string `json:"repository"` Owner string `json:"owner"` diff --git a/services/actions/githubissuesearch.go b/services/actions/githubissuesearch.go index f858ab7..d6b518a 100644 --- a/services/actions/githubissuesearch.go +++ b/services/actions/githubissuesearch.go @@ -28,7 +28,7 @@ func NewGithubIssueSearch(config map[string]string) *GithubIssueSearch { } } -func (g *GithubIssueSearch) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { +func (g *GithubIssueSearch) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) { result := struct { Query string `json:"query"` Repository string `json:"repository"` diff --git a/services/actions/githubprcommenter.go b/services/actions/githubprcommenter.go index 810b6e1..0a2c7fe 100644 --- a/services/actions/githubprcommenter.go +++ b/services/actions/githubprcommenter.go @@ -3,8 +3,6 @@ package actions import ( "context" "fmt" - "regexp" - "strconv" "github.com/google/go-github/v69/github" "github.com/mudler/LocalAGI/core/types" @@ -17,96 +15,6 @@ type GithubPRCommenter struct { client *github.Client } -var ( - patchRegex = regexp.MustCompile(`^@@.*\d [\+\-](\d+),?(\d+)?.+?@@`) -) - -type commitFileInfo struct { - FileName string - hunkInfos []*hunkInfo - sha string -} - -type hunkInfo struct { - hunkStart int - hunkEnd int -} - -func (hi hunkInfo) isLineInHunk(line int) bool { - return line >= hi.hunkStart && line <= hi.hunkEnd -} - -func (cfi *commitFileInfo) getHunkInfo(line int) *hunkInfo { - for _, hunkInfo := range cfi.hunkInfos { - if hunkInfo.isLineInHunk(line) { - return hunkInfo - } - } - return nil -} - -func (cfi *commitFileInfo) isLineInChange(line int) bool { - return cfi.getHunkInfo(line) != nil -} - -func (cfi commitFileInfo) calculatePosition(line int) *int { - hi := cfi.getHunkInfo(line) - if hi == nil { - return nil - } - position := line - hi.hunkStart - return &position -} - -func parseHunkPositions(patch, filename string) ([]*hunkInfo, error) { - hunkInfos := make([]*hunkInfo, 0) - if patch != "" { - groups := patchRegex.FindAllStringSubmatch(patch, -1) - if len(groups) < 1 { - return hunkInfos, fmt.Errorf("the patch details for [%s] could not be resolved", filename) - } - for _, patchGroup := range groups { - endPos := 2 - if len(patchGroup) > 2 && patchGroup[2] == "" { - endPos = 1 - } - - hunkStart, err := strconv.Atoi(patchGroup[1]) - if err != nil { - hunkStart = -1 - } - hunkEnd, err := strconv.Atoi(patchGroup[endPos]) - if err != nil { - hunkEnd = -1 - } - hunkInfos = append(hunkInfos, &hunkInfo{ - hunkStart: hunkStart, - hunkEnd: hunkEnd, - }) - } - } - return hunkInfos, nil -} - -func getCommitInfo(file *github.CommitFile) (*commitFileInfo, error) { - patch := file.GetPatch() - hunkInfos, err := parseHunkPositions(patch, *file.Filename) - if err != nil { - return nil, err - } - - sha := file.GetSHA() - if sha == "" { - return nil, fmt.Errorf("the sha details for [%s] could not be resolved", *file.Filename) - } - - return &commitFileInfo{ - FileName: *file.Filename, - hunkInfos: hunkInfos, - sha: sha, - }, nil -} - func NewGithubPRCommenter(config map[string]string) *GithubPRCommenter { client := github.NewClient(nil).WithAuthToken(config["token"]) @@ -119,7 +27,7 @@ func NewGithubPRCommenter(config map[string]string) *GithubPRCommenter { } } -func (g *GithubPRCommenter) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { +func (g *GithubPRCommenter) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) { result := struct { Repository string `json:"repository"` Owner string `json:"owner"` diff --git a/services/actions/githubprcreator.go b/services/actions/githubprcreator.go index eb326ee..306ad40 100644 --- a/services/actions/githubprcreator.go +++ b/services/actions/githubprcreator.go @@ -148,7 +148,7 @@ func (g *GithubPRCreator) createOrUpdateFile(ctx context.Context, branch string, return nil } -func (g *GithubPRCreator) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { +func (g *GithubPRCreator) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) { result := struct { Repository string `json:"repository"` Owner string `json:"owner"` diff --git a/services/actions/githubprcreator_test.go b/services/actions/githubprcreator_test.go index c96c9f6..ffb97d1 100644 --- a/services/actions/githubprcreator_test.go +++ b/services/actions/githubprcreator_test.go @@ -54,7 +54,7 @@ var _ = Describe("GithubPRCreator", func() { }, } - result, err := action.Run(ctx, params) + result, err := action.Run(ctx, nil, params) Expect(err).NotTo(HaveOccurred()) Expect(result.Result).To(ContainSubstring("pull request #")) }) @@ -65,7 +65,7 @@ var _ = Describe("GithubPRCreator", func() { "body": "This is a test pull request", } - _, err := action.Run(ctx, params) + _, err := action.Run(ctx, nil, params) Expect(err).To(HaveOccurred()) }) }) diff --git a/services/actions/githubprreader.go b/services/actions/githubprreader.go index 767ef47..393d31a 100644 --- a/services/actions/githubprreader.go +++ b/services/actions/githubprreader.go @@ -34,7 +34,7 @@ func NewGithubPRReader(config map[string]string) *GithubPRReader { } } -func (g *GithubPRReader) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { +func (g *GithubPRReader) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) { result := struct { Repository string `json:"repository"` Owner string `json:"owner"` diff --git a/services/actions/githubprreviewer.go b/services/actions/githubprreviewer.go index 8ad2d94..74ec5de 100644 --- a/services/actions/githubprreviewer.go +++ b/services/actions/githubprreviewer.go @@ -30,7 +30,7 @@ func NewGithubPRReviewer(config map[string]string) *GithubPRReviewer { } } -func (g *GithubPRReviewer) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { +func (g *GithubPRReviewer) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) { result := struct { Repository string `json:"repository"` Owner string `json:"owner"` diff --git a/services/actions/githubprreviewer_test.go b/services/actions/githubprreviewer_test.go index 99b294d..7d8e498 100644 --- a/services/actions/githubprreviewer_test.go +++ b/services/actions/githubprreviewer_test.go @@ -58,7 +58,7 @@ var _ = Describe("GithubPRReviewer", func() { }, } - result, err := reviewer.Run(ctx, params) + result, err := reviewer.Run(ctx, nil, params) Expect(err).NotTo(HaveOccurred()) Expect(result.Result).To(ContainSubstring("reviewed successfully")) }) @@ -70,7 +70,7 @@ var _ = Describe("GithubPRReviewer", func() { "review_action": "COMMENT", } - result, err := reviewer.Run(ctx, params) + result, err := reviewer.Run(ctx, nil, params) Expect(err).To(HaveOccurred()) Expect(result.Result).To(ContainSubstring("not found")) }) @@ -85,7 +85,7 @@ var _ = Describe("GithubPRReviewer", func() { "review_action": "INVALID_ACTION", } - _, err := reviewer.Run(ctx, params) + _, err := reviewer.Run(ctx, nil, params) Expect(err).To(HaveOccurred()) }) }) diff --git a/services/actions/githubrepositorycreateupdatecontent.go b/services/actions/githubrepositorycreateupdatecontent.go index 382d6fe..0e53a67 100644 --- a/services/actions/githubrepositorycreateupdatecontent.go +++ b/services/actions/githubrepositorycreateupdatecontent.go @@ -30,7 +30,7 @@ func NewGithubRepositoryCreateOrUpdateContent(config map[string]string) *GithubR } } -func (g *GithubRepositoryCreateOrUpdateContent) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { +func (g *GithubRepositoryCreateOrUpdateContent) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) { result := struct { Path string `json:"path"` Repository string `json:"repository"` diff --git a/services/actions/githubrepositorygetallcontent.go b/services/actions/githubrepositorygetallcontent.go index ad1ee9c..91eab18 100644 --- a/services/actions/githubrepositorygetallcontent.go +++ b/services/actions/githubrepositorygetallcontent.go @@ -101,7 +101,7 @@ func (g *GithubRepositoryGetAllContent) getContentRecursively(ctx context.Contex return result.String(), nil } -func (g *GithubRepositoryGetAllContent) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { +func (g *GithubRepositoryGetAllContent) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) { result := struct { Repository string `json:"repository"` Owner string `json:"owner"` diff --git a/services/actions/githubrepositorygetallcontent_test.go b/services/actions/githubrepositorygetallcontent_test.go index 61effd4..1a700db 100644 --- a/services/actions/githubrepositorygetallcontent_test.go +++ b/services/actions/githubrepositorygetallcontent_test.go @@ -45,7 +45,7 @@ var _ = Describe("GithubRepositoryGetAllContent", func() { "path": ".", } - result, err := action.Run(ctx, params) + result, err := action.Run(ctx, nil, params) Expect(err).NotTo(HaveOccurred()) Expect(result.Result).NotTo(BeEmpty()) @@ -64,7 +64,7 @@ var _ = Describe("GithubRepositoryGetAllContent", func() { "path": "non-existent-path", } - _, err := action.Run(ctx, params) + _, err := action.Run(ctx, nil, params) Expect(err).To(HaveOccurred()) }) }) diff --git a/services/actions/githubrepositorygetcontent.go b/services/actions/githubrepositorygetcontent.go index d4cb36f..63d529a 100644 --- a/services/actions/githubrepositorygetcontent.go +++ b/services/actions/githubrepositorygetcontent.go @@ -27,7 +27,7 @@ func NewGithubRepositoryGetContent(config map[string]string) *GithubRepositoryGe } } -func (g *GithubRepositoryGetContent) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { +func (g *GithubRepositoryGetContent) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) { result := struct { Path string `json:"path"` Repository string `json:"repository"` diff --git a/services/actions/githubrepositorylistfiles.go b/services/actions/githubrepositorylistfiles.go index 65b23f6..e6a7ba9 100644 --- a/services/actions/githubrepositorylistfiles.go +++ b/services/actions/githubrepositorylistfiles.go @@ -55,7 +55,7 @@ func (g *GithubRepositoryListFiles) listFilesRecursively(ctx context.Context, pa return files, nil } -func (g *GithubRepositoryListFiles) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { +func (g *GithubRepositoryListFiles) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) { result := struct { Repository string `json:"repository"` Owner string `json:"owner"` diff --git a/services/actions/githubrepositoryreadme.go b/services/actions/githubrepositoryreadme.go index 5d81f24..4f40bec 100644 --- a/services/actions/githubrepositoryreadme.go +++ b/services/actions/githubrepositoryreadme.go @@ -25,7 +25,7 @@ func NewGithubRepositoryREADME(config map[string]string) *GithubRepositoryREADME } } -func (g *GithubRepositoryREADME) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { +func (g *GithubRepositoryREADME) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) { result := struct { Repository string `json:"repository"` Owner string `json:"owner"` diff --git a/services/actions/githubrepositorysearchfiles.go b/services/actions/githubrepositorysearchfiles.go index 576e975..5d9e955 100644 --- a/services/actions/githubrepositorysearchfiles.go +++ b/services/actions/githubrepositorysearchfiles.go @@ -71,7 +71,7 @@ func (g *GithubRepositorySearchFiles) searchFilesRecursively(ctx context.Context return result.String(), nil } -func (g *GithubRepositorySearchFiles) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { +func (g *GithubRepositorySearchFiles) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) { result := struct { Repository string `json:"repository"` Owner string `json:"owner"` diff --git a/services/actions/scrape.go b/services/actions/scrape.go index b0e9e5b..bd59188 100644 --- a/services/actions/scrape.go +++ b/services/actions/scrape.go @@ -16,7 +16,7 @@ func NewScraper(config map[string]string) *ScraperAction { type ScraperAction struct{} -func (a *ScraperAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { +func (a *ScraperAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) { result := struct { URL string `json:"url"` }{} diff --git a/services/actions/search.go b/services/actions/search.go index 48fa47b..42748af 100644 --- a/services/actions/search.go +++ b/services/actions/search.go @@ -35,7 +35,7 @@ func NewSearch(config map[string]string) *SearchAction { type SearchAction struct{ results int } -func (a *SearchAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { +func (a *SearchAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) { result := struct { Query string `json:"query"` }{} diff --git a/services/actions/sendmail.go b/services/actions/sendmail.go index 1247015..314f004 100644 --- a/services/actions/sendmail.go +++ b/services/actions/sendmail.go @@ -28,7 +28,7 @@ type SendMailAction struct { smtpPort string } -func (a *SendMailAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { +func (a *SendMailAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) { result := struct { Message string `json:"message"` To string `json:"to"` diff --git a/services/actions/sendtelegrammessage.go b/services/actions/sendtelegrammessage.go index 2d35106..abc6837 100644 --- a/services/actions/sendtelegrammessage.go +++ b/services/actions/sendtelegrammessage.go @@ -10,6 +10,7 @@ import ( "github.com/mudler/LocalAGI/core/types" "github.com/mudler/LocalAGI/pkg/config" "github.com/mudler/LocalAGI/pkg/xstrings" + "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/jsonschema" ) @@ -19,9 +20,11 @@ const ( ) type SendTelegramMessageRunner struct { - token string - chatID int64 - bot *bot.Bot + token string + chatID int64 + bot *bot.Bot + customName string + customDescription string } func NewSendTelegramMessageRunner(config map[string]string) *SendTelegramMessageRunner { @@ -46,9 +49,11 @@ func NewSendTelegramMessageRunner(config map[string]string) *SendTelegramMessage } return &SendTelegramMessageRunner{ - token: token, - chatID: chatID, - bot: b, + token: token, + chatID: chatID, + bot: b, + customName: config["custom_name"], + customDescription: config["custom_description"], } } @@ -57,7 +62,7 @@ type TelegramMessageParams struct { Message string `json:"message"` } -func (s *SendTelegramMessageRunner) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { +func (s *SendTelegramMessageRunner) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) { var messageParams TelegramMessageParams err := params.Unmarshal(&messageParams) if err != nil { @@ -95,6 +100,11 @@ func (s *SendTelegramMessageRunner) Run(ctx context.Context, params types.Action } } + sharedState.ConversationTracker.AddMessage(fmt.Sprintf("telegram:%d", messageParams.ChatID), openai.ChatCompletionMessage{ + Content: messageParams.Message, + Role: "assistant", + }) + return types.ActionResult{ Result: fmt.Sprintf("Message sent successfully to chat ID %d in %d parts", messageParams.ChatID, len(messages)), Metadata: map[string]interface{}{ @@ -104,10 +114,21 @@ func (s *SendTelegramMessageRunner) Run(ctx context.Context, params types.Action } func (s *SendTelegramMessageRunner) Definition() types.ActionDefinition { + + customName := "send_telegram_message" + if s.customName != "" { + customName = s.customName + } + + customDescription := "Send a message to a Telegram user or group" + if s.customDescription != "" { + customDescription = s.customDescription + } + if s.chatID != 0 { return types.ActionDefinition{ - Name: "send_telegram_message", - Description: "Send a message to a Telegram user or group", + Name: types.ActionDefinitionName(customName), + Description: customDescription, Properties: map[string]jsonschema.Definition{ "message": { Type: jsonschema.String, @@ -119,8 +140,8 @@ func (s *SendTelegramMessageRunner) Definition() types.ActionDefinition { } return types.ActionDefinition{ - Name: "send_telegram_message", - Description: "Send a message to a Telegram user or group", + Name: types.ActionDefinitionName(customName), + Description: customDescription, Properties: map[string]jsonschema.Definition{ "chat_id": { Type: jsonschema.Number, @@ -156,5 +177,19 @@ func SendTelegramMessageConfigMeta() []config.Field { Required: false, HelpText: "Default Telegram chat ID to send messages to (can be overridden in parameters)", }, + { + Name: "custom_name", + Label: "Custom Name", + Type: config.FieldTypeText, + Required: false, + HelpText: "Custom name for the action (optional, defaults to 'send_telegram_message')", + }, + { + Name: "custom_description", + Label: "Custom Description", + Type: config.FieldTypeText, + Required: false, + HelpText: "Custom description for the action (optional, defaults to 'Send a message to a Telegram user or group')", + }, } } diff --git a/services/actions/shell.go b/services/actions/shell.go index 3b7d227..7f957be 100644 --- a/services/actions/shell.go +++ b/services/actions/shell.go @@ -28,7 +28,7 @@ type ShellAction struct { customDescription string } -func (a *ShellAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { +func (a *ShellAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) { result := struct { Command string `json:"command"` Host string `json:"host"` diff --git a/services/actions/twitter_post.go b/services/actions/twitter_post.go index eeae064..9469053 100644 --- a/services/actions/twitter_post.go +++ b/services/actions/twitter_post.go @@ -22,7 +22,7 @@ type PostTweetAction struct { noCharacterLimit bool } -func (a *PostTweetAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { +func (a *PostTweetAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) { result := struct { Text string `json:"text"` }{} diff --git a/services/actions/wikipedia.go b/services/actions/wikipedia.go index 3f6b683..f4d3709 100644 --- a/services/actions/wikipedia.go +++ b/services/actions/wikipedia.go @@ -15,7 +15,7 @@ func NewWikipedia(config map[string]string) *WikipediaAction { type WikipediaAction struct{} -func (a *WikipediaAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { +func (a *WikipediaAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) { result := struct { Query string `json:"query"` }{} diff --git a/services/connectors/discord.go b/services/connectors/discord.go index 4ebf265..fd11945 100644 --- a/services/connectors/discord.go +++ b/services/connectors/discord.go @@ -2,8 +2,8 @@ package connectors import ( "encoding/json" + "fmt" "strings" - "time" "github.com/bwmarrin/discordgo" "github.com/mudler/LocalAGI/core/agent" @@ -14,9 +14,8 @@ import ( ) type Discord struct { - token string - defaultChannel string - conversationTracker *ConversationTracker[string] + token string + defaultChannel string } // NewDiscord creates a new Discord connector @@ -25,11 +24,6 @@ type Discord struct { // - defaultChannel: Discord channel to always answer even if not mentioned func NewDiscord(config map[string]string) *Discord { - duration, err := time.ParseDuration(config["lastMessageDuration"]) - if err != nil { - duration = 5 * time.Minute - } - token := config["token"] if !strings.HasPrefix(token, "Bot ") { @@ -37,9 +31,8 @@ func NewDiscord(config map[string]string) *Discord { } return &Discord{ - conversationTracker: NewConversationTracker[string](duration), - token: token, - defaultChannel: config["defaultChannel"], + token: token, + defaultChannel: config["defaultChannel"], } } @@ -157,12 +150,12 @@ func (d *Discord) handleThreadMessage(a *agent.Agent, s *discordgo.Session, m *d func (d *Discord) handleChannelMessage(a *agent.Agent, s *discordgo.Session, m *discordgo.MessageCreate) { - d.conversationTracker.AddMessage(m.ChannelID, openai.ChatCompletionMessage{ + a.SharedState().ConversationTracker.AddMessage(fmt.Sprintf("discord:%s", m.ChannelID), openai.ChatCompletionMessage{ Role: "user", Content: m.Content, }) - conv := d.conversationTracker.GetConversation(m.ChannelID) + conv := a.SharedState().ConversationTracker.GetConversation(fmt.Sprintf("discord:%s", m.ChannelID)) jobResult := a.Ask( types.WithConversationHistory(conv), @@ -173,7 +166,7 @@ func (d *Discord) handleChannelMessage(a *agent.Agent, s *discordgo.Session, m * return } - d.conversationTracker.AddMessage(m.ChannelID, openai.ChatCompletionMessage{ + a.SharedState().ConversationTracker.AddMessage(fmt.Sprintf("discord:%s", m.ChannelID), openai.ChatCompletionMessage{ Role: "assistant", Content: jobResult.Response, }) diff --git a/services/connectors/irc.go b/services/connectors/irc.go index 819c6cc..dd7a8fb 100644 --- a/services/connectors/irc.go +++ b/services/connectors/irc.go @@ -15,28 +15,22 @@ import ( ) type IRC struct { - server string - port string - nickname string - channel string - conn *irc.Connection - alwaysReply bool - conversationTracker *ConversationTracker[string] + server string + port string + nickname string + channel string + conn *irc.Connection + alwaysReply bool } func NewIRC(config map[string]string) *IRC { - duration, err := time.ParseDuration(config["lastMessageDuration"]) - if err != nil { - duration = 5 * time.Minute - } return &IRC{ - server: config["server"], - port: config["port"], - nickname: config["nickname"], - channel: config["channel"], - alwaysReply: config["alwaysReply"] == "true", - conversationTracker: NewConversationTracker[string](duration), + server: config["server"], + port: config["port"], + nickname: config["nickname"], + channel: config["channel"], + alwaysReply: config["alwaysReply"] == "true", } } @@ -115,7 +109,7 @@ func (i *IRC) Start(a *agent.Agent) { cleanedMessage := cleanUpMessage(message, i.nickname) go func() { - conv := i.conversationTracker.GetConversation(channel) + conv := a.SharedState().ConversationTracker.GetConversation(fmt.Sprintf("irc:%s", channel)) conv = append(conv, openai.ChatCompletionMessage{ @@ -125,7 +119,7 @@ func (i *IRC) Start(a *agent.Agent) { ) // Update the conversation history - i.conversationTracker.AddMessage(channel, openai.ChatCompletionMessage{ + a.SharedState().ConversationTracker.AddMessage(fmt.Sprintf("irc:%s", channel), openai.ChatCompletionMessage{ Content: cleanedMessage, Role: "user", }) @@ -140,7 +134,7 @@ func (i *IRC) Start(a *agent.Agent) { } // Update the conversation history - i.conversationTracker.AddMessage(channel, openai.ChatCompletionMessage{ + a.SharedState().ConversationTracker.AddMessage(fmt.Sprintf("irc:%s", channel), openai.ChatCompletionMessage{ Content: res.Response, Role: "assistant", }) @@ -209,7 +203,7 @@ func (i *IRC) Start(a *agent.Agent) { // Start the IRC client in a goroutine go i.conn.Loop() go func() { - select { + select { case <-a.Context().Done(): i.conn.Quit() return @@ -249,11 +243,5 @@ func IRCConfigMeta() []config.Field { Label: "Always Reply", Type: config.FieldTypeCheckbox, }, - { - Name: "lastMessageDuration", - Label: "Last Message Duration", - Type: config.FieldTypeText, - DefaultValue: "5m", - }, } } diff --git a/services/connectors/matrix.go b/services/connectors/matrix.go index 72acc98..c49dec0 100644 --- a/services/connectors/matrix.go +++ b/services/connectors/matrix.go @@ -31,27 +31,20 @@ type Matrix struct { // Track active jobs for cancellation activeJobs map[string][]*types.Job // map[roomID]bool to track if a room has active processing activeJobsMutex sync.RWMutex - - conversationTracker *ConversationTracker[string] } const matrixThinkingMessage = "🤔 thinking..." func NewMatrix(config map[string]string) *Matrix { - duration, err := time.ParseDuration(config["lastMessageDuration"]) - if err != nil { - duration = 5 * time.Minute - } return &Matrix{ - homeserverURL: config["homeserverURL"], - userID: config["userID"], - accessToken: config["accessToken"], - roomID: config["roomID"], - roomMode: config["roomMode"] == "true", - conversationTracker: NewConversationTracker[string](duration), - placeholders: make(map[string]string), - activeJobs: make(map[string][]*types.Job), + homeserverURL: config["homeserverURL"], + userID: config["userID"], + accessToken: config["accessToken"], + roomID: config["roomID"], + roomMode: config["roomMode"] == "true", + placeholders: make(map[string]string), + activeJobs: make(map[string][]*types.Job), } } @@ -149,7 +142,7 @@ func (m *Matrix) handleRoomMessage(a *agent.Agent, evt *event.Event) { // Cancel any active job for this room before starting a new one m.cancelActiveJobForRoom(evt.RoomID.String()) - currentConv := m.conversationTracker.GetConversation(evt.RoomID.String()) + currentConv := a.SharedState().ConversationTracker.GetConversation(fmt.Sprintf("matrix:%s", evt.RoomID.String())) message := evt.Content.AsMessage().Body @@ -163,8 +156,8 @@ func (m *Matrix) handleRoomMessage(a *agent.Agent, evt *event.Event) { Content: message, }) - m.conversationTracker.AddMessage( - evt.RoomID.String(), currentConv[len(currentConv)-1], + a.SharedState().ConversationTracker.AddMessage( + fmt.Sprintf("matrix:%s", evt.RoomID.String()), currentConv[len(currentConv)-1], ) agentOptions = append(agentOptions, types.WithConversationHistory(currentConv)) @@ -209,8 +202,8 @@ func (m *Matrix) handleRoomMessage(a *agent.Agent, evt *event.Event) { return } - m.conversationTracker.AddMessage( - evt.RoomID.String(), openai.ChatCompletionMessage{ + a.SharedState().ConversationTracker.AddMessage( + fmt.Sprintf("matrix:%s", evt.RoomID.String()), openai.ChatCompletionMessage{ Role: "assistant", Content: res.Response, }, @@ -307,11 +300,5 @@ func MatrixConfigMeta() []config.Field { Label: "Room Mode", Type: config.FieldTypeCheckbox, }, - { - Name: "lastMessageDuration", - Label: "Last Message Duration", - Type: config.FieldTypeText, - DefaultValue: "5m", - }, } } diff --git a/services/connectors/slack.go b/services/connectors/slack.go index 84184ce..71a2d21 100644 --- a/services/connectors/slack.go +++ b/services/connectors/slack.go @@ -8,7 +8,6 @@ import ( "os" "strings" "sync" - "time" "github.com/mudler/LocalAGI/pkg/config" "github.com/mudler/LocalAGI/pkg/localoperator" @@ -42,27 +41,19 @@ type Slack struct { // Track active jobs for cancellation activeJobs map[string][]*types.Job // map[channelID]bool to track if a channel has active processing activeJobsMutex sync.RWMutex - - conversationTracker *ConversationTracker[string] } const thinkingMessage = ":hourglass: thinking..." func NewSlack(config map[string]string) *Slack { - duration, err := time.ParseDuration(config["lastMessageDuration"]) - if err != nil { - duration = 5 * time.Minute - } - return &Slack{ - appToken: config["appToken"], - botToken: config["botToken"], - channelID: config["channelID"], - channelMode: config["channelMode"] == "true", - conversationTracker: NewConversationTracker[string](duration), - placeholders: make(map[string]string), - activeJobs: make(map[string][]*types.Job), + appToken: config["appToken"], + botToken: config["botToken"], + channelID: config["channelID"], + channelMode: config["channelMode"] == "true", + placeholders: make(map[string]string), + activeJobs: make(map[string][]*types.Job), } } @@ -140,16 +131,6 @@ func cleanUpUsernameFromMessage(message string, b *slack.AuthTestResponse) strin return cleaned } -func extractUserIDsFromMessage(message string) []string { - var userIDs []string - for _, part := range strings.Split(message, " ") { - if strings.HasPrefix(part, "<@") && strings.HasSuffix(part, ">") { - userIDs = append(userIDs, strings.TrimPrefix(strings.TrimSuffix(part, ">"), "<@")) - } - } - return userIDs -} - func replaceUserIDsWithNamesInMessage(api *slack.Client, message string) string { for _, part := range strings.Split(message, " ") { if strings.HasPrefix(part, "<@") && strings.HasSuffix(part, ">") { @@ -279,7 +260,7 @@ func (t *Slack) handleChannelMessage( // Cancel any active job for this channel before starting a new one t.cancelActiveJobForChannel(ev.Channel) - currentConv := t.conversationTracker.GetConversation(t.channelID) + currentConv := a.SharedState().ConversationTracker.GetConversation(fmt.Sprintf("slack:%s", t.channelID)) message := replaceUserIDsWithNamesInMessage(api, cleanUpUsernameFromMessage(ev.Text, b)) @@ -323,8 +304,8 @@ func (t *Slack) handleChannelMessage( }) } - t.conversationTracker.AddMessage( - t.channelID, currentConv[len(currentConv)-1], + a.SharedState().ConversationTracker.AddMessage( + fmt.Sprintf("slack:%s", t.channelID), currentConv[len(currentConv)-1], ) agentOptions = append(agentOptions, types.WithConversationHistory(currentConv)) @@ -370,14 +351,14 @@ func (t *Slack) handleChannelMessage( return } - t.conversationTracker.AddMessage( - t.channelID, openai.ChatCompletionMessage{ + a.SharedState().ConversationTracker.AddMessage( + fmt.Sprintf("slack:%s", t.channelID), openai.ChatCompletionMessage{ Role: "assistant", Content: res.Response, }, ) - xlog.Debug("After adding message to conversation tracker", "conversation", t.conversationTracker.GetConversation(t.channelID)) + xlog.Debug("After adding message to conversation tracker", "conversation", a.SharedState().ConversationTracker.GetConversation(fmt.Sprintf("slack:%s", t.channelID))) //res.Response = githubmarkdownconvertergo.Slack(res.Response) @@ -752,6 +733,13 @@ func (t *Slack) Start(a *agent.Agent) { if err != nil { xlog.Error(fmt.Sprintf("Error posting message: %v", err)) } + a.SharedState().ConversationTracker.AddMessage( + fmt.Sprintf("slack:%s", t.channelID), + openai.ChatCompletionMessage{ + Content: ccm.Content, + Role: "assistant", + }, + ) }) } @@ -835,11 +823,5 @@ func SlackConfigMeta() []config.Field { Label: "Always Reply", Type: config.FieldTypeCheckbox, }, - { - Name: "lastMessageDuration", - Label: "Last Message Duration", - Type: config.FieldTypeText, - DefaultValue: "5m", - }, } } diff --git a/services/connectors/telegram.go b/services/connectors/telegram.go index 4cbddb4..5e4ca22 100644 --- a/services/connectors/telegram.go +++ b/services/connectors/telegram.go @@ -13,7 +13,6 @@ import ( "slices" "strings" "sync" - "time" "github.com/go-telegram/bot" "github.com/go-telegram/bot/models" @@ -35,14 +34,8 @@ type Telegram struct { bot *bot.Bot agent *agent.Agent - currentconversation map[int64][]openai.ChatCompletionMessage - lastMessageTime map[int64]time.Time - lastMessageDuration time.Duration - admins []string - conversationTracker *ConversationTracker[int64] - // To track placeholder messages placeholders map[string]int // map[jobUUID]messageID placeholderMutex sync.RWMutex @@ -50,6 +43,8 @@ type Telegram struct { // Track active jobs for cancellation activeJobs map[int64][]*types.Job // map[chatID]bool to track if a chat has active processing activeJobsMutex sync.RWMutex + + channelID string } // Send any text message to the bot after the bot has been started @@ -219,6 +214,8 @@ func formatResponseWithURLs(response string, urls []string) string { func (t *Telegram) handleUpdate(ctx context.Context, b *bot.Bot, a *agent.Agent, update *models.Update) { username := update.Message.From.Username + xlog.Debug("Received message from user", "username", username, "chatID", update.Message.Chat.ID, "message", update.Message.Text) + internalError := func(err error, msg *models.Message) { xlog.Error("Error updating final message", "error", err) b.EditMessageText(ctx, &bot.EditMessageTextParams{ @@ -242,14 +239,14 @@ func (t *Telegram) handleUpdate(ctx context.Context, b *bot.Bot, a *agent.Agent, // Cancel any active job for this chat before starting a new one t.cancelActiveJobForChat(update.Message.Chat.ID) - currentConv := t.conversationTracker.GetConversation(update.Message.From.ID) + currentConv := a.SharedState().ConversationTracker.GetConversation(fmt.Sprintf("telegram:%d", update.Message.From.ID)) currentConv = append(currentConv, openai.ChatCompletionMessage{ Content: update.Message.Text, Role: "user", }) - t.conversationTracker.AddMessage( - update.Message.From.ID, + a.SharedState().ConversationTracker.AddMessage( + fmt.Sprintf("telegram:%d", update.Message.From.ID), openai.ChatCompletionMessage{ Content: update.Message.Text, Role: "user", @@ -328,8 +325,8 @@ func (t *Telegram) handleUpdate(ctx context.Context, b *bot.Bot, a *agent.Agent, return } - t.conversationTracker.AddMessage( - update.Message.From.ID, + a.SharedState().ConversationTracker.AddMessage( + fmt.Sprintf("telegram:%d", update.Message.From.ID), openai.ChatCompletionMessage{ Content: res.Response, Role: "assistant", @@ -408,11 +405,34 @@ func (t *Telegram) Start(a *agent.Agent) { t.agent = a // go func() { - // for m := range a.ConversationChannel() { + // forc m := range a.ConversationChannel() { // t.handleNewMessage(ctx, b, m) // } // }() + if t.channelID != "" { + // handle new conversations + a.AddSubscriber(func(ccm openai.ChatCompletionMessage) { + xlog.Debug("Subscriber(telegram)", "message", ccm.Content) + _, err := b.SendMessage(ctx, &bot.SendMessageParams{ + ChatID: t.channelID, + Text: ccm.Content, + }) + if err != nil { + xlog.Error("Error sending message", "error", err) + return + } + + t.agent.SharedState().ConversationTracker.AddMessage( + fmt.Sprintf("telegram:%s", t.channelID), + openai.ChatCompletionMessage{ + Content: ccm.Content, + Role: "assistant", + }, + ) + }) + } + b.Start(ctx) } @@ -422,11 +442,6 @@ func NewTelegramConnector(config map[string]string) (*Telegram, error) { return nil, errors.New("token is required") } - duration, err := time.ParseDuration(config["lastMessageDuration"]) - if err != nil { - duration = 5 * time.Minute - } - admins := []string{} if _, ok := config["admins"]; ok { @@ -434,14 +449,11 @@ func NewTelegramConnector(config map[string]string) (*Telegram, error) { } return &Telegram{ - Token: token, - lastMessageDuration: duration, - admins: admins, - currentconversation: map[int64][]openai.ChatCompletionMessage{}, - lastMessageTime: map[int64]time.Time{}, - conversationTracker: NewConversationTracker[int64](duration), - placeholders: make(map[string]int), - activeJobs: make(map[int64][]*types.Job), + Token: token, + admins: admins, + placeholders: make(map[string]int), + activeJobs: make(map[int64][]*types.Job), + channelID: config["channel_id"], }, nil } @@ -461,10 +473,10 @@ func TelegramConfigMeta() []config.Field { HelpText: "Comma-separated list of Telegram usernames that are allowed to interact with the bot", }, { - Name: "lastMessageDuration", - Label: "Last Message Duration", - Type: config.FieldTypeText, - DefaultValue: "5m", + Name: "channel_id", + Label: "Channel ID", + Type: config.FieldTypeText, + HelpText: "Telegram channel ID to send messages to if the agent needs to initiate a conversation", }, } } diff --git a/webui/app.go b/webui/app.go index ff4b29e..42a5223 100644 --- a/webui/app.go +++ b/webui/app.go @@ -11,12 +11,14 @@ import ( "time" "github.com/google/uuid" + "github.com/mudler/LocalAGI/core/conversations" coreTypes "github.com/mudler/LocalAGI/core/types" + internalTypes "github.com/mudler/LocalAGI/core/types" "github.com/mudler/LocalAGI/pkg/llm" "github.com/mudler/LocalAGI/pkg/xlog" "github.com/mudler/LocalAGI/services" - "github.com/mudler/LocalAGI/services/connectors" "github.com/mudler/LocalAGI/webui/types" + "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/jsonschema" @@ -33,6 +35,7 @@ type ( htmx *htmx.HTMX config *Config *fiber.App + sharedState *internalTypes.AgentSharedState } ) @@ -47,9 +50,10 @@ func NewApp(opts ...Option) *App { }) a := &App{ - htmx: htmx.New(), - config: config, - App: webapp, + htmx: htmx.New(), + config: config, + App: webapp, + sharedState: internalTypes.NewAgentSharedState(5 * time.Minute), } a.registerRoutes(config.Pool, webapp) @@ -443,7 +447,7 @@ func (a *App) GetActionDefinition(pool *state.AgentPool) func(c *fiber.Ctx) erro } } -func (a *App) ExecuteAction(pool *state.AgentPool) func(c *fiber.Ctx) error { +func (app *App) ExecuteAction(pool *state.AgentPool) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { payload := struct { Config map[string]string `json:"config"` @@ -467,7 +471,7 @@ func (a *App) ExecuteAction(pool *state.AgentPool) func(c *fiber.Ctx) error { ctx, cancel := context.WithTimeout(c.Context(), 200*time.Second) defer cancel() - res, err := a.Run(ctx, payload.Params) + res, err := a.Run(ctx, app.sharedState, payload.Params) if err != nil { xlog.Error("Error running action", "error", err) return errorJSONMessage(c, err.Error()) @@ -484,7 +488,7 @@ func (a *App) ListActions() func(c *fiber.Ctx) error { } } -func (a *App) Responses(pool *state.AgentPool, tracker *connectors.ConversationTracker[string]) func(c *fiber.Ctx) error { +func (a *App) Responses(pool *state.AgentPool, tracker *conversations.ConversationTracker[string]) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { var request types.RequestBody if err := c.BodyParser(&request); err != nil { diff --git a/webui/routes.go b/webui/routes.go index e005432..68d2a3e 100644 --- a/webui/routes.go +++ b/webui/routes.go @@ -13,8 +13,8 @@ import ( fiber "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/middleware/filesystem" "github.com/gofiber/fiber/v2/middleware/keyauth" + "github.com/mudler/LocalAGI/core/conversations" "github.com/mudler/LocalAGI/core/sse" - "github.com/mudler/LocalAGI/services/connectors" "github.com/mudler/LocalAGI/core/state" "github.com/mudler/LocalAGI/core/types" @@ -138,7 +138,7 @@ func (app *App) registerRoutes(pool *state.AgentPool, webapp *fiber.App) { webapp.Post("/api/chat/:name", app.Chat(pool)) - conversationTracker := connectors.NewConversationTracker[string](app.config.ConversationStoreDuration) + conversationTracker := conversations.NewConversationTracker[string](app.config.ConversationStoreDuration) webapp.Post("/v1/responses", app.Responses(pool, conversationTracker)) @@ -268,7 +268,7 @@ func (app *App) registerRoutes(pool *state.AgentPool, webapp *fiber.App) { } return c.JSON(fiber.Map{ - "Name": name, + "Name": name, "History": agent.Observer().History(), }) })