diff --git a/core/agent/agent.go b/core/agent/agent.go index b877166..57928bf 100644 --- a/core/agent/agent.go +++ b/core/agent/agent.go @@ -40,6 +40,8 @@ type Agent struct { newConversations chan openai.ChatCompletionMessage mcpActions types.Actions + + newMessagesSubscribers []func(openai.ChatCompletionMessage) } type RAGDB interface { @@ -64,12 +66,13 @@ func New(opts ...Option) (*Agent, error) { ctx, cancel := context.WithCancel(c) a := &Agent{ - jobQueue: make(chan *types.Job), - options: options, - client: client, - Character: options.character, - currentState: &action.AgentInternalState{}, - context: types.NewActionContext(ctx, cancel), + jobQueue: make(chan *types.Job), + options: options, + client: client, + Character: options.character, + currentState: &action.AgentInternalState{}, + context: types.NewActionContext(ctx, cancel), + newMessagesSubscribers: options.newConversationsSubscribers, } if a.options.statefile != "" { @@ -102,9 +105,27 @@ func New(opts ...Option) (*Agent, error) { "model", a.options.LLMAPI.Model, ) + a.startNewConversationsConsumer() + return a, nil } +func (a *Agent) startNewConversationsConsumer() { + go func() { + for { + select { + case <-a.context.Done(): + return + + case msg := <-a.newConversations: + for _, s := range a.newMessagesSubscribers { + s(msg) + } + } + } + }() +} + // StopAction stops the current action // if any. Can be called before adding a new job. func (a *Agent) StopAction() { @@ -124,10 +145,6 @@ func (a *Agent) ActionContext() context.Context { return a.actionContext.Context } -func (a *Agent) ConversationChannel() chan openai.ChatCompletionMessage { - return a.newConversations -} - // Ask is a pre-emptive, blocking call that returns the response as soon as it's ready. // It discards any other computation. func (a *Agent) Ask(opts ...types.JobOption) *types.JobResult { diff --git a/core/agent/knowledgebase.go b/core/agent/knowledgebase.go index 56ae999..8d288e5 100644 --- a/core/agent/knowledgebase.go +++ b/core/agent/knowledgebase.go @@ -19,8 +19,7 @@ func (a *Agent) knowledgeBaseLookup(conv Messages) { // Walk conversation from bottom to top, and find the first message of the user // to use it as a query to the KB - var userMessage string - userMessage = conv.GetLatestUserMessage().Content + userMessage := conv.GetLatestUserMessage().Content xlog.Info("[Knowledge Base Lookup] Last user message", "agent", a.Character.Name, "message", userMessage, "lastMessage", conv.GetLatestUserMessage()) diff --git a/core/agent/options.go b/core/agent/options.go index f01f7dd..3376551 100644 --- a/core/agent/options.go +++ b/core/agent/options.go @@ -6,6 +6,7 @@ import ( "time" "github.com/mudler/LocalAgent/core/types" + "github.com/sashabaranov/go-openai" ) type Option func(*options) error @@ -49,6 +50,8 @@ type options struct { conversationsPath string mcpServers []MCPServer + + newConversationsSubscribers []func(openai.ChatCompletionMessage) } func (o *options) SeparatedMultimodalModel() bool { @@ -125,6 +128,13 @@ func EnableKnowledgeBaseWithResults(results int) Option { } } +func WithNewConversationSubscriber(sub func(openai.ChatCompletionMessage)) Option { + return func(o *options) error { + o.newConversationsSubscribers = append(o.newConversationsSubscribers, sub) + return nil + } +} + var EnableInitiateConversations = func(o *options) error { o.initiateConversations = true return nil diff --git a/services/connectors/conversationstracker.go b/services/connectors/conversationstracker.go new file mode 100644 index 0000000..cec34fe --- /dev/null +++ b/services/connectors/conversationstracker.go @@ -0,0 +1,75 @@ +package connectors + +import ( + "fmt" + "sync" + "time" + + "github.com/mudler/LocalAgent/pkg/xlog" + "github.com/sashabaranov/go-openai" +) + +type TrackerKey interface{ ~int | ~int64 | ~string } + +type ConversationTracker[K TrackerKey] struct { + convMutex sync.Mutex + currentconversation map[K][]openai.ChatCompletionMessage + lastMessageTime map[K]time.Time + lastMessageDuration time.Duration +} + +func NewConversationTracker[K TrackerKey](lastMessageDuration time.Duration) *ConversationTracker[K] { + return &ConversationTracker[K]{ + lastMessageDuration: lastMessageDuration, + currentconversation: map[K][]openai.ChatCompletionMessage{}, + lastMessageTime: map[K]time.Time{}, + } +} + +func (c *ConversationTracker[K]) GetConversation(key K) []openai.ChatCompletionMessage { + // Lock the conversation mutex to update the conversation history + c.convMutex.Lock() + defer c.convMutex.Unlock() + + // Clear up the conversation if the last message was sent more than lastMessageDuration ago + currentConv := []openai.ChatCompletionMessage{} + lastMessageTime := c.lastMessageTime[key] + if lastMessageTime.IsZero() { + lastMessageTime = time.Now() + } + if lastMessageTime.Add(c.lastMessageDuration).Before(time.Now()) { + currentConv = []openai.ChatCompletionMessage{} + c.lastMessageTime[key] = time.Now() + xlog.Debug("Conversation history does not exist for", "key", fmt.Sprintf("%v", key)) + } else { + xlog.Debug("Conversation history exists for", "key", fmt.Sprintf("%v", key)) + currentConv = append(currentConv, c.currentconversation[key]...) + } + + // cleanup other conversations if older + for k := range c.currentconversation { + lastMessage, exists := c.lastMessageTime[k] + if !exists { + delete(c.currentconversation, k) + delete(c.lastMessageTime, k) + continue + } + if lastMessage.Add(c.lastMessageDuration).Before(time.Now()) { + xlog.Debug("Cleaning up conversation for", k) + delete(c.currentconversation, k) + delete(c.lastMessageTime, k) + } + } + + return currentConv + +} + +func (c *ConversationTracker[K]) AddMessage(key K, message openai.ChatCompletionMessage) { + // Lock the conversation mutex to update the conversation history + c.convMutex.Lock() + defer c.convMutex.Unlock() + + c.currentconversation[key] = append(c.currentconversation[key], message) + c.lastMessageTime[key] = time.Now() +} diff --git a/services/connectors/discord.go b/services/connectors/discord.go index d264d08..4ed17e5 100644 --- a/services/connectors/discord.go +++ b/services/connectors/discord.go @@ -1,17 +1,19 @@ package connectors import ( - "strings" + "time" "github.com/bwmarrin/discordgo" "github.com/mudler/LocalAgent/core/agent" "github.com/mudler/LocalAgent/core/types" "github.com/mudler/LocalAgent/pkg/xlog" + "github.com/sashabaranov/go-openai" ) type Discord struct { - token string - defaultChannel string + token string + defaultChannel string + conversationTracker *ConversationTracker[string] } // NewDiscord creates a new Discord connector @@ -19,9 +21,16 @@ type Discord struct { // - token: Discord token // - defaultChannel: Discord channel to always answer even if not mentioned func NewDiscord(config map[string]string) *Discord { + + duration, err := time.ParseDuration(config["lastMessageDuration"]) + if err != nil { + duration = 5 * time.Minute + } + return &Discord{ - token: config["token"], - defaultChannel: config["defaultChannel"], + conversationTracker: NewConversationTracker[string](duration), + token: config["token"], + defaultChannel: config["defaultChannel"], } } @@ -69,6 +78,71 @@ func (d *Discord) Start(a *agent.Agent) { }() } +func (d *Discord) handleThreadMessage(a *agent.Agent, s *discordgo.Session, m *discordgo.MessageCreate) { + var messages []*discordgo.Message + var err error + messages, err = s.ChannelMessages(m.ChannelID, 100, "", m.MessageReference.MessageID, "") + if err != nil { + xlog.Info("error getting messages,", err) + return + } + + conv := []openai.ChatCompletionMessage{} + + for _, message := range messages { + if message.Author.ID == s.State.User.ID { + conv = append(conv, openai.ChatCompletionMessage{ + Role: "assistant", + Content: message.Content, + }) + } else { + conv = append(conv, openai.ChatCompletionMessage{ + Role: "user", + Content: message.Content, + }) + } + } + + jobResult := a.Ask( + types.WithConversationHistory(conv), + ) + + if jobResult.Error != nil { + xlog.Info("error asking agent,", jobResult.Error) + return + } + + _, err = s.ChannelMessageSend(m.ChannelID, jobResult.Response) + if err != nil { + xlog.Info("error sending message,", err) + } +} + +func (d *Discord) handleChannelMessage(a *agent.Agent, s *discordgo.Session, m *discordgo.MessageCreate) { + + conv := d.conversationTracker.GetConversation(m.ChannelID) + + jobResult := a.Ask( + types.WithConversationHistory(conv), + ) + + if jobResult.Error != nil { + xlog.Info("error asking agent,", jobResult.Error) + return + } + + _, err := s.ChannelMessageSend(m.ChannelID, jobResult.Response) + if err != nil { + xlog.Info("error sending message,", err) + } + + d.conversationTracker.AddMessage(m.ChannelID, openai.ChatCompletionMessage{ + Role: "user", + Content: m.Content, + }) + +} + // This function will be called (due to AddHandler above) every time a new // message is created on any channel that the authenticated bot has access to. func (d *Discord) messageCreate(a *agent.Agent) func(s *discordgo.Session, m *discordgo.MessageCreate) { @@ -78,40 +152,30 @@ func (d *Discord) messageCreate(a *agent.Agent) func(s *discordgo.Session, m *di if m.Author.ID == s.State.User.ID { return } - interact := func() { - //m := m.ContentWithMentionsReplaced() - content := m.Content - content = strings.ReplaceAll(content, "<@"+s.State.User.ID+"> ", "") - xlog.Info("Received message", "content", content) - job := a.Ask( - types.WithText( - content, - ), - ) - if job.Error != nil { - xlog.Info("error asking agent,", job.Error) + // Interact if we are mentioned + mentioned := false + for _, mention := range m.Mentions { + if mention.ID == s.State.User.ID { + mentioned = true return } - - xlog.Info("Response", "response", job.Response) - _, err := s.ChannelMessageSend(m.ChannelID, job.Response) - if err != nil { - xlog.Info("error sending message,", err) - } } - // Interact if we are mentioned - for _, mention := range m.Mentions { - if mention.ID == s.State.User.ID { - go interact() - return - } + if !mentioned && d.defaultChannel == "" { + xlog.Debug("Not mentioned") + return + } + + // check if the message is in a thread and get all messages in the thread + if m.MessageReference != nil { + d.handleThreadMessage(a, s, m) + return } // Or we are in the default channel (if one is set!) if d.defaultChannel != "" && m.ChannelID == d.defaultChannel { - go interact() + d.handleChannelMessage(a, s, m) return } } diff --git a/services/connectors/irc.go b/services/connectors/irc.go index 53a2fd2..026aea0 100644 --- a/services/connectors/irc.go +++ b/services/connectors/irc.go @@ -9,25 +9,33 @@ import ( "github.com/mudler/LocalAgent/core/types" "github.com/mudler/LocalAgent/pkg/xlog" "github.com/mudler/LocalAgent/services/actions" + "github.com/sashabaranov/go-openai" irc "github.com/thoj/go-ircevent" ) type IRC struct { - server string - port string - nickname string - channel string - conn *irc.Connection - alwaysReply bool + server string + port string + nickname string + channel string + conn *irc.Connection + alwaysReply bool + conversationTracker *ConversationTracker[string] } func NewIRC(config map[string]string) *IRC { + + duration, err := time.ParseDuration(config["lastMessageDuration"]) + if err != nil { + duration = 5 * time.Minute + } return &IRC{ - server: config["server"], - port: config["port"], - nickname: config["nickname"], - channel: config["channel"], - alwaysReply: config["alwaysReply"] == "true", + server: config["server"], + port: config["port"], + nickname: config["nickname"], + channel: config["channel"], + alwaysReply: config["alwaysReply"] == "true", + conversationTracker: NewConversationTracker[string](duration), } } @@ -102,13 +110,33 @@ func (i *IRC) Start(a *agent.Agent) { } xlog.Info("Recv message", "message", message, "sender", sender, "channel", channel) - cleanedMessage := "My name is " + sender + ". " + cleanUpMessage(message, i.nickname) + cleanedMessage := cleanUpMessage(message, i.nickname) go func() { - res := a.Ask( - types.WithText(cleanedMessage), + conv := i.conversationTracker.GetConversation(channel) + + conv = append(conv, + openai.ChatCompletionMessage{ + Content: cleanedMessage, + Role: "user", + }, ) + res := a.Ask( + types.WithConversationHistory(conv), + ) + + if res.Response == "" { + xlog.Info("No response from agent") + return + } + + // Update the conversation history + i.conversationTracker.AddMessage(channel, openai.ChatCompletionMessage{ + Content: res.Response, + Role: "assistant", + }) + xlog.Info("Sending message", "message", res.Response, "channel", channel) // Split the response into multiple messages if it's too long diff --git a/services/connectors/slack.go b/services/connectors/slack.go index 9fcd6e6..c635c21 100644 --- a/services/connectors/slack.go +++ b/services/connectors/slack.go @@ -9,6 +9,7 @@ import ( "os" "strings" "sync" + "time" "github.com/mudler/LocalAgent/pkg/xlog" "github.com/mudler/LocalAgent/services/actions" @@ -35,17 +36,28 @@ type Slack struct { placeholders map[string]string // map[jobUUID]messageTS placeholderMutex sync.RWMutex apiClient *slack.Client + + conversationTracker *ConversationTracker[string] + processing sync.Mutex + processingMessage bool } const thinkingMessage = "thinking..." func NewSlack(config map[string]string) *Slack { + + duration, err := time.ParseDuration(config["lastMessageDuration"]) + if err != nil { + duration = 5 * time.Minute + } + return &Slack{ - appToken: config["appToken"], - botToken: config["botToken"], - channelID: config["channelID"], - alwaysReply: config["alwaysReply"] == "true", - placeholders: make(map[string]string), + appToken: config["appToken"], + botToken: config["botToken"], + channelID: config["channelID"], + alwaysReply: config["alwaysReply"] == "true", + conversationTracker: NewConversationTracker[string](duration), + placeholders: make(map[string]string), } } @@ -182,6 +194,26 @@ func (t *Slack) handleChannelMessage( return } + currentConv := t.conversationTracker.GetConversation(t.channelID) + + // Lock the conversation mutex to update the conversation history + t.processing.Lock() + + // If we are already processing something, stop the current action + if t.processingMessage { + a.StopAction() + } else { + t.processingMessage = true + } + t.processing.Unlock() + + // Defer to reset the processing flag + defer func() { + t.processing.Lock() + t.processingMessage = false + t.processing.Unlock() + }() + message := replaceUserIDsWithNamesInMessage(api, cleanUpUsernameFromMessage(ev.Text, b)) go func() { @@ -221,22 +253,59 @@ func (t *Slack) handleChannelMessage( // If the last message has an image, we send it as a multi content message if len(imageBytes.Bytes()) > 0 { - // // Encode the image to base64 imgBase64, err := encodeImageFromURL(*imageBytes) if err != nil { xlog.Error(fmt.Sprintf("Error encoding image to base64: %v", err)) } else { - agentOptions = append(agentOptions, types.WithTextImage(message, fmt.Sprintf("data:%s;base64,%s", mimeType, imgBase64))) + currentConv = append(currentConv, + openai.ChatCompletionMessage{ + Role: "user", + MultiContent: []openai.ChatMessagePart{ + { + Text: message, + Type: openai.ChatMessagePartTypeText, + }, + { + Type: openai.ChatMessagePartTypeImageURL, + ImageURL: &openai.ChatMessageImageURL{ + URL: fmt.Sprintf("data:%s;base64,%s", mimeType, imgBase64), + }, + }, + }, + }, + ) } } else { - agentOptions = append(agentOptions, types.WithText(message)) + currentConv = append(currentConv, openai.ChatCompletionMessage{ + Role: "user", + Content: message, + }) } + agentOptions = append(agentOptions, types.WithConversationHistory(currentConv)) + res := a.Ask( agentOptions..., ) + if res.Response == "" { + xlog.Debug(fmt.Sprintf("Empty response from agent")) + return + } + + if res.Error != nil { + xlog.Error(fmt.Sprintf("Error from agent: %v", res.Error)) + return + } + + t.conversationTracker.AddMessage( + t.channelID, openai.ChatCompletionMessage{ + Role: "assistant", + Content: res.Response, + }, + ) + //res.Response = githubmarkdownconvertergo.Slack(res.Response) _, _, err = api.PostMessage(ev.Channel, @@ -250,6 +319,7 @@ func (t *Slack) handleChannelMessage( if err != nil { xlog.Error(fmt.Sprintf("Error posting message: %v", err)) } + }() } @@ -486,6 +556,15 @@ func (t *Slack) handleMention( types.WithMetadata(metadata), ) + if res.Response == "" { + xlog.Debug(fmt.Sprintf("Empty response from agent")) + _, _, err := api.DeleteMessage(ev.Channel, msgTs) + if err != nil { + xlog.Error(fmt.Sprintf("Error deleting message: %v", err)) + } + return + } + // get user id user, err := api.GetUserInfo(ev.User) if err != nil { diff --git a/services/connectors/telegram.go b/services/connectors/telegram.go index ac50616..b0f3e7b 100644 --- a/services/connectors/telegram.go +++ b/services/connectors/telegram.go @@ -5,18 +5,30 @@ import ( "errors" "os" "os/signal" + "slices" + "strings" + "time" "github.com/go-telegram/bot" "github.com/go-telegram/bot/models" "github.com/mudler/LocalAgent/core/agent" "github.com/mudler/LocalAgent/core/types" + "github.com/mudler/LocalAgent/pkg/xlog" + "github.com/sashabaranov/go-openai" ) type Telegram struct { - Token string - lastChatID int64 - bot *bot.Bot - agent *agent.Agent + Token string + bot *bot.Bot + agent *agent.Agent + + currentconversation map[int64][]openai.ChatCompletionMessage + lastMessageTime map[int64]time.Time + lastMessageDuration time.Duration + + admins []string + + conversationTracker *ConversationTracker[int64] } // Send any text message to the bot after the bot has been started @@ -38,24 +50,60 @@ func (t *Telegram) AgentReasoningCallback() func(state types.ActionCurrentState) } } +func (t *Telegram) handleUpdate(ctx context.Context, b *bot.Bot, a *agent.Agent, update *models.Update) { + username := update.Message.From.Username + + if len(t.admins) > 0 && !slices.Contains(t.admins, username) { + xlog.Info("Unauthorized user", "username", username) + return + } + + currentConv := t.conversationTracker.GetConversation(update.Message.From.ID) + currentConv = append(currentConv, openai.ChatCompletionMessage{ + Content: update.Message.Text, + Role: "user", + }) + + res := a.Ask( + types.WithConversationHistory(currentConv), + ) + + if res.Response == "" { + return + } + + t.conversationTracker.AddMessage( + update.Message.From.ID, + openai.ChatCompletionMessage{ + Content: res.Response, + Role: "assistant", + }, + ) + + b.SendMessage(ctx, &bot.SendMessageParams{ + ParseMode: models.ParseModeMarkdown, + ChatID: update.Message.Chat.ID, + Text: res.Response, + }) +} + +// func (t *Telegram) handleNewMessage(ctx context.Context, b *bot.Bot, m openai.ChatCompletionMessage) { +// if t.lastChatID == 0 { +// return +// } +// b.SendMessage(ctx, &bot.SendMessageParams{ +// ChatID: t.lastChatID, +// Text: m.Content, +// }) +// } + func (t *Telegram) Start(a *agent.Agent) { ctx, cancel := signal.NotifyContext(a.Context(), os.Interrupt) defer cancel() opts := []bot.Option{ bot.WithDefaultHandler(func(ctx context.Context, b *bot.Bot, update *models.Update) { - go func() { - res := a.Ask( - types.WithText( - update.Message.Text, - ), - ) - b.SendMessage(ctx, &bot.SendMessageParams{ - ChatID: update.Message.Chat.ID, - Text: res.Response, - }) - t.lastChatID = update.Message.Chat.ID - }() + go t.handleUpdate(ctx, b, a, update) }), } @@ -67,17 +115,11 @@ func (t *Telegram) Start(a *agent.Agent) { t.bot = b t.agent = a - go func() { - for m := range a.ConversationChannel() { - if t.lastChatID == 0 { - continue - } - b.SendMessage(ctx, &bot.SendMessageParams{ - ChatID: t.lastChatID, - Text: m.Content, - }) - } - }() + // go func() { + // for m := range a.ConversationChannel() { + // t.handleNewMessage(ctx, b, m) + // } + // }() b.Start(ctx) } @@ -88,7 +130,23 @@ func NewTelegramConnector(config map[string]string) (*Telegram, error) { return nil, errors.New("token is required") } + duration, err := time.ParseDuration(config["lastMessageDuration"]) + if err != nil { + duration = 5 * time.Minute + } + + admins := []string{} + + if _, ok := config["admins"]; ok { + admins = append(admins, strings.Split(config["admins"], ",")...) + } + return &Telegram{ - Token: token, + Token: token, + lastMessageDuration: duration, + admins: admins, + currentconversation: map[int64][]openai.ChatCompletionMessage{}, + lastMessageTime: map[int64]time.Time{}, + conversationTracker: NewConversationTracker[int64](duration), }, nil }