feat(agent): shared state, allow to track conversations globally (#148)

* feat(agent): shared state, allow to track conversations globally

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Cleanup

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* track conversations initiated by the bot

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2025-05-11 22:23:01 +02:00
committed by GitHub
parent 2b07dd79ec
commit c23e655f44
63 changed files with 290 additions and 316 deletions

View File

@@ -81,7 +81,7 @@ func (a *CustomAction) Plannable() bool {
return true return true
} }
func (a *CustomAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { func (a *CustomAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
v, err := a.i.Eval(fmt.Sprintf("%s.Run", a.config["name"])) v, err := a.i.Eval(fmt.Sprintf("%s.Run", a.config["name"]))
if err != nil { if err != nil {
return types.ActionResult{}, err return types.ActionResult{}, err

View File

@@ -76,7 +76,7 @@ return []string{"foo"}
Description: "A test action", Description: "A test action",
})) }))
runResult, err := customAction.Run(context.Background(), types.ActionParams{ runResult, err := customAction.Run(context.Background(), nil, types.ActionParams{
"Foo": "bar", "Foo": "bar",
}) })
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())

View File

@@ -21,7 +21,7 @@ type GoalResponse struct {
Achieved bool `json:"achieved"` Achieved bool `json:"achieved"`
} }
func (a *GoalAction) Run(context.Context, types.ActionParams) (types.ActionResult, error) { func (a *GoalAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
return types.ActionResult{}, nil return types.ActionResult{}, nil
} }

View File

@@ -22,7 +22,7 @@ type IntentResponse struct {
Reasoning string `json:"reasoning"` Reasoning string `json:"reasoning"`
} }
func (a *IntentAction) Run(context.Context, types.ActionParams) (types.ActionResult, error) { func (a *IntentAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
return types.ActionResult{}, nil return types.ActionResult{}, nil
} }

View File

@@ -19,7 +19,7 @@ type ConversationActionResponse struct {
Message string `json:"message"` Message string `json:"message"`
} }
func (a *ConversationAction) Run(context.Context, types.ActionParams) (types.ActionResult, error) { func (a *ConversationAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
return types.ActionResult{}, nil return types.ActionResult{}, nil
} }

View File

@@ -16,7 +16,7 @@ func NewStop() *StopAction {
type StopAction struct{} type StopAction struct{}
func (a *StopAction) Run(context.Context, types.ActionParams) (types.ActionResult, error) { func (a *StopAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
return types.ActionResult{}, nil return types.ActionResult{}, nil
} }

View File

@@ -30,7 +30,7 @@ type PlanSubtask struct {
Reasoning string `json:"reasoning"` Reasoning string `json:"reasoning"`
} }
func (a *PlanAction) Run(context.Context, types.ActionParams) (types.ActionResult, error) { func (a *PlanAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
return types.ActionResult{}, nil return types.ActionResult{}, nil
} }

View File

@@ -20,7 +20,7 @@ type ReasoningResponse struct {
Reasoning string `json:"reasoning"` Reasoning string `json:"reasoning"`
} }
func (a *ReasoningAction) Run(context.Context, types.ActionParams) (types.ActionResult, error) { func (a *ReasoningAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
return types.ActionResult{}, nil return types.ActionResult{}, nil
} }

View File

