feat(agent): shared state, allow to track conversations globally (#148)
* feat(agent): shared state, allow to track conversations globally Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Cleanup Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * track conversations initiated by the bot Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
2b07dd79ec
commit
c23e655f44
@@ -1,84 +0,0 @@
|
||||
package connectors
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAGI/pkg/xlog"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
type TrackerKey interface{ ~int | ~int64 | ~string }
|
||||
|
||||
type ConversationTracker[K TrackerKey] struct {
|
||||
convMutex sync.Mutex
|
||||
currentconversation map[K][]openai.ChatCompletionMessage
|
||||
lastMessageTime map[K]time.Time
|
||||
lastMessageDuration time.Duration
|
||||
}
|
||||
|
||||
func NewConversationTracker[K TrackerKey](lastMessageDuration time.Duration) *ConversationTracker[K] {
|
||||
return &ConversationTracker[K]{
|
||||
lastMessageDuration: lastMessageDuration,
|
||||
currentconversation: map[K][]openai.ChatCompletionMessage{},
|
||||
lastMessageTime: map[K]time.Time{},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ConversationTracker[K]) GetConversation(key K) []openai.ChatCompletionMessage {
|
||||
// Lock the conversation mutex to update the conversation history
|
||||
c.convMutex.Lock()
|
||||
defer c.convMutex.Unlock()
|
||||
|
||||
// Clear up the conversation if the last message was sent more than lastMessageDuration ago
|
||||
currentConv := []openai.ChatCompletionMessage{}
|
||||
lastMessageTime := c.lastMessageTime[key]
|
||||
if lastMessageTime.IsZero() {
|
||||
lastMessageTime = time.Now()
|
||||
}
|
||||
if lastMessageTime.Add(c.lastMessageDuration).Before(time.Now()) {
|
||||
currentConv = []openai.ChatCompletionMessage{}
|
||||
c.lastMessageTime[key] = time.Now()
|
||||
xlog.Debug("Conversation history does not exist for", "key", fmt.Sprintf("%v", key))
|
||||
} else {
|
||||
xlog.Debug("Conversation history exists for", "key", fmt.Sprintf("%v", key))
|
||||
currentConv = append(currentConv, c.currentconversation[key]...)
|
||||
}
|
||||
|
||||
// cleanup other conversations if older
|
||||
for k := range c.currentconversation {
|
||||
lastMessage, exists := c.lastMessageTime[k]
|
||||
if !exists {
|
||||
delete(c.currentconversation, k)
|
||||
delete(c.lastMessageTime, k)
|
||||
continue
|
||||
}
|
||||
if lastMessage.Add(c.lastMessageDuration).Before(time.Now()) {
|
||||
xlog.Debug("Cleaning up conversation for", k)
|
||||
delete(c.currentconversation, k)
|
||||
delete(c.lastMessageTime, k)
|
||||
}
|
||||
}
|
||||
|
||||
return currentConv
|
||||
|
||||
}
|
||||
|
||||
func (c *ConversationTracker[K]) AddMessage(key K, message openai.ChatCompletionMessage) {
|
||||
// Lock the conversation mutex to update the conversation history
|
||||
c.convMutex.Lock()
|
||||
defer c.convMutex.Unlock()
|
||||
|
||||
c.currentconversation[key] = append(c.currentconversation[key], message)
|
||||
c.lastMessageTime[key] = time.Now()
|
||||
}
|
||||
|
||||
func (c *ConversationTracker[K]) SetConversation(key K, messages []openai.ChatCompletionMessage) {
|
||||
// Lock the conversation mutex to update the conversation history
|
||||
c.convMutex.Lock()
|
||||
defer c.convMutex.Unlock()
|
||||
|
||||
c.currentconversation[key] = messages
|
||||
c.lastMessageTime[key] = time.Now()
|
||||
}
|
||||
@@ -1,111 +0,0 @@
|
||||
package connectors_test
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAGI/services/connectors"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
var _ = Describe("ConversationTracker", func() {
|
||||
var (
|
||||
tracker *connectors.ConversationTracker[string]
|
||||
duration time.Duration
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
duration = 1 * time.Second
|
||||
tracker = connectors.NewConversationTracker[string](duration)
|
||||
})
|
||||
|
||||
It("should initialize with empty conversations", func() {
|
||||
Expect(tracker.GetConversation("test")).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("should add a message and retrieve it", func() {
|
||||
message := openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: "Hello",
|
||||
}
|
||||
tracker.AddMessage("test", message)
|
||||
conv := tracker.GetConversation("test")
|
||||
Expect(conv).To(HaveLen(1))
|
||||
Expect(conv[0]).To(Equal(message))
|
||||
})
|
||||
|
||||
It("should clear the conversation after the duration", func() {
|
||||
message := openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: "Hello",
|
||||
}
|
||||
tracker.AddMessage("test", message)
|
||||
time.Sleep(2 * time.Second)
|
||||
conv := tracker.GetConversation("test")
|
||||
Expect(conv).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("should keep the conversation within the duration", func() {
|
||||
message := openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: "Hello",
|
||||
}
|
||||
tracker.AddMessage("test", message)
|
||||
time.Sleep(500 * time.Millisecond) // Half the duration
|
||||
conv := tracker.GetConversation("test")
|
||||
Expect(conv).To(HaveLen(1))
|
||||
Expect(conv[0]).To(Equal(message))
|
||||
})
|
||||
|
||||
It("should handle multiple keys and clear old conversations", func() {
|
||||
message1 := openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: "Hello 1",
|
||||
}
|
||||
message2 := openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: "Hello 2",
|
||||
}
|
||||
|
||||
tracker.AddMessage("key1", message1)
|
||||
tracker.AddMessage("key2", message2)
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
conv1 := tracker.GetConversation("key1")
|
||||
conv2 := tracker.GetConversation("key2")
|
||||
|
||||
Expect(conv1).To(BeEmpty())
|
||||
Expect(conv2).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("should handle different key types", func() {
|
||||
trackerInt := connectors.NewConversationTracker[int](duration)
|
||||
trackerInt64 := connectors.NewConversationTracker[int64](duration)
|
||||
|
||||
message := openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: "Hello",
|
||||
}
|
||||
|
||||
trackerInt.AddMessage(1, message)
|
||||
trackerInt64.AddMessage(int64(1), message)
|
||||
|
||||
Expect(trackerInt.GetConversation(1)).To(HaveLen(1))
|
||||
Expect(trackerInt64.GetConversation(int64(1))).To(HaveLen(1))
|
||||
})
|
||||
|
||||
It("should cleanup other conversations if older", func() {
|
||||
message := openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: "Hello",
|
||||
}
|
||||
tracker.AddMessage("key1", message)
|
||||
tracker.AddMessage("key2", message)
|
||||
time.Sleep(2 * time.Second)
|
||||
tracker.GetConversation("key3")
|
||||
Expect(tracker.GetConversation("key1")).To(BeEmpty())
|
||||
Expect(tracker.GetConversation("key2")).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
@@ -2,8 +2,8 @@ package connectors
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/bwmarrin/discordgo"
|
||||
"github.com/mudler/LocalAGI/core/agent"
|
||||
@@ -14,9 +14,8 @@ import (
|
||||
)
|
||||
|
||||
type Discord struct {
|
||||
token string
|
||||
defaultChannel string
|
||||
conversationTracker *ConversationTracker[string]
|
||||
token string
|
||||
defaultChannel string
|
||||
}
|
||||
|
||||
// NewDiscord creates a new Discord connector
|
||||
@@ -25,11 +24,6 @@ type Discord struct {
|
||||
// - defaultChannel: Discord channel to always answer even if not mentioned
|
||||
func NewDiscord(config map[string]string) *Discord {
|
||||
|
||||
duration, err := time.ParseDuration(config["lastMessageDuration"])
|
||||
if err != nil {
|
||||
duration = 5 * time.Minute
|
||||
}
|
||||
|
||||
token := config["token"]
|
||||
|
||||
if !strings.HasPrefix(token, "Bot ") {
|
||||
@@ -37,9 +31,8 @@ func NewDiscord(config map[string]string) *Discord {
|
||||
}
|
||||
|
||||
return &Discord{
|
||||
conversationTracker: NewConversationTracker[string](duration),
|
||||
token: token,
|
||||
defaultChannel: config["defaultChannel"],
|
||||
token: token,
|
||||
defaultChannel: config["defaultChannel"],
|
||||
}
|
||||
}
|
||||
|
||||
@@ -157,12 +150,12 @@ func (d *Discord) handleThreadMessage(a *agent.Agent, s *discordgo.Session, m *d
|
||||
|
||||
func (d *Discord) handleChannelMessage(a *agent.Agent, s *discordgo.Session, m *discordgo.MessageCreate) {
|
||||
|
||||
d.conversationTracker.AddMessage(m.ChannelID, openai.ChatCompletionMessage{
|
||||
a.SharedState().ConversationTracker.AddMessage(fmt.Sprintf("discord:%s", m.ChannelID), openai.ChatCompletionMessage{
|
||||
Role: "user",
|
||||
Content: m.Content,
|
||||
})
|
||||
|
||||
conv := d.conversationTracker.GetConversation(m.ChannelID)
|
||||
conv := a.SharedState().ConversationTracker.GetConversation(fmt.Sprintf("discord:%s", m.ChannelID))
|
||||
|
||||
jobResult := a.Ask(
|
||||
types.WithConversationHistory(conv),
|
||||
@@ -173,7 +166,7 @@ func (d *Discord) handleChannelMessage(a *agent.Agent, s *discordgo.Session, m *
|
||||
return
|
||||
}
|
||||
|
||||
d.conversationTracker.AddMessage(m.ChannelID, openai.ChatCompletionMessage{
|
||||
a.SharedState().ConversationTracker.AddMessage(fmt.Sprintf("discord:%s", m.ChannelID), openai.ChatCompletionMessage{
|
||||
Role: "assistant",
|
||||
Content: jobResult.Response,
|
||||
})
|
||||
|
||||
@@ -15,28 +15,22 @@ import (
|
||||
)
|
||||
|
||||
type IRC struct {
|
||||
server string
|
||||
port string
|
||||
nickname string
|
||||
channel string
|
||||
conn *irc.Connection
|
||||
alwaysReply bool
|
||||
conversationTracker *ConversationTracker[string]
|
||||
server string
|
||||
port string
|
||||
nickname string
|
||||
channel string
|
||||
conn *irc.Connection
|
||||
alwaysReply bool
|
||||
}
|
||||
|
||||
func NewIRC(config map[string]string) *IRC {
|
||||
|
||||
duration, err := time.ParseDuration(config["lastMessageDuration"])
|
||||
if err != nil {
|
||||
duration = 5 * time.Minute
|
||||
}
|
||||
return &IRC{
|
||||
server: config["server"],
|
||||
port: config["port"],
|
||||
nickname: config["nickname"],
|
||||
channel: config["channel"],
|
||||
alwaysReply: config["alwaysReply"] == "true",
|
||||
conversationTracker: NewConversationTracker[string](duration),
|
||||
server: config["server"],
|
||||
port: config["port"],
|
||||
nickname: config["nickname"],
|
||||
channel: config["channel"],
|
||||
alwaysReply: config["alwaysReply"] == "true",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -115,7 +109,7 @@ func (i *IRC) Start(a *agent.Agent) {
|
||||
cleanedMessage := cleanUpMessage(message, i.nickname)
|
||||
|
||||
go func() {
|
||||
conv := i.conversationTracker.GetConversation(channel)
|
||||
conv := a.SharedState().ConversationTracker.GetConversation(fmt.Sprintf("irc:%s", channel))
|
||||
|
||||
conv = append(conv,
|
||||
openai.ChatCompletionMessage{
|
||||
@@ -125,7 +119,7 @@ func (i *IRC) Start(a *agent.Agent) {
|
||||
)
|
||||
|
||||
// Update the conversation history
|
||||
i.conversationTracker.AddMessage(channel, openai.ChatCompletionMessage{
|
||||
a.SharedState().ConversationTracker.AddMessage(fmt.Sprintf("irc:%s", channel), openai.ChatCompletionMessage{
|
||||
Content: cleanedMessage,
|
||||
Role: "user",
|
||||
})
|
||||
@@ -140,7 +134,7 @@ func (i *IRC) Start(a *agent.Agent) {
|
||||
}
|
||||
|
||||
// Update the conversation history
|
||||
i.conversationTracker.AddMessage(channel, openai.ChatCompletionMessage{
|
||||
a.SharedState().ConversationTracker.AddMessage(fmt.Sprintf("irc:%s", channel), openai.ChatCompletionMessage{
|
||||
Content: res.Response,
|
||||
Role: "assistant",
|
||||
})
|
||||
@@ -209,7 +203,7 @@ func (i *IRC) Start(a *agent.Agent) {
|
||||
// Start the IRC client in a goroutine
|
||||
go i.conn.Loop()
|
||||
go func() {
|
||||
select {
|
||||
select {
|
||||
case <-a.Context().Done():
|
||||
i.conn.Quit()
|
||||
return
|
||||
@@ -249,11 +243,5 @@ func IRCConfigMeta() []config.Field {
|
||||
Label: "Always Reply",
|
||||
Type: config.FieldTypeCheckbox,
|
||||
},
|
||||
{
|
||||
Name: "lastMessageDuration",
|
||||
Label: "Last Message Duration",
|
||||
Type: config.FieldTypeText,
|
||||
DefaultValue: "5m",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,27 +31,20 @@ type Matrix struct {
|
||||
// Track active jobs for cancellation
|
||||
activeJobs map[string][]*types.Job // map[roomID]bool to track if a room has active processing
|
||||
activeJobsMutex sync.RWMutex
|
||||
|
||||
conversationTracker *ConversationTracker[string]
|
||||
}
|
||||
|
||||
const matrixThinkingMessage = "🤔 thinking..."
|
||||
|
||||
func NewMatrix(config map[string]string) *Matrix {
|
||||
duration, err := time.ParseDuration(config["lastMessageDuration"])
|
||||
if err != nil {
|
||||
duration = 5 * time.Minute
|
||||
}
|
||||
|
||||
return &Matrix{
|
||||
homeserverURL: config["homeserverURL"],
|
||||
userID: config["userID"],
|
||||
accessToken: config["accessToken"],
|
||||
roomID: config["roomID"],
|
||||
roomMode: config["roomMode"] == "true",
|
||||
conversationTracker: NewConversationTracker[string](duration),
|
||||
placeholders: make(map[string]string),
|
||||
activeJobs: make(map[string][]*types.Job),
|
||||
homeserverURL: config["homeserverURL"],
|
||||
userID: config["userID"],
|
||||
accessToken: config["accessToken"],
|
||||
roomID: config["roomID"],
|
||||
roomMode: config["roomMode"] == "true",
|
||||
placeholders: make(map[string]string),
|
||||
activeJobs: make(map[string][]*types.Job),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -149,7 +142,7 @@ func (m *Matrix) handleRoomMessage(a *agent.Agent, evt *event.Event) {
|
||||
// Cancel any active job for this room before starting a new one
|
||||
m.cancelActiveJobForRoom(evt.RoomID.String())
|
||||
|
||||
currentConv := m.conversationTracker.GetConversation(evt.RoomID.String())
|
||||
currentConv := a.SharedState().ConversationTracker.GetConversation(fmt.Sprintf("matrix:%s", evt.RoomID.String()))
|
||||
|
||||
message := evt.Content.AsMessage().Body
|
||||
|
||||
@@ -163,8 +156,8 @@ func (m *Matrix) handleRoomMessage(a *agent.Agent, evt *event.Event) {
|
||||
Content: message,
|
||||
})
|
||||
|
||||
m.conversationTracker.AddMessage(
|
||||
evt.RoomID.String(), currentConv[len(currentConv)-1],
|
||||
a.SharedState().ConversationTracker.AddMessage(
|
||||
fmt.Sprintf("matrix:%s", evt.RoomID.String()), currentConv[len(currentConv)-1],
|
||||
)
|
||||
|
||||
agentOptions = append(agentOptions, types.WithConversationHistory(currentConv))
|
||||
@@ -209,8 +202,8 @@ func (m *Matrix) handleRoomMessage(a *agent.Agent, evt *event.Event) {
|
||||
return
|
||||
}
|
||||
|
||||
m.conversationTracker.AddMessage(
|
||||
evt.RoomID.String(), openai.ChatCompletionMessage{
|
||||
a.SharedState().ConversationTracker.AddMessage(
|
||||
fmt.Sprintf("matrix:%s", evt.RoomID.String()), openai.ChatCompletionMessage{
|
||||
Role: "assistant",
|
||||
Content: res.Response,
|
||||
},
|
||||
@@ -307,11 +300,5 @@ func MatrixConfigMeta() []config.Field {
|
||||
Label: "Room Mode",
|
||||
Type: config.FieldTypeCheckbox,
|
||||
},
|
||||
{
|
||||
Name: "lastMessageDuration",
|
||||
Label: "Last Message Duration",
|
||||
Type: config.FieldTypeText,
|
||||
DefaultValue: "5m",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAGI/pkg/config"
|
||||
"github.com/mudler/LocalAGI/pkg/localoperator"
|
||||
@@ -42,27 +41,19 @@ type Slack struct {
|
||||
// Track active jobs for cancellation
|
||||
activeJobs map[string][]*types.Job // map[channelID]bool to track if a channel has active processing
|
||||
activeJobsMutex sync.RWMutex
|
||||
|
||||
conversationTracker *ConversationTracker[string]
|
||||
}
|
||||
|
||||
const thinkingMessage = ":hourglass: thinking..."
|
||||
|
||||
func NewSlack(config map[string]string) *Slack {
|
||||
|
||||
duration, err := time.ParseDuration(config["lastMessageDuration"])
|
||||
if err != nil {
|
||||
duration = 5 * time.Minute
|
||||
}
|
||||
|
||||
return &Slack{
|
||||
appToken: config["appToken"],
|
||||
botToken: config["botToken"],
|
||||
channelID: config["channelID"],
|
||||
channelMode: config["channelMode"] == "true",
|
||||
conversationTracker: NewConversationTracker[string](duration),
|
||||
placeholders: make(map[string]string),
|
||||
activeJobs: make(map[string][]*types.Job),
|
||||
appToken: config["appToken"],
|
||||
botToken: config["botToken"],
|
||||
channelID: config["channelID"],
|
||||
channelMode: config["channelMode"] == "true",
|
||||
placeholders: make(map[string]string),
|
||||
activeJobs: make(map[string][]*types.Job),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -140,16 +131,6 @@ func cleanUpUsernameFromMessage(message string, b *slack.AuthTestResponse) strin
|
||||
return cleaned
|
||||
}
|
||||
|
||||
func extractUserIDsFromMessage(message string) []string {
|
||||
var userIDs []string
|
||||
for _, part := range strings.Split(message, " ") {
|
||||
if strings.HasPrefix(part, "<@") && strings.HasSuffix(part, ">") {
|
||||
userIDs = append(userIDs, strings.TrimPrefix(strings.TrimSuffix(part, ">"), "<@"))
|
||||
}
|
||||
}
|
||||
return userIDs
|
||||
}
|
||||
|
||||
func replaceUserIDsWithNamesInMessage(api *slack.Client, message string) string {
|
||||
for _, part := range strings.Split(message, " ") {
|
||||
if strings.HasPrefix(part, "<@") && strings.HasSuffix(part, ">") {
|
||||
@@ -279,7 +260,7 @@ func (t *Slack) handleChannelMessage(
|
||||
// Cancel any active job for this channel before starting a new one
|
||||
t.cancelActiveJobForChannel(ev.Channel)
|
||||
|
||||
currentConv := t.conversationTracker.GetConversation(t.channelID)
|
||||
currentConv := a.SharedState().ConversationTracker.GetConversation(fmt.Sprintf("slack:%s", t.channelID))
|
||||
|
||||
message := replaceUserIDsWithNamesInMessage(api, cleanUpUsernameFromMessage(ev.Text, b))
|
||||
|
||||
@@ -323,8 +304,8 @@ func (t *Slack) handleChannelMessage(
|
||||
})
|
||||
}
|
||||
|
||||
t.conversationTracker.AddMessage(
|
||||
t.channelID, currentConv[len(currentConv)-1],
|
||||
a.SharedState().ConversationTracker.AddMessage(
|
||||
fmt.Sprintf("slack:%s", t.channelID), currentConv[len(currentConv)-1],
|
||||
)
|
||||
|
||||
agentOptions = append(agentOptions, types.WithConversationHistory(currentConv))
|
||||
@@ -370,14 +351,14 @@ func (t *Slack) handleChannelMessage(
|
||||
return
|
||||
}
|
||||
|
||||
t.conversationTracker.AddMessage(
|
||||
t.channelID, openai.ChatCompletionMessage{
|
||||
a.SharedState().ConversationTracker.AddMessage(
|
||||
fmt.Sprintf("slack:%s", t.channelID), openai.ChatCompletionMessage{
|
||||
Role: "assistant",
|
||||
Content: res.Response,
|
||||
},
|
||||
)
|
||||
|
||||
xlog.Debug("After adding message to conversation tracker", "conversation", t.conversationTracker.GetConversation(t.channelID))
|
||||
xlog.Debug("After adding message to conversation tracker", "conversation", a.SharedState().ConversationTracker.GetConversation(fmt.Sprintf("slack:%s", t.channelID)))
|
||||
|
||||
//res.Response = githubmarkdownconvertergo.Slack(res.Response)
|
||||
|
||||
@@ -752,6 +733,13 @@ func (t *Slack) Start(a *agent.Agent) {
|
||||
if err != nil {
|
||||
xlog.Error(fmt.Sprintf("Error posting message: %v", err))
|
||||
}
|
||||
a.SharedState().ConversationTracker.AddMessage(
|
||||
fmt.Sprintf("slack:%s", t.channelID),
|
||||
openai.ChatCompletionMessage{
|
||||
Content: ccm.Content,
|
||||
Role: "assistant",
|
||||
},
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -835,11 +823,5 @@ func SlackConfigMeta() []config.Field {
|
||||
Label: "Always Reply",
|
||||
Type: config.FieldTypeCheckbox,
|
||||
},
|
||||
{
|
||||
Name: "lastMessageDuration",
|
||||
Label: "Last Message Duration",
|
||||
Type: config.FieldTypeText,
|
||||
DefaultValue: "5m",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-telegram/bot"
|
||||
"github.com/go-telegram/bot/models"
|
||||
@@ -35,14 +34,8 @@ type Telegram struct {
|
||||
bot *bot.Bot
|
||||
agent *agent.Agent
|
||||
|
||||
currentconversation map[int64][]openai.ChatCompletionMessage
|
||||
lastMessageTime map[int64]time.Time
|
||||
lastMessageDuration time.Duration
|
||||
|
||||
admins []string
|
||||
|
||||
conversationTracker *ConversationTracker[int64]
|
||||
|
||||
// To track placeholder messages
|
||||
placeholders map[string]int // map[jobUUID]messageID
|
||||
placeholderMutex sync.RWMutex
|
||||
@@ -50,6 +43,8 @@ type Telegram struct {
|
||||
// Track active jobs for cancellation
|
||||
activeJobs map[int64][]*types.Job // map[chatID]bool to track if a chat has active processing
|
||||
activeJobsMutex sync.RWMutex
|
||||
|
||||
channelID string
|
||||
}
|
||||
|
||||
// Send any text message to the bot after the bot has been started
|
||||
@@ -219,6 +214,8 @@ func formatResponseWithURLs(response string, urls []string) string {
|
||||
func (t *Telegram) handleUpdate(ctx context.Context, b *bot.Bot, a *agent.Agent, update *models.Update) {
|
||||
username := update.Message.From.Username
|
||||
|
||||
xlog.Debug("Received message from user", "username", username, "chatID", update.Message.Chat.ID, "message", update.Message.Text)
|
||||
|
||||
internalError := func(err error, msg *models.Message) {
|
||||
xlog.Error("Error updating final message", "error", err)
|
||||
b.EditMessageText(ctx, &bot.EditMessageTextParams{
|
||||
@@ -242,14 +239,14 @@ func (t *Telegram) handleUpdate(ctx context.Context, b *bot.Bot, a *agent.Agent,
|
||||
// Cancel any active job for this chat before starting a new one
|
||||
t.cancelActiveJobForChat(update.Message.Chat.ID)
|
||||
|
||||
currentConv := t.conversationTracker.GetConversation(update.Message.From.ID)
|
||||
currentConv := a.SharedState().ConversationTracker.GetConversation(fmt.Sprintf("telegram:%d", update.Message.From.ID))
|
||||
currentConv = append(currentConv, openai.ChatCompletionMessage{
|
||||
Content: update.Message.Text,
|
||||
Role: "user",
|
||||
})
|
||||
|
||||
t.conversationTracker.AddMessage(
|
||||
update.Message.From.ID,
|
||||
a.SharedState().ConversationTracker.AddMessage(
|
||||
fmt.Sprintf("telegram:%d", update.Message.From.ID),
|
||||
openai.ChatCompletionMessage{
|
||||
Content: update.Message.Text,
|
||||
Role: "user",
|
||||
@@ -328,8 +325,8 @@ func (t *Telegram) handleUpdate(ctx context.Context, b *bot.Bot, a *agent.Agent,
|
||||
return
|
||||
}
|
||||
|
||||
t.conversationTracker.AddMessage(
|
||||
update.Message.From.ID,
|
||||
a.SharedState().ConversationTracker.AddMessage(
|
||||
fmt.Sprintf("telegram:%d", update.Message.From.ID),
|
||||
openai.ChatCompletionMessage{
|
||||
Content: res.Response,
|
||||
Role: "assistant",
|
||||
@@ -408,11 +405,34 @@ func (t *Telegram) Start(a *agent.Agent) {
|
||||
t.agent = a
|
||||
|
||||
// go func() {
|
||||
// for m := range a.ConversationChannel() {
|
||||
// forc m := range a.ConversationChannel() {
|
||||
// t.handleNewMessage(ctx, b, m)
|
||||
// }
|
||||
// }()
|
||||
|
||||
if t.channelID != "" {
|
||||
// handle new conversations
|
||||
a.AddSubscriber(func(ccm openai.ChatCompletionMessage) {
|
||||
xlog.Debug("Subscriber(telegram)", "message", ccm.Content)
|
||||
_, err := b.SendMessage(ctx, &bot.SendMessageParams{
|
||||
ChatID: t.channelID,
|
||||
Text: ccm.Content,
|
||||
})
|
||||
if err != nil {
|
||||
xlog.Error("Error sending message", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
t.agent.SharedState().ConversationTracker.AddMessage(
|
||||
fmt.Sprintf("telegram:%s", t.channelID),
|
||||
openai.ChatCompletionMessage{
|
||||
Content: ccm.Content,
|
||||
Role: "assistant",
|
||||
},
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
b.Start(ctx)
|
||||
}
|
||||
|
||||
@@ -422,11 +442,6 @@ func NewTelegramConnector(config map[string]string) (*Telegram, error) {
|
||||
return nil, errors.New("token is required")
|
||||
}
|
||||
|
||||
duration, err := time.ParseDuration(config["lastMessageDuration"])
|
||||
if err != nil {
|
||||
duration = 5 * time.Minute
|
||||
}
|
||||
|
||||
admins := []string{}
|
||||
|
||||
if _, ok := config["admins"]; ok {
|
||||
@@ -434,14 +449,11 @@ func NewTelegramConnector(config map[string]string) (*Telegram, error) {
|
||||
}
|
||||
|
||||
return &Telegram{
|
||||
Token: token,
|
||||
lastMessageDuration: duration,
|
||||
admins: admins,
|
||||
currentconversation: map[int64][]openai.ChatCompletionMessage{},
|
||||
lastMessageTime: map[int64]time.Time{},
|
||||
conversationTracker: NewConversationTracker[int64](duration),
|
||||
placeholders: make(map[string]int),
|
||||
activeJobs: make(map[int64][]*types.Job),
|
||||
Token: token,
|
||||
admins: admins,
|
||||
placeholders: make(map[string]int),
|
||||
activeJobs: make(map[int64][]*types.Job),
|
||||
channelID: config["channel_id"],
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -461,10 +473,10 @@ func TelegramConfigMeta() []config.Field {
|
||||
HelpText: "Comma-separated list of Telegram usernames that are allowed to interact with the bot",
|
||||
},
|
||||
{
|
||||
Name: "lastMessageDuration",
|
||||
Label: "Last Message Duration",
|
||||
Type: config.FieldTypeText,
|
||||
DefaultValue: "5m",
|
||||
Name: "channel_id",
|
||||
Label: "Channel ID",
|
||||
Type: config.FieldTypeText,
|
||||
HelpText: "Telegram channel ID to send messages to if the agent needs to initiate a conversation",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user