feat: track conversations inside connectors (#92)

* switch to observer pattern

Signed-off-by: mudler <mudler@localai.io>

* keep conversation history in telegram and slack

* generalize with conversation tracker

---------

Signed-off-by: mudler <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2025-03-25 16:31:03 +01:00
committed by GitHub
parent 53c1554d55
commit d0cfc4c317
8 changed files with 422 additions and 92 deletions

View File

@@ -0,0 +1,75 @@
package connectors
import (
"fmt"
"sync"
"time"
"github.com/mudler/LocalAgent/pkg/xlog"
"github.com/sashabaranov/go-openai"
)
type TrackerKey interface{ ~int | ~int64 | ~string }
type ConversationTracker[K TrackerKey] struct {
convMutex sync.Mutex
currentconversation map[K][]openai.ChatCompletionMessage
lastMessageTime map[K]time.Time
lastMessageDuration time.Duration
}
func NewConversationTracker[K TrackerKey](lastMessageDuration time.Duration) *ConversationTracker[K] {
return &ConversationTracker[K]{
lastMessageDuration: lastMessageDuration,
currentconversation: map[K][]openai.ChatCompletionMessage{},
lastMessageTime: map[K]time.Time{},
}
}
func (c *ConversationTracker[K]) GetConversation(key K) []openai.ChatCompletionMessage {
// Lock the conversation mutex to update the conversation history
c.convMutex.Lock()
defer c.convMutex.Unlock()
// Clear up the conversation if the last message was sent more than lastMessageDuration ago
currentConv := []openai.ChatCompletionMessage{}
lastMessageTime := c.lastMessageTime[key]
if lastMessageTime.IsZero() {
lastMessageTime = time.Now()
}
if lastMessageTime.Add(c.lastMessageDuration).Before(time.Now()) {
currentConv = []openai.ChatCompletionMessage{}
c.lastMessageTime[key] = time.Now()
xlog.Debug("Conversation history does not exist for", "key", fmt.Sprintf("%v", key))
} else {
xlog.Debug("Conversation history exists for", "key", fmt.Sprintf("%v", key))
currentConv = append(currentConv, c.currentconversation[key]...)
}
// cleanup other conversations if older
for k := range c.currentconversation {
lastMessage, exists := c.lastMessageTime[k]
if !exists {
delete(c.currentconversation, k)
delete(c.lastMessageTime, k)
continue
}
if lastMessage.Add(c.lastMessageDuration).Before(time.Now()) {
xlog.Debug("Cleaning up conversation for", k)
delete(c.currentconversation, k)
delete(c.lastMessageTime, k)
}
}
return currentConv
}
func (c *ConversationTracker[K]) AddMessage(key K, message openai.ChatCompletionMessage) {
// Lock the conversation mutex to update the conversation history
c.convMutex.Lock()
defer c.convMutex.Unlock()
c.currentconversation[key] = append(c.currentconversation[key], message)
c.lastMessageTime[key] = time.Now()
}

View File

@@ -1,17 +1,19 @@
package connectors
import (
"strings"
"time"
"github.com/bwmarrin/discordgo"
"github.com/mudler/LocalAgent/core/agent"
"github.com/mudler/LocalAgent/core/types"
"github.com/mudler/LocalAgent/pkg/xlog"
"github.com/sashabaranov/go-openai"
)
type Discord struct {
token string
defaultChannel string
token string
defaultChannel string
conversationTracker *ConversationTracker[string]
}
// NewDiscord creates a new Discord connector
@@ -19,9 +21,16 @@ type Discord struct {
// - token: Discord token
// - defaultChannel: Discord channel to always answer even if not mentioned
func NewDiscord(config map[string]string) *Discord {
duration, err := time.ParseDuration(config["lastMessageDuration"])
if err != nil {
duration = 5 * time.Minute
}
return &Discord{
token: config["token"],
defaultChannel: config["defaultChannel"],
conversationTracker: NewConversationTracker[string](duration),
token: config["token"],
defaultChannel: config["defaultChannel"],
}
}
@@ -69,6 +78,71 @@ func (d *Discord) Start(a *agent.Agent) {
}()
}
func (d *Discord) handleThreadMessage(a *agent.Agent, s *discordgo.Session, m *discordgo.MessageCreate) {
var messages []*discordgo.Message
var err error
messages, err = s.ChannelMessages(m.ChannelID, 100, "", m.MessageReference.MessageID, "")
if err != nil {
xlog.Info("error getting messages,", err)
return
}
conv := []openai.ChatCompletionMessage{}
for _, message := range messages {
if message.Author.ID == s.State.User.ID {
conv = append(conv, openai.ChatCompletionMessage{
Role: "assistant",
Content: message.Content,
})
} else {
conv = append(conv, openai.ChatCompletionMessage{
Role: "user",
Content: message.Content,
})
}
}
jobResult := a.Ask(
types.WithConversationHistory(conv),
)
if jobResult.Error != nil {
xlog.Info("error asking agent,", jobResult.Error)
return
}
_, err = s.ChannelMessageSend(m.ChannelID, jobResult.Response)
if err != nil {
xlog.Info("error sending message,", err)
}
}
func (d *Discord) handleChannelMessage(a *agent.Agent, s *discordgo.Session, m *discordgo.MessageCreate) {
conv := d.conversationTracker.GetConversation(m.ChannelID)
jobResult := a.Ask(
types.WithConversationHistory(conv),
)
if jobResult.Error != nil {
xlog.Info("error asking agent,", jobResult.Error)
return
}
_, err := s.ChannelMessageSend(m.ChannelID, jobResult.Response)
if err != nil {
xlog.Info("error sending message,", err)
}
d.conversationTracker.AddMessage(m.ChannelID, openai.ChatCompletionMessage{
Role: "user",
Content: m.Content,
})
}
// This function will be called (due to AddHandler above) every time a new
// message is created on any channel that the authenticated bot has access to.
func (d *Discord) messageCreate(a *agent.Agent) func(s *discordgo.Session, m *discordgo.MessageCreate) {
@@ -78,40 +152,30 @@ func (d *Discord) messageCreate(a *agent.Agent) func(s *discordgo.Session, m *di
if m.Author.ID == s.State.User.ID {
return
}
interact := func() {
//m := m.ContentWithMentionsReplaced()
content := m.Content
content = strings.ReplaceAll(content, "<@"+s.State.User.ID+"> ", "")
xlog.Info("Received message", "content", content)
job := a.Ask(
types.WithText(
content,
),
)
if job.Error != nil {
xlog.Info("error asking agent,", job.Error)
// Interact if we are mentioned
mentioned := false
for _, mention := range m.Mentions {
if mention.ID == s.State.User.ID {
mentioned = true
return
}
xlog.Info("Response", "response", job.Response)
_, err := s.ChannelMessageSend(m.ChannelID, job.Response)
if err != nil {
xlog.Info("error sending message,", err)
}
}
// Interact if we are mentioned
for _, mention := range m.Mentions {
if mention.ID == s.State.User.ID {
go interact()
return
}
if !mentioned && d.defaultChannel == "" {
xlog.Debug("Not mentioned")
return
}
// check if the message is in a thread and get all messages in the thread
if m.MessageReference != nil {
d.handleThreadMessage(a, s, m)
return
}
// Or we are in the default channel (if one is set!)
if d.defaultChannel != "" && m.ChannelID == d.defaultChannel {
go interact()
d.handleChannelMessage(a, s, m)
return
}
}

View File

@@ -9,25 +9,33 @@ import (
"github.com/mudler/LocalAgent/core/types"
"github.com/mudler/LocalAgent/pkg/xlog"
"github.com/mudler/LocalAgent/services/actions"
"github.com/sashabaranov/go-openai"
irc "github.com/thoj/go-ircevent"
)
type IRC struct {
server string
port string
nickname string
channel string
conn *irc.Connection
alwaysReply bool
server string
port string
nickname string
channel string
conn *irc.Connection
alwaysReply bool
conversationTracker *ConversationTracker[string]
}
func NewIRC(config map[string]string) *IRC {
duration, err := time.ParseDuration(config["lastMessageDuration"])
if err != nil {
duration = 5 * time.Minute
}
return &IRC{
server: config["server"],
port: config["port"],
nickname: config["nickname"],
channel: config["channel"],
alwaysReply: config["alwaysReply"] == "true",
server: config["server"],
port: config["port"],
nickname: config["nickname"],
channel: config["channel"],
alwaysReply: config["alwaysReply"] == "true",
conversationTracker: NewConversationTracker[string](duration),
}
}
@@ -102,13 +110,33 @@ func (i *IRC) Start(a *agent.Agent) {
}
xlog.Info("Recv message", "message", message, "sender", sender, "channel", channel)
cleanedMessage := "My name is " + sender + ". " + cleanUpMessage(message, i.nickname)
cleanedMessage := cleanUpMessage(message, i.nickname)
go func() {
res := a.Ask(
types.WithText(cleanedMessage),
conv := i.conversationTracker.GetConversation(channel)
conv = append(conv,
openai.ChatCompletionMessage{
Content: cleanedMessage,
Role: "user",
},
)
res := a.Ask(
types.WithConversationHistory(conv),
)
if res.Response == "" {
xlog.Info("No response from agent")
return
}
// Update the conversation history
i.conversationTracker.AddMessage(channel, openai.ChatCompletionMessage{
Content: res.Response,
Role: "assistant",
})
xlog.Info("Sending message", "message", res.Response, "channel", channel)
// Split the response into multiple messages if it's too long

View File

@@ -9,6 +9,7 @@ import (
"os"
"strings"
"sync"
"time"
"github.com/mudler/LocalAgent/pkg/xlog"
"github.com/mudler/LocalAgent/services/actions"
@@ -35,17 +36,28 @@ type Slack struct {
placeholders map[string]string // map[jobUUID]messageTS
placeholderMutex sync.RWMutex
apiClient *slack.Client
conversationTracker *ConversationTracker[string]
processing sync.Mutex
processingMessage bool
}
const thinkingMessage = "thinking..."
func NewSlack(config map[string]string) *Slack {
duration, err := time.ParseDuration(config["lastMessageDuration"])
if err != nil {
duration = 5 * time.Minute
}
return &Slack{
appToken: config["appToken"],
botToken: config["botToken"],
channelID: config["channelID"],
alwaysReply: config["alwaysReply"] == "true",
placeholders: make(map[string]string),
appToken: config["appToken"],
botToken: config["botToken"],
channelID: config["channelID"],
alwaysReply: config["alwaysReply"] == "true",
conversationTracker: NewConversationTracker[string](duration),
placeholders: make(map[string]string),
}
}
@@ -182,6 +194,26 @@ func (t *Slack) handleChannelMessage(
return
}
currentConv := t.conversationTracker.GetConversation(t.channelID)
// Lock the conversation mutex to update the conversation history
t.processing.Lock()
// If we are already processing something, stop the current action
if t.processingMessage {
a.StopAction()
} else {
t.processingMessage = true
}
t.processing.Unlock()
// Defer to reset the processing flag
defer func() {
t.processing.Lock()
t.processingMessage = false
t.processing.Unlock()
}()
message := replaceUserIDsWithNamesInMessage(api, cleanUpUsernameFromMessage(ev.Text, b))
go func() {
@@ -221,22 +253,59 @@ func (t *Slack) handleChannelMessage(
// If the last message has an image, we send it as a multi content message
if len(imageBytes.Bytes()) > 0 {
// // Encode the image to base64
imgBase64, err := encodeImageFromURL(*imageBytes)
if err != nil {
xlog.Error(fmt.Sprintf("Error encoding image to base64: %v", err))
} else {
agentOptions = append(agentOptions, types.WithTextImage(message, fmt.Sprintf("data:%s;base64,%s", mimeType, imgBase64)))
currentConv = append(currentConv,
openai.ChatCompletionMessage{
Role: "user",
MultiContent: []openai.ChatMessagePart{
{
Text: message,
Type: openai.ChatMessagePartTypeText,
},
{
Type: openai.ChatMessagePartTypeImageURL,
ImageURL: &openai.ChatMessageImageURL{
URL: fmt.Sprintf("data:%s;base64,%s", mimeType, imgBase64),
},
},
},
},
)
}
} else {
agentOptions = append(agentOptions, types.WithText(message))
currentConv = append(currentConv, openai.ChatCompletionMessage{
Role: "user",
Content: message,
})
}
agentOptions = append(agentOptions, types.WithConversationHistory(currentConv))
res := a.Ask(
agentOptions...,
)
if res.Response == "" {
xlog.Debug(fmt.Sprintf("Empty response from agent"))
return
}
if res.Error != nil {
xlog.Error(fmt.Sprintf("Error from agent: %v", res.Error))
return
}
t.conversationTracker.AddMessage(
t.channelID, openai.ChatCompletionMessage{
Role: "assistant",
Content: res.Response,
},
)
//res.Response = githubmarkdownconvertergo.Slack(res.Response)
_, _, err = api.PostMessage(ev.Channel,
@@ -250,6 +319,7 @@ func (t *Slack) handleChannelMessage(
if err != nil {
xlog.Error(fmt.Sprintf("Error posting message: %v", err))
}
}()
}
@@ -486,6 +556,15 @@ func (t *Slack) handleMention(
types.WithMetadata(metadata),
)
if res.Response == "" {
xlog.Debug(fmt.Sprintf("Empty response from agent"))
_, _, err := api.DeleteMessage(ev.Channel, msgTs)
if err != nil {
xlog.Error(fmt.Sprintf("Error deleting message: %v", err))
}
return
}
// get user id
user, err := api.GetUserInfo(ev.User)
if err != nil {

View File

@@ -5,18 +5,30 @@ import (
"errors"
"os"
"os/signal"
"slices"
"strings"
"time"
"github.com/go-telegram/bot"
"github.com/go-telegram/bot/models"
"github.com/mudler/LocalAgent/core/agent"
"github.com/mudler/LocalAgent/core/types"
"github.com/mudler/LocalAgent/pkg/xlog"
"github.com/sashabaranov/go-openai"
)
type Telegram struct {
Token string
lastChatID int64
bot *bot.Bot
agent *agent.Agent
Token string
bot *bot.Bot
agent *agent.Agent
currentconversation map[int64][]openai.ChatCompletionMessage
lastMessageTime map[int64]time.Time
lastMessageDuration time.Duration
admins []string
conversationTracker *ConversationTracker[int64]
}
// Send any text message to the bot after the bot has been started
@@ -38,24 +50,60 @@ func (t *Telegram) AgentReasoningCallback() func(state types.ActionCurrentState)
}
}
func (t *Telegram) handleUpdate(ctx context.Context, b *bot.Bot, a *agent.Agent, update *models.Update) {
username := update.Message.From.Username
if len(t.admins) > 0 && !slices.Contains(t.admins, username) {
xlog.Info("Unauthorized user", "username", username)
return
}
currentConv := t.conversationTracker.GetConversation(update.Message.From.ID)
currentConv = append(currentConv, openai.ChatCompletionMessage{
Content: update.Message.Text,
Role: "user",
})
res := a.Ask(
types.WithConversationHistory(currentConv),
)
if res.Response == "" {
return
}
t.conversationTracker.AddMessage(
update.Message.From.ID,
openai.ChatCompletionMessage{
Content: res.Response,
Role: "assistant",
},
)
b.SendMessage(ctx, &bot.SendMessageParams{
ParseMode: models.ParseModeMarkdown,
ChatID: update.Message.Chat.ID,
Text: res.Response,
})
}
// func (t *Telegram) handleNewMessage(ctx context.Context, b *bot.Bot, m openai.ChatCompletionMessage) {
// if t.lastChatID == 0 {
// return
// }
// b.SendMessage(ctx, &bot.SendMessageParams{
// ChatID: t.lastChatID,
// Text: m.Content,
// })
// }
func (t *Telegram) Start(a *agent.Agent) {
ctx, cancel := signal.NotifyContext(a.Context(), os.Interrupt)
defer cancel()
opts := []bot.Option{
bot.WithDefaultHandler(func(ctx context.Context, b *bot.Bot, update *models.Update) {
go func() {
res := a.Ask(
types.WithText(
update.Message.Text,
),
)
b.SendMessage(ctx, &bot.SendMessageParams{
ChatID: update.Message.Chat.ID,
Text: res.Response,
})
t.lastChatID = update.Message.Chat.ID
}()
go t.handleUpdate(ctx, b, a, update)
}),
}
@@ -67,17 +115,11 @@ func (t *Telegram) Start(a *agent.Agent) {
t.bot = b
t.agent = a
go func() {
for m := range a.ConversationChannel() {
if t.lastChatID == 0 {
continue
}
b.SendMessage(ctx, &bot.SendMessageParams{
ChatID: t.lastChatID,
Text: m.Content,
})
}
}()
// go func() {
// for m := range a.ConversationChannel() {
// t.handleNewMessage(ctx, b, m)
// }
// }()
b.Start(ctx)
}
@@ -88,7 +130,23 @@ func NewTelegramConnector(config map[string]string) (*Telegram, error) {
return nil, errors.New("token is required")
}
duration, err := time.ParseDuration(config["lastMessageDuration"])
if err != nil {
duration = 5 * time.Minute
}
admins := []string{}
if _, ok := config["admins"]; ok {
admins = append(admins, strings.Split(config["admins"], ",")...)
}
return &Telegram{
Token: token,
Token: token,
lastMessageDuration: duration,
admins: admins,
currentconversation: map[int64][]openai.ChatCompletionMessage{},
lastMessageTime: map[int64]time.Time{},
conversationTracker: NewConversationTracker[int64](duration),
}, nil
}