From daa7dcd12a3663ba4deec801e5980aba1443fe9f Mon Sep 17 00:00:00 2001 From: mudler Date: Wed, 2 Apr 2025 19:40:27 +0200 Subject: [PATCH] fix(discord): make it work Signed-off-by: mudler --- services/connectors/discord.go | 69 ++++++++++++++++++++++++++++------ 1 file changed, 57 insertions(+), 12 deletions(-) diff --git a/services/connectors/discord.go b/services/connectors/discord.go index 6f24e05..6d006d9 100644 --- a/services/connectors/discord.go +++ b/services/connectors/discord.go @@ -1,6 +1,8 @@ package connectors import ( + "encoding/json" + "strings" "time" "github.com/bwmarrin/discordgo" @@ -80,6 +82,8 @@ func (d *Discord) Start(a *agent.Agent) { return } + dg.StateEnabled = true + // Register the messageCreate func as a callback for MessageCreate events. dg.AddHandler(d.messageCreate(a)) @@ -104,7 +108,8 @@ func (d *Discord) Start(a *agent.Agent) { func (d *Discord) handleThreadMessage(a *agent.Agent, s *discordgo.Session, m *discordgo.MessageCreate) { var messages []*discordgo.Message var err error - messages, err = s.ChannelMessages(m.ChannelID, 100, "", m.MessageReference.MessageID, "") + + messages, err = s.ChannelMessages(m.ChannelID, 100, "", "", "") if err != nil { xlog.Info("error getting messages,", err) return @@ -112,20 +117,23 @@ func (d *Discord) handleThreadMessage(a *agent.Agent, s *discordgo.Session, m *d conv := []openai.ChatCompletionMessage{} - for _, message := range messages { + for i := len(messages) - 1; i >= 0; i-- { + message := messages[i] if message.Author.ID == s.State.User.ID { conv = append(conv, openai.ChatCompletionMessage{ Role: "assistant", - Content: message.Content, + Content: removeBotID(s, message.Content), }) } else { conv = append(conv, openai.ChatCompletionMessage{ Role: "user", - Content: message.Content, + Content: removeBotID(s, message.Content), }) } } + xlog.Debug("Conversation", "conversation", conv) + jobResult := a.Ask( types.WithConversationHistory(conv), ) @@ -143,13 +151,13 @@ func (d *Discord) handleThreadMessage(a *agent.Agent, s *discordgo.Session, m *d func (d *Discord) handleChannelMessage(a *agent.Agent, s *discordgo.Session, m *discordgo.MessageCreate) { - conv := d.conversationTracker.GetConversation(m.ChannelID) - d.conversationTracker.AddMessage(m.ChannelID, openai.ChatCompletionMessage{ Role: "user", Content: m.Content, }) + conv := d.conversationTracker.GetConversation(m.ChannelID) + jobResult := a.Ask( types.WithConversationHistory(conv), ) @@ -164,10 +172,28 @@ func (d *Discord) handleChannelMessage(a *agent.Agent, s *discordgo.Session, m * Content: jobResult.Response, }) - _, err := s.ChannelMessageSend(m.ChannelID, jobResult.Response) + thread, err := s.MessageThreadStartComplex(m.ChannelID, m.ID, &discordgo.ThreadStart{ + Name: "Thread for " + m.Author.Username, + AutoArchiveDuration: 60, + }) if err != nil { - xlog.Info("error sending message,", err) + xlog.Error("error creating thread", "err", err.Error()) + // Thread already exists + _, err = s.ChannelMessageSend(m.ChannelID, jobResult.Response) + if err != nil { + xlog.Error("error sending message to thread", "err", err.Error()) + } + } else { + _, err = s.ChannelMessageSend(thread.ID, jobResult.Response) + if err != nil { + xlog.Error("error sending message,", err) + } } + +} + +func removeBotID(s *discordgo.Session, m string) string { + return strings.ReplaceAll(m, "<@"+s.State.User.ID+">", "") } // This function will be called (due to AddHandler above) every time a new @@ -180,12 +206,16 @@ func (d *Discord) messageCreate(a *agent.Agent) func(s *discordgo.Session, m *di return } + m.Content = removeBotID(s, m.Content) + + xlog.Debug("Message received", "content", m.Content, "connector", "discord") + // Interact if we are mentioned mentioned := false for _, mention := range m.Mentions { if mention.ID == s.State.User.ID { mentioned = true - return + break } } @@ -194,15 +224,30 @@ func (d *Discord) messageCreate(a *agent.Agent) func(s *discordgo.Session, m *di return } + mm, _ := json.Marshal(m) + xlog.Debug("Discord message", "message", string(mm)) + + isThread := func() bool { + // NOTE: this doesn't seem to work, + // even if used in https://github.com/bwmarrin/discordgo/blob/5571950c905ff94d898501e5a0d76895fa140069/examples/threads/main.go#L33 + ch, err := s.State.Channel(m.ChannelID) + return !(err != nil || !ch.IsThread()) + } + // check if the message is in a thread and get all messages in the thread - if m.MessageReference != nil && - ((d.defaultChannel != "" && m.ChannelID == d.defaultChannel) || (mentioned && d.defaultChannel == "")) { - d.handleThreadMessage(a, s, m) + if isThread() { + xlog.Debug("Thread message") + if (d.defaultChannel != "" && m.ChannelID == d.defaultChannel) || (mentioned && d.defaultChannel == "") { + xlog.Debug("Thread message") + d.handleThreadMessage(a, s, m) + } + xlog.Info("ignoring thread message") return } // Or we are in the default channel (if one is set!) if (d.defaultChannel != "" && m.ChannelID == d.defaultChannel) || (mentioned && d.defaultChannel == "") { + xlog.Debug("Channel message") d.handleChannelMessage(a, s, m) return }