@@ -1,6 +1,8 @@
|
|||||||
package connectors
|
package connectors
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/bwmarrin/discordgo"
|
"github.com/bwmarrin/discordgo"
|
||||||
@@ -80,6 +82,8 @@ func (d *Discord) Start(a *agent.Agent) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
dg.StateEnabled = true
|
||||||
|
|
||||||
// Register the messageCreate func as a callback for MessageCreate events.
|
// Register the messageCreate func as a callback for MessageCreate events.
|
||||||
dg.AddHandler(d.messageCreate(a))
|
dg.AddHandler(d.messageCreate(a))
|
||||||
|
|
||||||
@@ -104,7 +108,8 @@ func (d *Discord) Start(a *agent.Agent) {
|
|||||||
func (d *Discord) handleThreadMessage(a *agent.Agent, s *discordgo.Session, m *discordgo.MessageCreate) {
|
func (d *Discord) handleThreadMessage(a *agent.Agent, s *discordgo.Session, m *discordgo.MessageCreate) {
|
||||||
var messages []*discordgo.Message
|
var messages []*discordgo.Message
|
||||||
var err error
|
var err error
|
||||||
messages, err = s.ChannelMessages(m.ChannelID, 100, "", m.MessageReference.MessageID, "")
|
|
||||||
|
messages, err = s.ChannelMessages(m.ChannelID, 100, "", "", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
xlog.Info("error getting messages,", err)
|
xlog.Info("error getting messages,", err)
|
||||||
return
|
return
|
||||||
@@ -112,20 +117,23 @@ func (d *Discord) handleThreadMessage(a *agent.Agent, s *discordgo.Session, m *d
|
|||||||
|
|
||||||
conv := []openai.ChatCompletionMessage{}
|
conv := []openai.ChatCompletionMessage{}
|
||||||
|
|
||||||
for _, message := range messages {
|
for i := len(messages) - 1; i >= 0; i-- {
|
||||||
|
message := messages[i]
|
||||||
if message.Author.ID == s.State.User.ID {
|
if message.Author.ID == s.State.User.ID {
|
||||||
conv = append(conv, openai.ChatCompletionMessage{
|
conv = append(conv, openai.ChatCompletionMessage{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: message.Content,
|
Content: removeBotID(s, message.Content),
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
conv = append(conv, openai.ChatCompletionMessage{
|
conv = append(conv, openai.ChatCompletionMessage{
|
||||||
Role: "user",
|
Role: "user",
|
||||||
Content: message.Content,
|
Content: removeBotID(s, message.Content),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
xlog.Debug("Conversation", "conversation", conv)
|
||||||
|
|
||||||
jobResult := a.Ask(
|
jobResult := a.Ask(
|
||||||
types.WithConversationHistory(conv),
|
types.WithConversationHistory(conv),
|
||||||
)
|
)
|
||||||
@@ -143,13 +151,13 @@ func (d *Discord) handleThreadMessage(a *agent.Agent, s *discordgo.Session, m *d
|
|||||||
|
|
||||||
func (d *Discord) handleChannelMessage(a *agent.Agent, s *discordgo.Session, m *discordgo.MessageCreate) {
|
func (d *Discord) handleChannelMessage(a *agent.Agent, s *discordgo.Session, m *discordgo.MessageCreate) {
|
||||||
|
|
||||||
conv := d.conversationTracker.GetConversation(m.ChannelID)
|
|
||||||
|
|
||||||
d.conversationTracker.AddMessage(m.ChannelID, openai.ChatCompletionMessage{
|
d.conversationTracker.AddMessage(m.ChannelID, openai.ChatCompletionMessage{
|
||||||
Role: "user",
|
Role: "user",
|
||||||
Content: m.Content,
|
Content: m.Content,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
conv := d.conversationTracker.GetConversation(m.ChannelID)
|
||||||
|
|
||||||
jobResult := a.Ask(
|
jobResult := a.Ask(
|
||||||
types.WithConversationHistory(conv),
|
types.WithConversationHistory(conv),
|
||||||
)
|
)
|
||||||
@@ -164,10 +172,28 @@ func (d *Discord) handleChannelMessage(a *agent.Agent, s *discordgo.Session, m *
|
|||||||
Content: jobResult.Response,
|
Content: jobResult.Response,
|
||||||
})
|
})
|
||||||
|
|
||||||
_, err := s.ChannelMessageSend(m.ChannelID, jobResult.Response)
|
thread, err := s.MessageThreadStartComplex(m.ChannelID, m.ID, &discordgo.ThreadStart{
|
||||||
|
Name: "Thread for " + m.Author.Username,
|
||||||
|
AutoArchiveDuration: 60,
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
xlog.Info("error sending message,", err)
|
xlog.Error("error creating thread", "err", err.Error())
|
||||||
|
// Thread already exists
|
||||||
|
_, err = s.ChannelMessageSend(m.ChannelID, jobResult.Response)
|
||||||
|
if err != nil {
|
||||||
|
xlog.Error("error sending message to thread", "err", err.Error())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
_, err = s.ChannelMessageSend(thread.ID, jobResult.Response)
|
||||||
|
if err != nil {
|
||||||
|
xlog.Error("error sending message,", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeBotID(s *discordgo.Session, m string) string {
|
||||||
|
return strings.ReplaceAll(m, "<@"+s.State.User.ID+">", "")
|
||||||
}
|
}
|
||||||
|
|
||||||
// This function will be called (due to AddHandler above) every time a new
|
// This function will be called (due to AddHandler above) every time a new
|
||||||
@@ -180,12 +206,16 @@ func (d *Discord) messageCreate(a *agent.Agent) func(s *discordgo.Session, m *di
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.Content = removeBotID(s, m.Content)
|
||||||
|
|
||||||
|
xlog.Debug("Message received", "content", m.Content, "connector", "discord")
|
||||||
|
|
||||||
// Interact if we are mentioned
|
// Interact if we are mentioned
|
||||||
mentioned := false
|
mentioned := false
|
||||||
for _, mention := range m.Mentions {
|
for _, mention := range m.Mentions {
|
||||||
if mention.ID == s.State.User.ID {
|
if mention.ID == s.State.User.ID {
|
||||||
mentioned = true
|
mentioned = true
|
||||||
return
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -194,15 +224,30 @@ func (d *Discord) messageCreate(a *agent.Agent) func(s *discordgo.Session, m *di
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mm, _ := json.Marshal(m)
|
||||||
|
xlog.Debug("Discord message", "message", string(mm))
|
||||||
|
|
||||||
|
isThread := func() bool {
|
||||||
|
// NOTE: this doesn't seem to work,
|
||||||
|
// even if used in https://github.com/bwmarrin/discordgo/blob/5571950c905ff94d898501e5a0d76895fa140069/examples/threads/main.go#L33
|
||||||
|
ch, err := s.State.Channel(m.ChannelID)
|
||||||
|
return !(err != nil || !ch.IsThread())
|
||||||
|
}
|
||||||
|
|
||||||
// check if the message is in a thread and get all messages in the thread
|
// check if the message is in a thread and get all messages in the thread
|
||||||
if m.MessageReference != nil &&
|
if isThread() {
|
||||||
((d.defaultChannel != "" && m.ChannelID == d.defaultChannel) || (mentioned && d.defaultChannel == "")) {
|
xlog.Debug("Thread message")
|
||||||
d.handleThreadMessage(a, s, m)
|
if (d.defaultChannel != "" && m.ChannelID == d.defaultChannel) || (mentioned && d.defaultChannel == "") {
|
||||||
|
xlog.Debug("Thread message")
|
||||||
|
d.handleThreadMessage(a, s, m)
|
||||||
|
}
|
||||||
|
xlog.Info("ignoring thread message")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Or we are in the default channel (if one is set!)
|
// Or we are in the default channel (if one is set!)
|
||||||
if (d.defaultChannel != "" && m.ChannelID == d.defaultChannel) || (mentioned && d.defaultChannel == "") {
|
if (d.defaultChannel != "" && m.ChannelID == d.defaultChannel) || (mentioned && d.defaultChannel == "") {
|
||||||
|
xlog.Debug("Channel message")
|
||||||
d.handleChannelMessage(a, s, m)
|
d.handleChannelMessage(a, s, m)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user