feat: track conversations inside connectors (#92)
* switch to observer pattern Signed-off-by: mudler <mudler@localai.io> * keep conversation history in telegram and slack * generalize with conversation tracker --------- Signed-off-by: mudler <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
53c1554d55
commit
d0cfc4c317
@@ -1,17 +1,19 @@
|
||||
package connectors
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/bwmarrin/discordgo"
|
||||
"github.com/mudler/LocalAgent/core/agent"
|
||||
"github.com/mudler/LocalAgent/core/types"
|
||||
"github.com/mudler/LocalAgent/pkg/xlog"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
type Discord struct {
|
||||
token string
|
||||
defaultChannel string
|
||||
token string
|
||||
defaultChannel string
|
||||
conversationTracker *ConversationTracker[string]
|
||||
}
|
||||
|
||||
// NewDiscord creates a new Discord connector
|
||||
@@ -19,9 +21,16 @@ type Discord struct {
|
||||
// - token: Discord token
|
||||
// - defaultChannel: Discord channel to always answer even if not mentioned
|
||||
func NewDiscord(config map[string]string) *Discord {
|
||||
|
||||
duration, err := time.ParseDuration(config["lastMessageDuration"])
|
||||
if err != nil {
|
||||
duration = 5 * time.Minute
|
||||
}
|
||||
|
||||
return &Discord{
|
||||
token: config["token"],
|
||||
defaultChannel: config["defaultChannel"],
|
||||
conversationTracker: NewConversationTracker[string](duration),
|
||||
token: config["token"],
|
||||
defaultChannel: config["defaultChannel"],
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,6 +78,71 @@ func (d *Discord) Start(a *agent.Agent) {
|
||||
}()
|
||||
}
|
||||
|
||||
func (d *Discord) handleThreadMessage(a *agent.Agent, s *discordgo.Session, m *discordgo.MessageCreate) {
|
||||
var messages []*discordgo.Message
|
||||
var err error
|
||||
messages, err = s.ChannelMessages(m.ChannelID, 100, "", m.MessageReference.MessageID, "")
|
||||
if err != nil {
|
||||
xlog.Info("error getting messages,", err)
|
||||
return
|
||||
}
|
||||
|
||||
conv := []openai.ChatCompletionMessage{}
|
||||
|
||||
for _, message := range messages {
|
||||
if message.Author.ID == s.State.User.ID {
|
||||
conv = append(conv, openai.ChatCompletionMessage{
|
||||
Role: "assistant",
|
||||
Content: message.Content,
|
||||
})
|
||||
} else {
|
||||
conv = append(conv, openai.ChatCompletionMessage{
|
||||
Role: "user",
|
||||
Content: message.Content,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
jobResult := a.Ask(
|
||||
types.WithConversationHistory(conv),
|
||||
)
|
||||
|
||||
if jobResult.Error != nil {
|
||||
xlog.Info("error asking agent,", jobResult.Error)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = s.ChannelMessageSend(m.ChannelID, jobResult.Response)
|
||||
if err != nil {
|
||||
xlog.Info("error sending message,", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Discord) handleChannelMessage(a *agent.Agent, s *discordgo.Session, m *discordgo.MessageCreate) {
|
||||
|
||||
conv := d.conversationTracker.GetConversation(m.ChannelID)
|
||||
|
||||
jobResult := a.Ask(
|
||||
types.WithConversationHistory(conv),
|
||||
)
|
||||
|
||||
if jobResult.Error != nil {
|
||||
xlog.Info("error asking agent,", jobResult.Error)
|
||||
return
|
||||
}
|
||||
|
||||
_, err := s.ChannelMessageSend(m.ChannelID, jobResult.Response)
|
||||
if err != nil {
|
||||
xlog.Info("error sending message,", err)
|
||||
}
|
||||
|
||||
d.conversationTracker.AddMessage(m.ChannelID, openai.ChatCompletionMessage{
|
||||
Role: "user",
|
||||
Content: m.Content,
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
// This function will be called (due to AddHandler above) every time a new
|
||||
// message is created on any channel that the authenticated bot has access to.
|
||||
func (d *Discord) messageCreate(a *agent.Agent) func(s *discordgo.Session, m *discordgo.MessageCreate) {
|
||||
@@ -78,40 +152,30 @@ func (d *Discord) messageCreate(a *agent.Agent) func(s *discordgo.Session, m *di
|
||||
if m.Author.ID == s.State.User.ID {
|
||||
return
|
||||
}
|
||||
interact := func() {
|
||||
//m := m.ContentWithMentionsReplaced()
|
||||
content := m.Content
|
||||
|
||||
content = strings.ReplaceAll(content, "<@"+s.State.User.ID+"> ", "")
|
||||
xlog.Info("Received message", "content", content)
|
||||
job := a.Ask(
|
||||
types.WithText(
|
||||
content,
|
||||
),
|
||||
)
|
||||
if job.Error != nil {
|
||||
xlog.Info("error asking agent,", job.Error)
|
||||
// Interact if we are mentioned
|
||||
mentioned := false
|
||||
for _, mention := range m.Mentions {
|
||||
if mention.ID == s.State.User.ID {
|
||||
mentioned = true
|
||||
return
|
||||
}
|
||||
|
||||
xlog.Info("Response", "response", job.Response)
|
||||
_, err := s.ChannelMessageSend(m.ChannelID, job.Response)
|
||||
if err != nil {
|
||||
xlog.Info("error sending message,", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Interact if we are mentioned
|
||||
for _, mention := range m.Mentions {
|
||||
if mention.ID == s.State.User.ID {
|
||||
go interact()
|
||||
return
|
||||
}
|
||||
if !mentioned && d.defaultChannel == "" {
|
||||
xlog.Debug("Not mentioned")
|
||||
return
|
||||
}
|
||||
|
||||
// check if the message is in a thread and get all messages in the thread
|
||||
if m.MessageReference != nil {
|
||||
d.handleThreadMessage(a, s, m)
|
||||
return
|
||||
}
|
||||
|
||||
// Or we are in the default channel (if one is set!)
|
||||
if d.defaultChannel != "" && m.ChannelID == d.defaultChannel {
|
||||
go interact()
|
||||
d.handleChannelMessage(a, s, m)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user