@@ -22,7 +22,7 @@ type ReplyResponse struct {
Message string `json:"message"` Message string `json:"message"`
} }
func (a *ReplyAction) Run(context.Context, types.ActionParams) (string, error) { func (a *ReplyAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (string, error) {
return "no-op", nil return "no-op", nil
} }

View File

@@ -15,7 +15,7 @@ func NewState() *StateAction {
type StateAction struct{} type StateAction struct{}
func (a *StateAction) Run(context.Context, types.ActionParams) (types.ActionResult, error) { func (a *StateAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
return types.ActionResult{Result: "internal state has been updated"}, nil return types.ActionResult{Result: "internal state has been updated"}, nil
} }

View File

@@ -46,6 +46,8 @@ type Agent struct {
newMessagesSubscribers []func(openai.ChatCompletionMessage) newMessagesSubscribers []func(openai.ChatCompletionMessage)
observer Observer observer Observer
sharedState *types.AgentSharedState
} }
type RAGDB interface { type RAGDB interface {
@@ -78,6 +80,7 @@ func New(opts ...Option) (*Agent, error) {
context: types.NewActionContext(ctx, cancel), context: types.NewActionContext(ctx, cancel),
newConversations: make(chan openai.ChatCompletionMessage), newConversations: make(chan openai.ChatCompletionMessage),
newMessagesSubscribers: options.newConversationsSubscribers, newMessagesSubscribers: options.newConversationsSubscribers,
sharedState: types.NewAgentSharedState(options.lastMessageDuration),
} }
// Initialize observer if provided // Initialize observer if provided
@@ -118,6 +121,10 @@ func New(opts ...Option) (*Agent, error) {
return a, nil return a, nil
} }
func (a *Agent) SharedState() *types.AgentSharedState {
return a.sharedState
}
func (a *Agent) startNewConversationsConsumer() { func (a *Agent) startNewConversationsConsumer() {
go func() { go func() {
for { for {
@@ -294,7 +301,7 @@ func (a *Agent) runAction(job *types.Job, chosenAction types.Action, params type
for _, act := range a.availableActions() { for _, act := range a.availableActions() {
if act.Definition().Name == chosenAction.Definition().Name { if act.Definition().Name == chosenAction.Definition().Name {
res, err := act.Run(job.GetContext(), params) res, err := act.Run(job.GetContext(), a.sharedState, params)
if err != nil { if err != nil {
if obs != nil { if obs != nil {
obs.Completion = &types.Completion{ obs.Completion = &types.Completion{

View File

@@ -44,7 +44,7 @@ func (a *TestAction) Plannable() bool {
return true return true
} }
func (a *TestAction) Run(c context.Context, p types.ActionParams) (types.ActionResult, error) { func (a *TestAction) Run(c context.Context, sharedState *types.AgentSharedState, p types.ActionParams) (types.ActionResult, error) {
for k, r := range a.response { for k, r := range a.response {
if strings.Contains(strings.ToLower(p.String()), strings.ToLower(k)) { if strings.Contains(strings.ToLower(p.String()), strings.ToLower(k)) {
return types.ActionResult{Result: r}, nil return types.ActionResult{Result: r}, nil

View File

@@ -38,7 +38,7 @@ func (a *mcpAction) Plannable() bool {
return true return true
} }
func (m *mcpAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { func (m *mcpAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
resp, err := m.mcpClient.CallTool(ctx, m.toolName, params) resp, err := m.mcpClient.CallTool(ctx, m.toolName, params)
if err != nil { if err != nil {
xlog.Error("Failed to call tool", "error", err.Error()) xlog.Error("Failed to call tool", "error", err.Error())

View File

@@ -64,6 +64,8 @@ type options struct {
observer Observer observer Observer
parallelJobs int parallelJobs int
lastMessageDuration time.Duration
} }
func (o *options) SeparatedMultimodalModel() bool { func (o *options) SeparatedMultimodalModel() bool {
@@ -151,6 +153,17 @@ func EnableKnowledgeBaseWithResults(results int) Option {
} }
} }
func WithLastMessageDuration(duration string) Option {
return func(o *options) error {
d, err := time.ParseDuration(duration)
if err != nil {
d = types.DefaultLastMessageDuration
}
o.lastMessageDuration = d
return nil
}
}
func WithParallelJobs(jobs int) Option { func WithParallelJobs(jobs int) Option {
return func(o *options) error { return func(o *options) error {
o.parallelJobs = jobs o.parallelJobs = jobs

View File

@@ -14,10 +14,10 @@ import (
// all information that should be displayed to the LLM // all information that should be displayed to the LLM
// in the prompts // in the prompts
type PromptHUD struct { type PromptHUD struct {
Character Character `json:"character"` Character Character `json:"character"`
CurrentState types.AgentInternalState `json:"current_state"` CurrentState types.AgentInternalState `json:"current_state"`
PermanentGoal string `json:"permanent_goal"` PermanentGoal string `json:"permanent_goal"`
ShowCharacter bool `json:"show_character"` ShowCharacter bool `json:"show_character"`
} }
type Character struct { type Character struct {

View File

@@ -0,0 +1,13 @@
package conversations_test
import (
"testing"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func TestConversations(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Conversations test suite")
}

View File

@@ -1,4 +1,4 @@
package connectors package conversations
import ( import (
"fmt" "fmt"

View File

@@ -1,9 +1,9 @@
package connectors_test package conversations_test
import ( import (
"time" "time"
"github.com/mudler/LocalAGI/services/connectors" "github.com/mudler/LocalAGI/core/conversations"
. "github.com/onsi/ginkgo/v2" . "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
"github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai"
@@ -11,13 +11,13 @@ import (
var _ = Describe("ConversationTracker", func() { var _ = Describe("ConversationTracker", func() {
var ( var (
tracker *connectors.ConversationTracker[string] tracker *conversations.ConversationTracker[string]
duration time.Duration duration time.Duration
) )
BeforeEach(func() { BeforeEach(func() {
duration = 1 * time.Second duration = 1 * time.Second
tracker = connectors.NewConversationTracker[string](duration) tracker = conversations.NewConversationTracker[string](duration)
}) })
It("should initialize with empty conversations", func() { It("should initialize with empty conversations", func() {
@@ -81,8 +81,8 @@ var _ = Describe("ConversationTracker", func() {
}) })
It("should handle different key types", func() { It("should handle different key types", func() {
trackerInt := connectors.NewConversationTracker[int](duration) trackerInt := conversations.NewConversationTracker[int](duration)
trackerInt64 := connectors.NewConversationTracker[int64](duration) trackerInt64 := conversations.NewConversationTracker[int64](duration)
message := openai.ChatCompletionMessage{ message := openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleUser, Role: openai.ChatMessageRoleUser,

View File

@@ -48,12 +48,13 @@ type AgentConfig struct {
Description string `json:"description" form:"description"` Description string `json:"description" form:"description"`
Model string `json:"model" form:"model"` Model string `json:"model" form:"model"`
MultimodalModel string `json:"multimodal_model" form:"multimodal_model"` MultimodalModel string `json:"multimodal_model" form:"multimodal_model"`
APIURL string `json:"api_url" form:"api_url"` APIURL string `json:"api_url" form:"api_url"`
APIKey string `json:"api_key" form:"api_key"` APIKey string `json:"api_key" form:"api_key"`
LocalRAGURL string `json:"local_rag_url" form:"local_rag_url"` LocalRAGURL string `json:"local_rag_url" form:"local_rag_url"`
LocalRAGAPIKey string `json:"local_rag_api_key" form:"local_rag_api_key"` LocalRAGAPIKey string `json:"local_rag_api_key" form:"local_rag_api_key"`
LastMessageDuration string `json:"last_message_duration" form:"last_message_duration"`
Name string `json:"name" form:"name"` Name string `json:"name" form:"name"`
HUD bool `json:"hud" form:"hud"` HUD bool `json:"hud" form:"hud"`
@@ -329,6 +330,14 @@ func NewAgentConfigMeta(
HelpText: "Maximum number of evaluation loops to perform when addressing gaps in responses", HelpText: "Maximum number of evaluation loops to perform when addressing gaps in responses",
Tags: config.Tags{Section: "AdvancedSettings"}, Tags: config.Tags{Section: "AdvancedSettings"},
}, },
{
Name: "last_message_duration",
Label: "Last Message Duration",
Type: "text",
DefaultValue: "5m",
HelpText: "Duration for the last message to be considered in the conversation",
Tags: config.Tags{Section: "AdvancedSettings"},
},
}, },
MCPServers: []config.Field{ MCPServers: []config.Field{
{ {

View File

@@ -462,6 +462,7 @@ func (a *AgentPool) startAgentWithConfig(name string, config *AgentConfig, obs O
}), }),
WithSystemPrompt(config.SystemPrompt), WithSystemPrompt(config.SystemPrompt),
WithMultimodalModel(multimodalModel), WithMultimodalModel(multimodalModel),
WithLastMessageDuration(config.LastMessageDuration),
WithAgentResultCallback(func(state types.ActionState) { WithAgentResultCallback(func(state types.ActionState) {
a.Lock() a.Lock()
if _, ok := a.agentStatus[name]; !ok { if _, ok := a.agentStatus[name]; !ok {

View File

@@ -88,7 +88,7 @@ func (a ActionDefinition) ToFunctionDefinition() *openai.FunctionDefinition {
// Actions is something the agent can do // Actions is something the agent can do
type Action interface { type Action interface {
Run(ctx context.Context, action ActionParams) (ActionResult, error) Run(ctx context.Context, sharedState *AgentSharedState, action ActionParams) (ActionResult, error)
Definition() ActionDefinition Definition() ActionDefinition
Plannable() bool Plannable() bool
} }

View File

@@ -1,6 +1,11 @@
package types package types
import "fmt" import (
"fmt"
"time"
"github.com/mudler/LocalAGI/core/conversations"
)
// State is the structure // State is the structure
// that is used to keep track of the current state // that is used to keep track of the current state
@@ -20,6 +25,23 @@ type AgentInternalState struct {
Goal string `json:"goal"` Goal string `json:"goal"`
} }
const (
DefaultLastMessageDuration = 5 * time.Minute
)
type AgentSharedState struct {
ConversationTracker *conversations.ConversationTracker[string] `json:"conversation_tracker"`
}
func NewAgentSharedState(lastMessageDuration time.Duration) *AgentSharedState {
if lastMessageDuration == 0 {
lastMessageDuration = DefaultLastMessageDuration
}
return &AgentSharedState{
ConversationTracker: conversations.NewConversationTracker[string](lastMessageDuration),
}
}
const fmtT = `===================== const fmtT = `=====================
NowDoing: %s NowDoing: %s
DoingNext: %s DoingNext: %s

View File

@@ -18,7 +18,7 @@ func NewBrowse(config map[string]string) *BrowseAction {
type BrowseAction struct{} type BrowseAction struct{}
func (a *BrowseAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { func (a *BrowseAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
result := struct { result := struct {
URL string `json:"url"` URL string `json:"url"`
}{} }{}

View File

@@ -45,7 +45,7 @@ func NewBrowserAgentRunner(config map[string]string, defaultURL string) *Browser
} }
} }
func (b *BrowserAgentRunner) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { func (b *BrowserAgentRunner) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
result := api.AgentRequest{} result := api.AgentRequest{}
err := params.Unmarshal(&result) err := params.Unmarshal(&result)
if err != nil { if err != nil {

View File

@@ -52,7 +52,7 @@ type CallAgentAction struct {
blacklist []string blacklist []string
} }
func (a *CallAgentAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { func (a *CallAgentAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
result := struct { result := struct {
AgentName string `json:"agent_name"` AgentName string `json:"agent_name"`
Message string `json:"message"` Message string `json:"message"`

View File

@@ -24,7 +24,7 @@ func NewCounter(config map[string]string) *CounterAction {
} }
// Run executes the counter action // Run executes the counter action
func (a *CounterAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { func (a *CounterAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
// Parse parameters // Parse parameters
request := struct { request := struct {
Name string `json:"name"` Name string `json:"name"`

View File

@@ -45,7 +45,7 @@ func NewDeepResearchRunner(config map[string]string, defaultURL string) *DeepRes
} }
} }
func (d *DeepResearchRunner) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { func (d *DeepResearchRunner) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
result := api.DeepResearchRequest{} result := api.DeepResearchRequest{}
err := params.Unmarshal(&result) err := params.Unmarshal(&result)
if err != nil { if err != nil {

View File

@@ -29,7 +29,7 @@ type GenImageAction struct {
imageModel string imageModel string
} }
func (a *GenImageAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { func (a *GenImageAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
result := struct { result := struct {
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
Size string `json:"size"` Size string `json:"size"`

View File

@@ -42,7 +42,7 @@ var _ = Describe("GenImageAction", func() {
"size": "256x256", "size": "256x256",
} }
url, err := action.Run(ctx, params) url, err := action.Run(ctx, nil, params)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(url).ToNot(BeEmpty()) Expect(url).ToNot(BeEmpty())
}) })
@@ -52,7 +52,7 @@ var _ = Describe("GenImageAction", func() {
"size": "256x256", "size": "256x256",
} }
_, err := action.Run(ctx, params) _, err := action.Run(ctx, nil, params)
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
}) })
}) })

View File

@@ -26,7 +26,7 @@ func NewGithubIssueCloser(config map[string]string) *GithubIssuesCloser {
} }
} }
func (g *GithubIssuesCloser) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { func (g *GithubIssuesCloser) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
result := struct { result := struct {
Repository string `json:"repository"` Repository string `json:"repository"`
Owner string `json:"owner"` Owner string `json:"owner"`

View File

@@ -27,7 +27,7 @@ func NewGithubIssueCommenter(config map[string]string) *GithubIssuesCommenter {
} }
} }
func (g *GithubIssuesCommenter) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { func (g *GithubIssuesCommenter) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
result := struct { result := struct {
Repository string `json:"repository"` Repository string `json:"repository"`
Owner string `json:"owner"` Owner string `json:"owner"`

View File

@@ -27,7 +27,7 @@ func NewGithubIssueEditor(config map[string]string) *GithubIssueEditor {
} }
} }
func (g *GithubIssueEditor) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { func (g *GithubIssueEditor) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
result := struct { result := struct {
Repository string `json:"repository"` Repository string `json:"repository"`
Owner string `json:"owner"` Owner string `json:"owner"`

View File

@@ -38,7 +38,7 @@ func NewGithubIssueLabeler(config map[string]string) *GithubIssuesLabeler {
} }
} }
func (g *GithubIssuesLabeler) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { func (g *GithubIssuesLabeler) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
result := struct { result := struct {
Repository string `json:"repository"` Repository string `json:"repository"`
Owner string `json:"owner"` Owner string `json:"owner"`

View File

@@ -27,7 +27,7 @@ func NewGithubIssueOpener(config map[string]string) *GithubIssuesOpener {
} }
} }
func (g *GithubIssuesOpener) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { func (g *GithubIssuesOpener) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
result := struct { result := struct {
Title string `json:"title"` Title string `json:"title"`
Body string `json:"text"` Body string `json:"text"`

View File

@@ -27,7 +27,7 @@ func NewGithubIssueReader(config map[string]string) *GithubIssuesReader {
} }
} }
func (g *GithubIssuesReader) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { func (g *GithubIssuesReader) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
result := struct { result := struct {
Repository string `json:"repository"` Repository string `json:"repository"`
Owner string `json:"owner"` Owner string `json:"owner"`

View File

@@ -28,7 +28,7 @@ func NewGithubIssueSearch(config map[string]string) *GithubIssueSearch {
} }
} }
func (g *GithubIssueSearch) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { func (g *GithubIssueSearch) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
result := struct { result := struct {
Query string `json:"query"` Query string `json:"query"`
Repository string `json:"repository"` Repository string `json:"repository"`

View File

@@ -3,8 +3,6 @@ package actions
import ( import (
"context" "context"
"fmt" "fmt"
"regexp"
"strconv"
"github.com/google/go-github/v69/github" "github.com/google/go-github/v69/github"
"github.com/mudler/LocalAGI/core/types" "github.com/mudler/LocalAGI/core/types"
@@ -17,96 +15,6 @@ type GithubPRCommenter struct {
client *github.Client client *github.Client
} }
var (
patchRegex = regexp.MustCompile(`^@@.*\d [\+\-](\d+),?(\d+)?.+?@@`)
)
type commitFileInfo struct {
FileName string
hunkInfos []*hunkInfo
sha string
}
type hunkInfo struct {
hunkStart int
hunkEnd int
}
func (hi hunkInfo) isLineInHunk(line int) bool {
return line >= hi.hunkStart && line <= hi.hunkEnd
}
func (cfi *commitFileInfo) getHunkInfo(line int) *hunkInfo {
for _, hunkInfo := range cfi.hunkInfos {
if hunkInfo.isLineInHunk(line) {
return hunkInfo
}
}
return nil
}
func (cfi *commitFileInfo) isLineInChange(line int) bool {
return cfi.getHunkInfo(line) != nil
}
func (cfi commitFileInfo) calculatePosition(line int) *int {
hi := cfi.getHunkInfo(line)
if hi == nil {
return nil
}
position := line - hi.hunkStart
return &position
}
func parseHunkPositions(patch, filename string) ([]*hunkInfo, error) {
hunkInfos := make([]*hunkInfo, 0)
if patch != "" {
groups := patchRegex.FindAllStringSubmatch(patch, -1)
if len(groups) < 1 {
return hunkInfos, fmt.Errorf("the patch details for [%s] could not be resolved", filename)
}
for _, patchGroup := range groups {
endPos := 2
if len(patchGroup) > 2 && patchGroup[2] == "" {
endPos = 1
}
hunkStart, err := strconv.Atoi(patchGroup[1])
if err != nil {
hunkStart = -1
}
hunkEnd, err := strconv.Atoi(patchGroup[endPos])
if err != nil {
hunkEnd = -1
}
hunkInfos = append(hunkInfos, &hunkInfo{
hunkStart: hunkStart,
hunkEnd: hunkEnd,
})
}
}
return hunkInfos, nil
}
func getCommitInfo(file *github.CommitFile) (*commitFileInfo, error) {
patch := file.GetPatch()
hunkInfos, err := parseHunkPositions(patch, *file.Filename)
if err != nil {
return nil, err
}
sha := file.GetSHA()
if sha == "" {
return nil, fmt.Errorf("the sha details for [%s] could not be resolved", *file.Filename)
}
return &commitFileInfo{
FileName: *file.Filename,
hunkInfos: hunkInfos,
sha: sha,
}, nil
}
func NewGithubPRCommenter(config map[string]string) *GithubPRCommenter { func NewGithubPRCommenter(config map[string]string) *GithubPRCommenter {
client := github.NewClient(nil).WithAuthToken(config["token"]) client := github.NewClient(nil).WithAuthToken(config["token"])
@@ -119,7 +27,7 @@ func NewGithubPRCommenter(config map[string]string) *GithubPRCommenter {
} }
} }
func (g *GithubPRCommenter) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { func (g *GithubPRCommenter) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
result := struct { result := struct {
Repository string `json:"repository"` Repository string `json:"repository"`
Owner string `json:"owner"` Owner string `json:"owner"`

View File

@@ -148,7 +148,7 @@ func (g *GithubPRCreator) createOrUpdateFile(ctx context.Context, branch string,
return nil return nil
} }
func (g *GithubPRCreator) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { func (g *GithubPRCreator) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
result := struct { result := struct {
Repository string `json:"repository"` Repository string `json:"repository"`
Owner string `json:"owner"` Owner string `json:"owner"`

View File

@@ -54,7 +54,7 @@ var _ = Describe("GithubPRCreator", func() {
}, },
} }
result, err := action.Run(ctx, params) result, err := action.Run(ctx, nil, params)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(result.Result).To(ContainSubstring("pull request #")) Expect(result.Result).To(ContainSubstring("pull request #"))
}) })
@@ -65,7 +65,7 @@ var _ = Describe("GithubPRCreator", func() {
"body": "This is a test pull request", "body": "This is a test pull request",
} }
_, err := action.Run(ctx, params) _, err := action.Run(ctx, nil, params)
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
}) })
}) })

View File

@@ -34,7 +34,7 @@ func NewGithubPRReader(config map[string]string) *GithubPRReader {
} }
} }
func (g *GithubPRReader) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { func (g *GithubPRReader) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
result := struct { result := struct {
Repository string `json:"repository"` Repository string `json:"repository"`
Owner string `json:"owner"` Owner string `json:"owner"`

View File

@@ -30,7 +30,7 @@ func NewGithubPRReviewer(config map[string]string) *GithubPRReviewer {
} }
} }
func (g *GithubPRReviewer) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { func (g *GithubPRReviewer) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
result := struct { result := struct {
Repository string `json:"repository"` Repository string `json:"repository"`
Owner string `json:"owner"` Owner string `json:"owner"`

View File

@@ -58,7 +58,7 @@ var _ = Describe("GithubPRReviewer", func() {
}, },
} }
result, err := reviewer.Run(ctx, params) result, err := reviewer.Run(ctx, nil, params)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(result.Result).To(ContainSubstring("reviewed successfully")) Expect(result.Result).To(ContainSubstring("reviewed successfully"))
}) })
@@ -70,7 +70,7 @@ var _ = Describe("GithubPRReviewer", func() {
"review_action": "COMMENT", "review_action": "COMMENT",
} }
result, err := reviewer.Run(ctx, params) result, err := reviewer.Run(ctx, nil, params)
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
Expect(result.Result).To(ContainSubstring("not found")) Expect(result.Result).To(ContainSubstring("not found"))
}) })
@@ -85,7 +85,7 @@ var _ = Describe("GithubPRReviewer", func() {
"review_action": "INVALID_ACTION", "review_action": "INVALID_ACTION",
} }
_, err := reviewer.Run(ctx, params) _, err := reviewer.Run(ctx, nil, params)
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
}) })
}) })

View File

@@ -30,7 +30,7 @@ func NewGithubRepositoryCreateOrUpdateContent(config map[string]string) *GithubR
} }
} }
func (g *GithubRepositoryCreateOrUpdateContent) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { func (g *GithubRepositoryCreateOrUpdateContent) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
result := struct { result := struct {
Path string `json:"path"` Path string `json:"path"`
Repository string `json:"repository"` Repository string `json:"repository"`

View File

@@ -101,7 +101,7 @@ func (g *GithubRepositoryGetAllContent) getContentRecursively(ctx context.Contex
return result.String(), nil return result.String(), nil
} }
func (g *GithubRepositoryGetAllContent) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { func (g *GithubRepositoryGetAllContent) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
result := struct { result := struct {
Repository string `json:"repository"` Repository string `json:"repository"`
Owner string `json:"owner"` Owner string `json:"owner"`

View File

@@ -45,7 +45,7 @@ var _ = Describe("GithubRepositoryGetAllContent", func() {
"path": ".", "path": ".",
} }
result, err := action.Run(ctx, params) result, err := action.Run(ctx, nil, params)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
Expect(result.Result).NotTo(BeEmpty()) Expect(result.Result).NotTo(BeEmpty())
@@ -64,7 +64,7 @@ var _ = Describe("GithubRepositoryGetAllContent", func() {
"path": "non-existent-path", "path": "non-existent-path",
} }
_, err := action.Run(ctx, params) _, err := action.Run(ctx, nil, params)
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
}) })
}) })

View File

@@ -27,7 +27,7 @@ func NewGithubRepositoryGetContent(config map[string]string) *GithubRepositoryGe
} }
} }
func (g *GithubRepositoryGetContent) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { func (g *GithubRepositoryGetContent) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
result := struct { result := struct {
Path string `json:"path"` Path string `json:"path"`
Repository string `json:"repository"` Repository string `json:"repository"`

View File

@@ -55,7 +55,7 @@ func (g *GithubRepositoryListFiles) listFilesRecursively(ctx context.Context, pa
return files, nil return files, nil
} }
func (g *GithubRepositoryListFiles) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { func (g *GithubRepositoryListFiles) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
result := struct { result := struct {
Repository string `json:"repository"` Repository string `json:"repository"`
Owner string `json:"owner"` Owner string `json:"owner"`

View File

@@ -25,7 +25,7 @@ func NewGithubRepositoryREADME(config map[string]string) *GithubRepositoryREADME
} }
} }
func (g *GithubRepositoryREADME) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { func (g *GithubRepositoryREADME) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
result := struct { result := struct {
Repository string `json:"repository"` Repository string `json:"repository"`
Owner string `json:"owner"` Owner string `json:"owner"`

View File

@@ -71,7 +71,7 @@ func (g *GithubRepositorySearchFiles) searchFilesRecursively(ctx context.Context
return result.String(), nil return result.String(), nil
} }
func (g *GithubRepositorySearchFiles) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { func (g *GithubRepositorySearchFiles) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
result := struct { result := struct {
Repository string `json:"repository"` Repository string `json:"repository"`
Owner string `json:"owner"` Owner string `json:"owner"`

View File

@@ -16,7 +16,7 @@ func NewScraper(config map[string]string) *ScraperAction {
type ScraperAction struct{} type ScraperAction struct{}
func (a *ScraperAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { func (a *ScraperAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
result := struct { result := struct {
URL string `json:"url"` URL string `json:"url"`
}{} }{}

View File

@@ -35,7 +35,7 @@ func NewSearch(config map[string]string) *SearchAction {
type SearchAction struct{ results int } type SearchAction struct{ results int }
func (a *SearchAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { func (a *SearchAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
result := struct { result := struct {
Query string `json:"query"` Query string `json:"query"`
}{} }{}

View File

@@ -28,7 +28,7 @@ type SendMailAction struct {
smtpPort string smtpPort string
} }
func (a *SendMailAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { func (a *SendMailAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
result := struct { result := struct {
Message string `json:"message"` Message string `json:"message"`
To string `json:"to"` To string `json:"to"`

View File

@@ -10,6 +10,7 @@ import (
"github.com/mudler/LocalAGI/core/types" "github.com/mudler/LocalAGI/core/types"
"github.com/mudler/LocalAGI/pkg/config" "github.com/mudler/LocalAGI/pkg/config"
"github.com/mudler/LocalAGI/pkg/xstrings" "github.com/mudler/LocalAGI/pkg/xstrings"
"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/jsonschema" "github.com/sashabaranov/go-openai/jsonschema"
) )
@@ -19,9 +20,11 @@ const (
) )
type SendTelegramMessageRunner struct { type SendTelegramMessageRunner struct {
token string token string
chatID int64 chatID int64
bot *bot.Bot bot *bot.Bot
customName string
customDescription string
} }
func NewSendTelegramMessageRunner(config map[string]string) *SendTelegramMessageRunner { func NewSendTelegramMessageRunner(config map[string]string) *SendTelegramMessageRunner {
@@ -46,9 +49,11 @@ func NewSendTelegramMessageRunner(config map[string]string) *SendTelegramMessage
} }
return &SendTelegramMessageRunner{ return &SendTelegramMessageRunner{
token: token, token: token,
chatID: chatID, chatID: chatID,
bot: b, bot: b,
customName: config["custom_name"],
customDescription: config["custom_description"],
} }
} }
@@ -57,7 +62,7 @@ type TelegramMessageParams struct {
Message string `json:"message"` Message string `json:"message"`
} }
func (s *SendTelegramMessageRunner) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { func (s *SendTelegramMessageRunner) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
var messageParams TelegramMessageParams var messageParams TelegramMessageParams
err := params.Unmarshal(&messageParams) err := params.Unmarshal(&messageParams)
if err != nil { if err != nil {
@@ -95,6 +100,11 @@ func (s *SendTelegramMessageRunner) Run(ctx context.Context, params types.Action
} }
} }
sharedState.ConversationTracker.AddMessage(fmt.Sprintf("telegram:%d", messageParams.ChatID), openai.ChatCompletionMessage{
Content: messageParams.Message,
Role: "assistant",
})
return types.ActionResult{ return types.ActionResult{
Result: fmt.Sprintf("Message sent successfully to chat ID %d in %d parts", messageParams.ChatID, len(messages)), Result: fmt.Sprintf("Message sent successfully to chat ID %d in %d parts", messageParams.ChatID, len(messages)),
Metadata: map[string]interface{}{ Metadata: map[string]interface{}{
@@ -104,10 +114,21 @@ func (s *SendTelegramMessageRunner) Run(ctx context.Context, params types.Action
} }
func (s *SendTelegramMessageRunner) Definition() types.ActionDefinition { func (s *SendTelegramMessageRunner) Definition() types.ActionDefinition {
customName := "send_telegram_message"
if s.customName != "" {
customName = s.customName
}
customDescription := "Send a message to a Telegram user or group"
if s.customDescription != "" {
customDescription = s.customDescription
}
if s.chatID != 0 { if s.chatID != 0 {
return types.ActionDefinition{ return types.ActionDefinition{
Name: "send_telegram_message", Name: types.ActionDefinitionName(customName),
Description: "Send a message to a Telegram user or group", Description: customDescription,
Properties: map[string]jsonschema.Definition{ Properties: map[string]jsonschema.Definition{
"message": { "message": {
Type: jsonschema.String, Type: jsonschema.String,
@@ -119,8 +140,8 @@ func (s *SendTelegramMessageRunner) Definition() types.ActionDefinition {
} }
return types.ActionDefinition{ return types.ActionDefinition{
Name: "send_telegram_message", Name: types.ActionDefinitionName(customName),
Description: "Send a message to a Telegram user or group", Description: customDescription,
Properties: map[string]jsonschema.Definition{ Properties: map[string]jsonschema.Definition{
"chat_id": { "chat_id": {
Type: jsonschema.Number, Type: jsonschema.Number,
@@ -156,5 +177,19 @@ func SendTelegramMessageConfigMeta() []config.Field {
Required: false, Required: false,
HelpText: "Default Telegram chat ID to send messages to (can be overridden in parameters)", HelpText: "Default Telegram chat ID to send messages to (can be overridden in parameters)",
}, },
{
Name: "custom_name",
Label: "Custom Name",
Type: config.FieldTypeText,
Required: false,
HelpText: "Custom name for the action (optional, defaults to 'send_telegram_message')",
},
{
Name: "custom_description",
Label: "Custom Description",
Type: config.FieldTypeText,
Required: false,
HelpText: "Custom description for the action (optional, defaults to 'Send a message to a Telegram user or group')",
},
} }
} }

View File

@@ -28,7 +28,7 @@ type ShellAction struct {
customDescription string customDescription string
} }
func (a *ShellAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { func (a *ShellAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
result := struct { result := struct {
Command string `json:"command"` Command string `json:"command"`
Host string `json:"host"` Host string `json:"host"`

View File

@@ -22,7 +22,7 @@ type PostTweetAction struct {
noCharacterLimit bool noCharacterLimit bool
} }
func (a *PostTweetAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { func (a *PostTweetAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
result := struct { result := struct {
Text string `json:"text"` Text string `json:"text"`
}{} }{}

View File

@@ -15,7 +15,7 @@ func NewWikipedia(config map[string]string) *WikipediaAction {
type WikipediaAction struct{} type WikipediaAction struct{}
func (a *WikipediaAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { func (a *WikipediaAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
result := struct { result := struct {
Query string `json:"query"` Query string `json:"query"`
}{} }{}

View File

@@ -2,8 +2,8 @@ package connectors
import ( import (
"encoding/json" "encoding/json"
"fmt"
"strings" "strings"
"time"
"github.com/bwmarrin/discordgo" "github.com/bwmarrin/discordgo"
"github.com/mudler/LocalAGI/core/agent" "github.com/mudler/LocalAGI/core/agent"
@@ -14,9 +14,8 @@ import (
) )
type Discord struct { type Discord struct {
token string token string
defaultChannel string defaultChannel string
conversationTracker *ConversationTracker[string]
} }
// NewDiscord creates a new Discord connector // NewDiscord creates a new Discord connector
@@ -25,11 +24,6 @@ type Discord struct {
// - defaultChannel: Discord channel to always answer even if not mentioned // - defaultChannel: Discord channel to always answer even if not mentioned
func NewDiscord(config map[string]string) *Discord { func NewDiscord(config map[string]string) *Discord {
duration, err := time.ParseDuration(config["lastMessageDuration"])
if err != nil {
duration = 5 * time.Minute
}
token := config["token"] token := config["token"]
if !strings.HasPrefix(token, "Bot ") { if !strings.HasPrefix(token, "Bot ") {
@@ -37,9 +31,8 @@ func NewDiscord(config map[string]string) *Discord {
} }
return &Discord{ return &Discord{
conversationTracker: NewConversationTracker[string](duration), token: token,
token: token, defaultChannel: config["defaultChannel"],
defaultChannel: config["defaultChannel"],
} }
} }
@@ -157,12 +150,12 @@ func (d *Discord) handleThreadMessage(a *agent.Agent, s *discordgo.Session, m *d
func (d *Discord) handleChannelMessage(a *agent.Agent, s *discordgo.Session, m *discordgo.MessageCreate) { func (d *Discord) handleChannelMessage(a *agent.Agent, s *discordgo.Session, m *discordgo.MessageCreate) {
d.conversationTracker.AddMessage(m.ChannelID, openai.ChatCompletionMessage{ a.SharedState().ConversationTracker.AddMessage(fmt.Sprintf("discord:%s", m.ChannelID), openai.ChatCompletionMessage{
Role: "user", Role: "user",
Content: m.Content, Content: m.Content,
}) })
conv := d.conversationTracker.GetConversation(m.ChannelID) conv := a.SharedState().ConversationTracker.GetConversation(fmt.Sprintf("discord:%s", m.ChannelID))
jobResult := a.Ask( jobResult := a.Ask(
types.WithConversationHistory(conv), types.WithConversationHistory(conv),
@@ -173,7 +166,7 @@ func (d *Discord) handleChannelMessage(a *agent.Agent, s *discordgo.Session, m *
return return
} }
d.conversationTracker.AddMessage(m.ChannelID, openai.ChatCompletionMessage{ a.SharedState().ConversationTracker.AddMessage(fmt.Sprintf("discord:%s", m.ChannelID), openai.ChatCompletionMessage{
Role: "assistant", Role: "assistant",
Content: jobResult.Response, Content: jobResult.Response,
}) })

View File

@@ -15,28 +15,22 @@ import (
) )
type IRC struct { type IRC struct {
server string server string
port string port string
nickname string nickname string
channel string channel string
conn *irc.Connection conn *irc.Connection
alwaysReply bool alwaysReply bool
conversationTracker *ConversationTracker[string]
} }
func NewIRC(config map[string]string) *IRC { func NewIRC(config map[string]string) *IRC {
duration, err := time.ParseDuration(config["lastMessageDuration"])
if err != nil {
duration = 5 * time.Minute
}
return &IRC{ return &IRC{
server: config["server"], server: config["server"],
port: config["port"], port: config["port"],
nickname: config["nickname"], nickname: config["nickname"],
channel: config["channel"], channel: config["channel"],
alwaysReply: config["alwaysReply"] == "true", alwaysReply: config["alwaysReply"] == "true",
conversationTracker: NewConversationTracker[string](duration),
} }
} }
@@ -115,7 +109,7 @@ func (i *IRC) Start(a *agent.Agent) {
cleanedMessage := cleanUpMessage(message, i.nickname) cleanedMessage := cleanUpMessage(message, i.nickname)
go func() { go func() {
conv := i.conversationTracker.GetConversation(channel) conv := a.SharedState().ConversationTracker.GetConversation(fmt.Sprintf("irc:%s", channel))
conv = append(conv, conv = append(conv,
openai.ChatCompletionMessage{ openai.ChatCompletionMessage{
@@ -125,7 +119,7 @@ func (i *IRC) Start(a *agent.Agent) {
) )
// Update the conversation history // Update the conversation history
i.conversationTracker.AddMessage(channel, openai.ChatCompletionMessage{ a.SharedState().ConversationTracker.AddMessage(fmt.Sprintf("irc:%s", channel), openai.ChatCompletionMessage{
Content: cleanedMessage, Content: cleanedMessage,
Role: "user", Role: "user",
}) })
@@ -140,7 +134,7 @@ func (i *IRC) Start(a *agent.Agent) {
} }
// Update the conversation history // Update the conversation history
i.conversationTracker.AddMessage(channel, openai.ChatCompletionMessage{ a.SharedState().ConversationTracker.AddMessage(fmt.Sprintf("irc:%s", channel), openai.ChatCompletionMessage{
Content: res.Response, Content: res.Response,
Role: "assistant", Role: "assistant",
}) })
@@ -209,7 +203,7 @@ func (i *IRC) Start(a *agent.Agent) {
// Start the IRC client in a goroutine // Start the IRC client in a goroutine
go i.conn.Loop() go i.conn.Loop()
go func() { go func() {
select { select {
case <-a.Context().Done(): case <-a.Context().Done():
i.conn.Quit() i.conn.Quit()
return return
@@ -249,11 +243,5 @@ func IRCConfigMeta() []config.Field {
Label: "Always Reply", Label: "Always Reply",
Type: config.FieldTypeCheckbox, Type: config.FieldTypeCheckbox,
}, },
{
Name: "lastMessageDuration",
Label: "Last Message Duration",
Type: config.FieldTypeText,
DefaultValue: "5m",
},
} }
} }

View File

@@ -31,27 +31,20 @@ type Matrix struct {
// Track active jobs for cancellation // Track active jobs for cancellation
activeJobs map[string][]*types.Job // map[roomID]bool to track if a room has active processing activeJobs map[string][]*types.Job // map[roomID]bool to track if a room has active processing
activeJobsMutex sync.RWMutex activeJobsMutex sync.RWMutex
conversationTracker *ConversationTracker[string]
} }
const matrixThinkingMessage = "🤔 thinking..." const matrixThinkingMessage = "🤔 thinking..."
func NewMatrix(config map[string]string) *Matrix { func NewMatrix(config map[string]string) *Matrix {
duration, err := time.ParseDuration(config["lastMessageDuration"])
if err != nil {
duration = 5 * time.Minute
}
return &Matrix{ return &Matrix{
homeserverURL: config["homeserverURL"], homeserverURL: config["homeserverURL"],
userID: config["userID"], userID: config["userID"],
accessToken: config["accessToken"], accessToken: config["accessToken"],
roomID: config["roomID"], roomID: config["roomID"],
roomMode: config["roomMode"] == "true", roomMode: config["roomMode"] == "true",
conversationTracker: NewConversationTracker[string](duration), placeholders: make(map[string]string),
placeholders: make(map[string]string), activeJobs: make(map[string][]*types.Job),
activeJobs: make(map[string][]*types.Job),
} }
} }
@@ -149,7 +142,7 @@ func (m *Matrix) handleRoomMessage(a *agent.Agent, evt *event.Event) {
// Cancel any active job for this room before starting a new one // Cancel any active job for this room before starting a new one
m.cancelActiveJobForRoom(evt.RoomID.String()) m.cancelActiveJobForRoom(evt.RoomID.String())
currentConv := m.conversationTracker.GetConversation(evt.RoomID.String()) currentConv := a.SharedState().ConversationTracker.GetConversation(fmt.Sprintf("matrix:%s", evt.RoomID.String()))
message := evt.Content.AsMessage().Body message := evt.Content.AsMessage().Body
@@ -163,8 +156,8 @@ func (m *Matrix) handleRoomMessage(a *agent.Agent, evt *event.Event) {
Content: message, Content: message,
}) })
m.conversationTracker.AddMessage( a.SharedState().ConversationTracker.AddMessage(
evt.RoomID.String(), currentConv[len(currentConv)-1], fmt.Sprintf("matrix:%s", evt.RoomID.String()), currentConv[len(currentConv)-1],
) )
agentOptions = append(agentOptions, types.WithConversationHistory(currentConv)) agentOptions = append(agentOptions, types.WithConversationHistory(currentConv))
@@ -209,8 +202,8 @@ func (m *Matrix) handleRoomMessage(a *agent.Agent, evt *event.Event) {
return return
} }
m.conversationTracker.AddMessage( a.SharedState().ConversationTracker.AddMessage(
evt.RoomID.String(), openai.ChatCompletionMessage{ fmt.Sprintf("matrix:%s", evt.RoomID.String()), openai.ChatCompletionMessage{
Role: "assistant", Role: "assistant",
Content: res.Response, Content: res.Response,
}, },
@@ -307,11 +300,5 @@ func MatrixConfigMeta() []config.Field {
Label: "Room Mode", Label: "Room Mode",
Type: config.FieldTypeCheckbox, Type: config.FieldTypeCheckbox,
}, },
{
Name: "lastMessageDuration",
Label: "Last Message Duration",
Type: config.FieldTypeText,
DefaultValue: "5m",
},
} }
} }

View File

@@ -8,7 +8,6 @@ import (
"os" "os"
"strings" "strings"
"sync" "sync"
"time"
"github.com/mudler/LocalAGI/pkg/config" "github.com/mudler/LocalAGI/pkg/config"
"github.com/mudler/LocalAGI/pkg/localoperator" "github.com/mudler/LocalAGI/pkg/localoperator"
@@ -42,27 +41,19 @@ type Slack struct {
// Track active jobs for cancellation // Track active jobs for cancellation
activeJobs map[string][]*types.Job // map[channelID]bool to track if a channel has active processing activeJobs map[string][]*types.Job // map[channelID]bool to track if a channel has active processing
activeJobsMutex sync.RWMutex activeJobsMutex sync.RWMutex
conversationTracker *ConversationTracker[string]
} }
const thinkingMessage = ":hourglass: thinking..." const thinkingMessage = ":hourglass: thinking..."
func NewSlack(config map[string]string) *Slack { func NewSlack(config map[string]string) *Slack {
duration, err := time.ParseDuration(config["lastMessageDuration"])
if err != nil {
duration = 5 * time.Minute
}
return &Slack{ return &Slack{
appToken: config["appToken"], appToken: config["appToken"],
botToken: config["botToken"], botToken: config["botToken"],
channelID: config["channelID"], channelID: config["channelID"],
channelMode: config["channelMode"] == "true", channelMode: config["channelMode"] == "true",
conversationTracker: NewConversationTracker[string](duration), placeholders: make(map[string]string),
placeholders: make(map[string]string), activeJobs: make(map[string][]*types.Job),
activeJobs: make(map[string][]*types.Job),
} }
} }
@@ -140,16 +131,6 @@ func cleanUpUsernameFromMessage(message string, b *slack.AuthTestResponse) strin
return cleaned return cleaned
} }
func extractUserIDsFromMessage(message string) []string {
var userIDs []string
for _, part := range strings.Split(message, " ") {
if strings.HasPrefix(part, "<@") && strings.HasSuffix(part, ">") {
userIDs = append(userIDs, strings.TrimPrefix(strings.TrimSuffix(part, ">"), "<@"))
}
}
return userIDs
}
func replaceUserIDsWithNamesInMessage(api *slack.Client, message string) string { func replaceUserIDsWithNamesInMessage(api *slack.Client, message string) string {
for _, part := range strings.Split(message, " ") { for _, part := range strings.Split(message, " ") {
if strings.HasPrefix(part, "<@") && strings.HasSuffix(part, ">") { if strings.HasPrefix(part, "<@") && strings.HasSuffix(part, ">") {
@@ -279,7 +260,7 @@ func (t *Slack) handleChannelMessage(
// Cancel any active job for this channel before starting a new one // Cancel any active job for this channel before starting a new one
t.cancelActiveJobForChannel(ev.Channel) t.cancelActiveJobForChannel(ev.Channel)
currentConv := t.conversationTracker.GetConversation(t.channelID) currentConv := a.SharedState().ConversationTracker.GetConversation(fmt.Sprintf("slack:%s", t.channelID))
message := replaceUserIDsWithNamesInMessage(api, cleanUpUsernameFromMessage(ev.Text, b)) message := replaceUserIDsWithNamesInMessage(api, cleanUpUsernameFromMessage(ev.Text, b))
@@ -323,8 +304,8 @@ func (t *Slack) handleChannelMessage(
}) })
} }
t.conversationTracker.AddMessage( a.SharedState().ConversationTracker.AddMessage(
t.channelID, currentConv[len(currentConv)-1], fmt.Sprintf("slack:%s", t.channelID), currentConv[len(currentConv)-1],
) )
agentOptions = append(agentOptions, types.WithConversationHistory(currentConv)) agentOptions = append(agentOptions, types.WithConversationHistory(currentConv))
@@ -370,14 +351,14 @@ func (t *Slack) handleChannelMessage(
return return
} }
t.conversationTracker.AddMessage( a.SharedState().ConversationTracker.AddMessage(
t.channelID, openai.ChatCompletionMessage{ fmt.Sprintf("slack:%s", t.channelID), openai.ChatCompletionMessage{
Role: "assistant", Role: "assistant",
Content: res.Response, Content: res.Response,
}, },
) )
xlog.Debug("After adding message to conversation tracker", "conversation", t.conversationTracker.GetConversation(t.channelID)) xlog.Debug("After adding message to conversation tracker", "conversation", a.SharedState().ConversationTracker.GetConversation(fmt.Sprintf("slack:%s", t.channelID)))
//res.Response = githubmarkdownconvertergo.Slack(res.Response) //res.Response = githubmarkdownconvertergo.Slack(res.Response)
@@ -752,6 +733,13 @@ func (t *Slack) Start(a *agent.Agent) {
if err != nil { if err != nil {
xlog.Error(fmt.Sprintf("Error posting message: %v", err)) xlog.Error(fmt.Sprintf("Error posting message: %v", err))
} }
a.SharedState().ConversationTracker.AddMessage(
fmt.Sprintf("slack:%s", t.channelID),
openai.ChatCompletionMessage{
Content: ccm.Content,
Role: "assistant",
},
)
}) })
} }
@@ -835,11 +823,5 @@ func SlackConfigMeta() []config.Field {
Label: "Always Reply", Label: "Always Reply",
Type: config.FieldTypeCheckbox, Type: config.FieldTypeCheckbox,
}, },
{
Name: "lastMessageDuration",
Label: "Last Message Duration",
Type: config.FieldTypeText,
DefaultValue: "5m",
},
} }
} }

View File

@@ -13,7 +13,6 @@ import (
"slices" "slices"
"strings" "strings"
"sync" "sync"
"time"
"github.com/go-telegram/bot" "github.com/go-telegram/bot"
"github.com/go-telegram/bot/models" "github.com/go-telegram/bot/models"
@@ -35,14 +34,8 @@ type Telegram struct {
bot *bot.Bot bot *bot.Bot
agent *agent.Agent agent *agent.Agent
currentconversation map[int64][]openai.ChatCompletionMessage
lastMessageTime map[int64]time.Time
lastMessageDuration time.Duration
admins []string admins []string
conversationTracker *ConversationTracker[int64]
// To track placeholder messages // To track placeholder messages
placeholders map[string]int // map[jobUUID]messageID placeholders map[string]int // map[jobUUID]messageID
placeholderMutex sync.RWMutex placeholderMutex sync.RWMutex
@@ -50,6 +43,8 @@ type Telegram struct {
// Track active jobs for cancellation // Track active jobs for cancellation
activeJobs map[int64][]*types.Job // map[chatID]bool to track if a chat has active processing activeJobs map[int64][]*types.Job // map[chatID]bool to track if a chat has active processing
activeJobsMutex sync.RWMutex activeJobsMutex sync.RWMutex
channelID string
} }
// Send any text message to the bot after the bot has been started // Send any text message to the bot after the bot has been started
@@ -219,6 +214,8 @@ func formatResponseWithURLs(response string, urls []string) string {
func (t *Telegram) handleUpdate(ctx context.Context, b *bot.Bot, a *agent.Agent, update *models.Update) { func (t *Telegram) handleUpdate(ctx context.Context, b *bot.Bot, a *agent.Agent, update *models.Update) {
username := update.Message.From.Username username := update.Message.From.Username
xlog.Debug("Received message from user", "username", username, "chatID", update.Message.Chat.ID, "message", update.Message.Text)
internalError := func(err error, msg *models.Message) { internalError := func(err error, msg *models.Message) {
xlog.Error("Error updating final message", "error", err) xlog.Error("Error updating final message", "error", err)
b.EditMessageText(ctx, &bot.EditMessageTextParams{ b.EditMessageText(ctx, &bot.EditMessageTextParams{
@@ -242,14 +239,14 @@ func (t *Telegram) handleUpdate(ctx context.Context, b *bot.Bot, a *agent.Agent,
// Cancel any active job for this chat before starting a new one // Cancel any active job for this chat before starting a new one
t.cancelActiveJobForChat(update.Message.Chat.ID) t.cancelActiveJobForChat(update.Message.Chat.ID)
currentConv := t.conversationTracker.GetConversation(update.Message.From.ID) currentConv := a.SharedState().ConversationTracker.GetConversation(fmt.Sprintf("telegram:%d", update.Message.From.ID))
currentConv = append(currentConv, openai.ChatCompletionMessage{ currentConv = append(currentConv, openai.ChatCompletionMessage{
Content: update.Message.Text, Content: update.Message.Text,
Role: "user", Role: "user",
}) })
t.conversationTracker.AddMessage( a.SharedState().ConversationTracker.AddMessage(
update.Message.From.ID, fmt.Sprintf("telegram:%d", update.Message.From.ID),
openai.ChatCompletionMessage{ openai.ChatCompletionMessage{
Content: update.Message.Text, Content: update.Message.Text,
Role: "user", Role: "user",
@@ -328,8 +325,8 @@ func (t *Telegram) handleUpdate(ctx context.Context, b *bot.Bot, a *agent.Agent,
return return
} }
t.conversationTracker.AddMessage( a.SharedState().ConversationTracker.AddMessage(
update.Message.From.ID, fmt.Sprintf("telegram:%d", update.Message.From.ID),
openai.ChatCompletionMessage{ openai.ChatCompletionMessage{
Content: res.Response, Content: res.Response,
Role: "assistant", Role: "assistant",
@@ -408,11 +405,34 @@ func (t *Telegram) Start(a *agent.Agent) {
t.agent = a t.agent = a
// go func() { // go func() {
// for m := range a.ConversationChannel() { // forc m := range a.ConversationChannel() {
// t.handleNewMessage(ctx, b, m) // t.handleNewMessage(ctx, b, m)
// } // }
// }() // }()
if t.channelID != "" {
// handle new conversations
a.AddSubscriber(func(ccm openai.ChatCompletionMessage) {
xlog.Debug("Subscriber(telegram)", "message", ccm.Content)
_, err := b.SendMessage(ctx, &bot.SendMessageParams{
ChatID: t.channelID,
Text: ccm.Content,
})
if err != nil {
xlog.Error("Error sending message", "error", err)
return
}
t.agent.SharedState().ConversationTracker.AddMessage(
fmt.Sprintf("telegram:%s", t.channelID),
openai.ChatCompletionMessage{
Content: ccm.Content,
Role: "assistant",
},
)
})
}
b.Start(ctx) b.Start(ctx)
} }
@@ -422,11 +442,6 @@ func NewTelegramConnector(config map[string]string) (*Telegram, error) {
return nil, errors.New("token is required") return nil, errors.New("token is required")
} }
duration, err := time.ParseDuration(config["lastMessageDuration"])
if err != nil {
duration = 5 * time.Minute
}
admins := []string{} admins := []string{}
if _, ok := config["admins"]; ok { if _, ok := config["admins"]; ok {
@@ -434,14 +449,11 @@ func NewTelegramConnector(config map[string]string) (*Telegram, error) {
} }
return &Telegram{ return &Telegram{
Token: token, Token: token,
lastMessageDuration: duration, admins: admins,
admins: admins, placeholders: make(map[string]int),
currentconversation: map[int64][]openai.ChatCompletionMessage{}, activeJobs: make(map[int64][]*types.Job),
lastMessageTime: map[int64]time.Time{}, channelID: config["channel_id"],
conversationTracker: NewConversationTracker[int64](duration),
placeholders: make(map[string]int),
activeJobs: make(map[int64][]*types.Job),
}, nil }, nil
} }
@@ -461,10 +473,10 @@ func TelegramConfigMeta() []config.Field {
HelpText: "Comma-separated list of Telegram usernames that are allowed to interact with the bot", HelpText: "Comma-separated list of Telegram usernames that are allowed to interact with the bot",
}, },
{ {
Name: "lastMessageDuration", Name: "channel_id",
Label: "Last Message Duration", Label: "Channel ID",
Type: config.FieldTypeText, Type: config.FieldTypeText,
DefaultValue: "5m", HelpText: "Telegram channel ID to send messages to if the agent needs to initiate a conversation",
}, },
} }
} }

View File

@@ -11,12 +11,14 @@ import (
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/mudler/LocalAGI/core/conversations"
coreTypes "github.com/mudler/LocalAGI/core/types" coreTypes "github.com/mudler/LocalAGI/core/types"
internalTypes "github.com/mudler/LocalAGI/core/types"
"github.com/mudler/LocalAGI/pkg/llm" "github.com/mudler/LocalAGI/pkg/llm"
"github.com/mudler/LocalAGI/pkg/xlog" "github.com/mudler/LocalAGI/pkg/xlog"
"github.com/mudler/LocalAGI/services" "github.com/mudler/LocalAGI/services"
"github.com/mudler/LocalAGI/services/connectors"
"github.com/mudler/LocalAGI/webui/types" "github.com/mudler/LocalAGI/webui/types"
"github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/jsonschema" "github.com/sashabaranov/go-openai/jsonschema"
@@ -33,6 +35,7 @@ type (
htmx *htmx.HTMX htmx *htmx.HTMX
config *Config config *Config
*fiber.App *fiber.App
sharedState *internalTypes.AgentSharedState
} }
) )
@@ -47,9 +50,10 @@ func NewApp(opts ...Option) *App {
}) })
a := &App{ a := &App{
htmx: htmx.New(), htmx: htmx.New(),
config: config, config: config,
App: webapp, App: webapp,
sharedState: internalTypes.NewAgentSharedState(5 * time.Minute),
} }
a.registerRoutes(config.Pool, webapp) a.registerRoutes(config.Pool, webapp)
@@ -443,7 +447,7 @@ func (a *App) GetActionDefinition(pool *state.AgentPool) func(c *fiber.Ctx) erro
} }
} }
func (a *App) ExecuteAction(pool *state.AgentPool) func(c *fiber.Ctx) error { func (app *App) ExecuteAction(pool *state.AgentPool) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
payload := struct { payload := struct {
Config map[string]string `json:"config"` Config map[string]string `json:"config"`
@@ -467,7 +471,7 @@ func (a *App) ExecuteAction(pool *state.AgentPool) func(c *fiber.Ctx) error {
ctx, cancel := context.WithTimeout(c.Context(), 200*time.Second) ctx, cancel := context.WithTimeout(c.Context(), 200*time.Second)
defer cancel() defer cancel()
res, err := a.Run(ctx, payload.Params) res, err := a.Run(ctx, app.sharedState, payload.Params)
if err != nil { if err != nil {
xlog.Error("Error running action", "error", err) xlog.Error("Error running action", "error", err)
return errorJSONMessage(c, err.Error()) return errorJSONMessage(c, err.Error())
@@ -484,7 +488,7 @@ func (a *App) ListActions() func(c *fiber.Ctx) error {
} }
} }
func (a *App) Responses(pool *state.AgentPool, tracker *connectors.ConversationTracker[string]) func(c *fiber.Ctx) error { func (a *App) Responses(pool *state.AgentPool, tracker *conversations.ConversationTracker[string]) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
var request types.RequestBody var request types.RequestBody
if err := c.BodyParser(&request); err != nil { if err := c.BodyParser(&request); err != nil {

View File

@@ -13,8 +13,8 @@ import (
fiber "github.com/gofiber/fiber/v2" fiber "github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/filesystem" "github.com/gofiber/fiber/v2/middleware/filesystem"
"github.com/gofiber/fiber/v2/middleware/keyauth" "github.com/gofiber/fiber/v2/middleware/keyauth"
"github.com/mudler/LocalAGI/core/conversations"
"github.com/mudler/LocalAGI/core/sse" "github.com/mudler/LocalAGI/core/sse"
"github.com/mudler/LocalAGI/services/connectors"
"github.com/mudler/LocalAGI/core/state" "github.com/mudler/LocalAGI/core/state"
"github.com/mudler/LocalAGI/core/types" "github.com/mudler/LocalAGI/core/types"
@@ -138,7 +138,7 @@ func (app *App) registerRoutes(pool *state.AgentPool, webapp *fiber.App) {
webapp.Post("/api/chat/:name", app.Chat(pool)) webapp.Post("/api/chat/:name", app.Chat(pool))
conversationTracker := connectors.NewConversationTracker[string](app.config.ConversationStoreDuration) conversationTracker := conversations.NewConversationTracker[string](app.config.ConversationStoreDuration)
webapp.Post("/v1/responses", app.Responses(pool, conversationTracker)) webapp.Post("/v1/responses", app.Responses(pool, conversationTracker))
@@ -268,7 +268,7 @@ func (app *App) registerRoutes(pool *state.AgentPool, webapp *fiber.App) {
} }
return c.JSON(fiber.Map{ return c.JSON(fiber.Map{
"Name": name, "Name": name,
"History": agent.Observer().History(), "History": agent.Observer().History(),
}) })
}) })