fix(discord): make it work

Signed-off-by: mudler <mudler@localai.io>
This commit is contained in:
mudler
2025-04-02 19:40:27 +02:00
parent b81f34a8f8
commit daa7dcd12a

View File

@@ -1,6 +1,8 @@
package connectors package connectors
import ( import (
"encoding/json"
"strings"
"time" "time"
"github.com/bwmarrin/discordgo" "github.com/bwmarrin/discordgo"
@@ -80,6 +82,8 @@ func (d *Discord) Start(a *agent.Agent) {
return return
} }
dg.StateEnabled = true
// Register the messageCreate func as a callback for MessageCreate events. // Register the messageCreate func as a callback for MessageCreate events.
dg.AddHandler(d.messageCreate(a)) dg.AddHandler(d.messageCreate(a))
@@ -104,7 +108,8 @@ func (d *Discord) Start(a *agent.Agent) {
func (d *Discord) handleThreadMessage(a *agent.Agent, s *discordgo.Session, m *discordgo.MessageCreate) { func (d *Discord) handleThreadMessage(a *agent.Agent, s *discordgo.Session, m *discordgo.MessageCreate) {
var messages []*discordgo.Message var messages []*discordgo.Message
var err error var err error
messages, err = s.ChannelMessages(m.ChannelID, 100, "", m.MessageReference.MessageID, "")
messages, err = s.ChannelMessages(m.ChannelID, 100, "", "", "")
if err != nil { if err != nil {
xlog.Info("error getting messages,", err) xlog.Info("error getting messages,", err)
return return
@@ -112,20 +117,23 @@ func (d *Discord) handleThreadMessage(a *agent.Agent, s *discordgo.Session, m *d
conv := []openai.ChatCompletionMessage{} conv := []openai.ChatCompletionMessage{}
for _, message := range messages { for i := len(messages) - 1; i >= 0; i-- {
message := messages[i]
if message.Author.ID == s.State.User.ID { if message.Author.ID == s.State.User.ID {
conv = append(conv, openai.ChatCompletionMessage{ conv = append(conv, openai.ChatCompletionMessage{
Role: "assistant", Role: "assistant",
Content: message.Content, Content: removeBotID(s, message.Content),
}) })
} else { } else {
conv = append(conv, openai.ChatCompletionMessage{ conv = append(conv, openai.ChatCompletionMessage{
Role: "user", Role: "user",
Content: message.Content, Content: removeBotID(s, message.Content),
}) })
} }
} }
xlog.Debug("Conversation", "conversation", conv)
jobResult := a.Ask( jobResult := a.Ask(
types.WithConversationHistory(conv), types.WithConversationHistory(conv),
) )
@@ -143,13 +151,13 @@ func (d *Discord) handleThreadMessage(a *agent.Agent, s *discordgo.Session, m *d
func (d *Discord) handleChannelMessage(a *agent.Agent, s *discordgo.Session, m *discordgo.MessageCreate) { func (d *Discord) handleChannelMessage(a *agent.Agent, s *discordgo.Session, m *discordgo.MessageCreate) {
conv := d.conversationTracker.GetConversation(m.ChannelID)
d.conversationTracker.AddMessage(m.ChannelID, openai.ChatCompletionMessage{ d.conversationTracker.AddMessage(m.ChannelID, openai.ChatCompletionMessage{
Role: "user", Role: "user",
Content: m.Content, Content: m.Content,
}) })
conv := d.conversationTracker.GetConversation(m.ChannelID)
jobResult := a.Ask( jobResult := a.Ask(
types.WithConversationHistory(conv), types.WithConversationHistory(conv),
) )
@@ -164,10 +172,28 @@ func (d *Discord) handleChannelMessage(a *agent.Agent, s *discordgo.Session, m *
Content: jobResult.Response, Content: jobResult.Response,
}) })
_, err := s.ChannelMessageSend(m.ChannelID, jobResult.Response) thread, err := s.MessageThreadStartComplex(m.ChannelID, m.ID, &discordgo.ThreadStart{
Name: "Thread for " + m.Author.Username,
AutoArchiveDuration: 60,
})
if err != nil { if err != nil {
xlog.Info("error sending message,", err) xlog.Error("error creating thread", "err", err.Error())
// Thread already exists
_, err = s.ChannelMessageSend(m.ChannelID, jobResult.Response)
if err != nil {
xlog.Error("error sending message to thread", "err", err.Error())
} }
} else {
_, err = s.ChannelMessageSend(thread.ID, jobResult.Response)
if err != nil {
xlog.Error("error sending message,", err)
}
}
}
func removeBotID(s *discordgo.Session, m string) string {
return strings.ReplaceAll(m, "<@"+s.State.User.ID+">", "")
} }
// This function will be called (due to AddHandler above) every time a new // This function will be called (due to AddHandler above) every time a new
@@ -180,12 +206,16 @@ func (d *Discord) messageCreate(a *agent.Agent) func(s *discordgo.Session, m *di
return return
} }
m.Content = removeBotID(s, m.Content)
xlog.Debug("Message received", "content", m.Content, "connector", "discord")
// Interact if we are mentioned // Interact if we are mentioned
mentioned := false mentioned := false
for _, mention := range m.Mentions { for _, mention := range m.Mentions {
if mention.ID == s.State.User.ID { if mention.ID == s.State.User.ID {
mentioned = true mentioned = true
return break
} }
} }
@@ -194,15 +224,30 @@ func (d *Discord) messageCreate(a *agent.Agent) func(s *discordgo.Session, m *di
return return
} }
mm, _ := json.Marshal(m)
xlog.Debug("Discord message", "message", string(mm))
isThread := func() bool {
// NOTE: this doesn't seem to work,
// even if used in https://github.com/bwmarrin/discordgo/blob/5571950c905ff94d898501e5a0d76895fa140069/examples/threads/main.go#L33
ch, err := s.State.Channel(m.ChannelID)
return !(err != nil || !ch.IsThread())
}
// check if the message is in a thread and get all messages in the thread // check if the message is in a thread and get all messages in the thread
if m.MessageReference != nil && if isThread() {
((d.defaultChannel != "" && m.ChannelID == d.defaultChannel) || (mentioned && d.defaultChannel == "")) { xlog.Debug("Thread message")
if (d.defaultChannel != "" && m.ChannelID == d.defaultChannel) || (mentioned && d.defaultChannel == "") {
xlog.Debug("Thread message")
d.handleThreadMessage(a, s, m) d.handleThreadMessage(a, s, m)
}
xlog.Info("ignoring thread message")
return return
} }
// Or we are in the default channel (if one is set!) // Or we are in the default channel (if one is set!)
if (d.defaultChannel != "" && m.ChannelID == d.defaultChannel) || (mentioned && d.defaultChannel == "") { if (d.defaultChannel != "" && m.ChannelID == d.defaultChannel) || (mentioned && d.defaultChannel == "") {
xlog.Debug("Channel message")
d.handleChannelMessage(a, s, m) d.handleChannelMessage(a, s, m)
return return
} }