feat: track conversations inside connectors (#92)
* switch to observer pattern Signed-off-by: mudler <mudler@localai.io> * keep conversation history in telegram and slack * generalize with conversation tracker --------- Signed-off-by: mudler <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
53c1554d55
commit
d0cfc4c317
@@ -40,6 +40,8 @@ type Agent struct {
|
||||
newConversations chan openai.ChatCompletionMessage
|
||||
|
||||
mcpActions types.Actions
|
||||
|
||||
newMessagesSubscribers []func(openai.ChatCompletionMessage)
|
||||
}
|
||||
|
||||
type RAGDB interface {
|
||||
@@ -64,12 +66,13 @@ func New(opts ...Option) (*Agent, error) {
|
||||
|
||||
ctx, cancel := context.WithCancel(c)
|
||||
a := &Agent{
|
||||
jobQueue: make(chan *types.Job),
|
||||
options: options,
|
||||
client: client,
|
||||
Character: options.character,
|
||||
currentState: &action.AgentInternalState{},
|
||||
context: types.NewActionContext(ctx, cancel),
|
||||
jobQueue: make(chan *types.Job),
|
||||
options: options,
|
||||
client: client,
|
||||
Character: options.character,
|
||||
currentState: &action.AgentInternalState{},
|
||||
context: types.NewActionContext(ctx, cancel),
|
||||
newMessagesSubscribers: options.newConversationsSubscribers,
|
||||
}
|
||||
|
||||
if a.options.statefile != "" {
|
||||
@@ -102,9 +105,27 @@ func New(opts ...Option) (*Agent, error) {
|
||||
"model", a.options.LLMAPI.Model,
|
||||
)
|
||||
|
||||
a.startNewConversationsConsumer()
|
||||
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func (a *Agent) startNewConversationsConsumer() {
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-a.context.Done():
|
||||
return
|
||||
|
||||
case msg := <-a.newConversations:
|
||||
for _, s := range a.newMessagesSubscribers {
|
||||
s(msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// StopAction stops the current action
|
||||
// if any. Can be called before adding a new job.
|
||||
func (a *Agent) StopAction() {
|
||||
@@ -124,10 +145,6 @@ func (a *Agent) ActionContext() context.Context {
|
||||
return a.actionContext.Context
|
||||
}
|
||||
|
||||
func (a *Agent) ConversationChannel() chan openai.ChatCompletionMessage {
|
||||
return a.newConversations
|
||||
}
|
||||
|
||||
// Ask is a pre-emptive, blocking call that returns the response as soon as it's ready.
|
||||
// It discards any other computation.
|
||||
func (a *Agent) Ask(opts ...types.JobOption) *types.JobResult {
|
||||
|
||||
@@ -19,8 +19,7 @@ func (a *Agent) knowledgeBaseLookup(conv Messages) {
|
||||
|
||||
// Walk conversation from bottom to top, and find the first message of the user
|
||||
// to use it as a query to the KB
|
||||
var userMessage string
|
||||
userMessage = conv.GetLatestUserMessage().Content
|
||||
userMessage := conv.GetLatestUserMessage().Content
|
||||
|
||||
xlog.Info("[Knowledge Base Lookup] Last user message", "agent", a.Character.Name, "message", userMessage, "lastMessage", conv.GetLatestUserMessage())
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAgent/core/types"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
type Option func(*options) error
|
||||
@@ -49,6 +50,8 @@ type options struct {
|
||||
conversationsPath string
|
||||
|
||||
mcpServers []MCPServer
|
||||
|
||||
newConversationsSubscribers []func(openai.ChatCompletionMessage)
|
||||
}
|
||||
|
||||
func (o *options) SeparatedMultimodalModel() bool {
|
||||
@@ -125,6 +128,13 @@ func EnableKnowledgeBaseWithResults(results int) Option {
|
||||
}
|
||||
}
|
||||
|
||||
func WithNewConversationSubscriber(sub func(openai.ChatCompletionMessage)) Option {
|
||||
return func(o *options) error {
|
||||
o.newConversationsSubscribers = append(o.newConversationsSubscribers, sub)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
var EnableInitiateConversations = func(o *options) error {
|
||||
o.initiateConversations = true
|
||||
return nil
|
||||
|
||||
75
services/connectors/conversationstracker.go
Normal file
75
services/connectors/conversationstracker.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package connectors
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAgent/pkg/xlog"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
type TrackerKey interface{ ~int | ~int64 | ~string }
|
||||
|
||||
type ConversationTracker[K TrackerKey] struct {
|
||||
convMutex sync.Mutex
|
||||
currentconversation map[K][]openai.ChatCompletionMessage
|
||||
lastMessageTime map[K]time.Time
|
||||
lastMessageDuration time.Duration
|
||||
}
|
||||
|
||||
func NewConversationTracker[K TrackerKey](lastMessageDuration time.Duration) *ConversationTracker[K] {
|
||||
return &ConversationTracker[K]{
|
||||
lastMessageDuration: lastMessageDuration,
|
||||
currentconversation: map[K][]openai.ChatCompletionMessage{},
|
||||
lastMessageTime: map[K]time.Time{},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ConversationTracker[K]) GetConversation(key K) []openai.ChatCompletionMessage {
|
||||
// Lock the conversation mutex to update the conversation history
|
||||
c.convMutex.Lock()
|
||||
defer c.convMutex.Unlock()
|
||||
|
||||
// Clear up the conversation if the last message was sent more than lastMessageDuration ago
|
||||
currentConv := []openai.ChatCompletionMessage{}
|
||||
lastMessageTime := c.lastMessageTime[key]
|
||||
if lastMessageTime.IsZero() {
|
||||
lastMessageTime = time.Now()
|
||||
}
|
||||
if lastMessageTime.Add(c.lastMessageDuration).Before(time.Now()) {
|
||||
currentConv = []openai.ChatCompletionMessage{}
|
||||
c.lastMessageTime[key] = time.Now()
|
||||
xlog.Debug("Conversation history does not exist for", "key", fmt.Sprintf("%v", key))
|
||||
} else {
|
||||
xlog.Debug("Conversation history exists for", "key", fmt.Sprintf("%v", key))
|
||||
currentConv = append(currentConv, c.currentconversation[key]...)
|
||||
}
|
||||
|
||||
// cleanup other conversations if older
|
||||
for k := range c.currentconversation {
|
||||
lastMessage, exists := c.lastMessageTime[k]
|
||||
if !exists {
|
||||
delete(c.currentconversation, k)
|
||||
delete(c.lastMessageTime, k)
|
||||
continue
|
||||
}
|
||||
if lastMessage.Add(c.lastMessageDuration).Before(time.Now()) {
|
||||
xlog.Debug("Cleaning up conversation for", k)
|
||||
delete(c.currentconversation, k)
|
||||
delete(c.lastMessageTime, k)
|
||||
}
|
||||
}
|
||||
|
||||
return currentConv
|
||||
|
||||
}
|
||||
|
||||
func (c *ConversationTracker[K]) AddMessage(key K, message openai.ChatCompletionMessage) {
|
||||
// Lock the conversation mutex to update the conversation history
|
||||
c.convMutex.Lock()
|
||||
defer c.convMutex.Unlock()
|
||||
|
||||
c.currentconversation[key] = append(c.currentconversation[key], message)
|
||||
c.lastMessageTime[key] = time.Now()
|
||||
}
|
||||
@@ -1,17 +1,19 @@
|
||||
package connectors
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/bwmarrin/discordgo"
|
||||
"github.com/mudler/LocalAgent/core/agent"
|
||||
"github.com/mudler/LocalAgent/core/types"
|
||||
"github.com/mudler/LocalAgent/pkg/xlog"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
type Discord struct {
|
||||
token string
|
||||
defaultChannel string
|
||||
token string
|
||||
defaultChannel string
|
||||
conversationTracker *ConversationTracker[string]
|
||||
}
|
||||
|
||||
// NewDiscord creates a new Discord connector
|
||||
@@ -19,9 +21,16 @@ type Discord struct {
|
||||
// - token: Discord token
|
||||
// - defaultChannel: Discord channel to always answer even if not mentioned
|
||||
func NewDiscord(config map[string]string) *Discord {
|
||||
|
||||
duration, err := time.ParseDuration(config["lastMessageDuration"])
|
||||
if err != nil {
|
||||
duration = 5 * time.Minute
|
||||
}
|
||||
|
||||
return &Discord{
|
||||
token: config["token"],
|
||||
defaultChannel: config["defaultChannel"],
|
||||
conversationTracker: NewConversationTracker[string](duration),
|
||||
token: config["token"],
|
||||
defaultChannel: config["defaultChannel"],
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,6 +78,71 @@ func (d *Discord) Start(a *agent.Agent) {
|
||||
}()
|
||||
}
|
||||
|
||||
func (d *Discord) handleThreadMessage(a *agent.Agent, s *discordgo.Session, m *discordgo.MessageCreate) {
|
||||
var messages []*discordgo.Message
|
||||
var err error
|
||||
messages, err = s.ChannelMessages(m.ChannelID, 100, "", m.MessageReference.MessageID, "")
|
||||
if err != nil {
|
||||
xlog.Info("error getting messages,", err)
|
||||
return
|
||||
}
|
||||
|
||||
conv := []openai.ChatCompletionMessage{}
|
||||
|
||||
for _, message := range messages {
|
||||
if message.Author.ID == s.State.User.ID {
|
||||
conv = append(conv, openai.ChatCompletionMessage{
|
||||
Role: "assistant",
|
||||
Content: message.Content,
|
||||
})
|
||||
} else {
|
||||
conv = append(conv, openai.ChatCompletionMessage{
|
||||
Role: "user",
|
||||
Content: message.Content,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
jobResult := a.Ask(
|
||||
types.WithConversationHistory(conv),
|
||||
)
|
||||
|
||||
if jobResult.Error != nil {
|
||||
xlog.Info("error asking agent,", jobResult.Error)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = s.ChannelMessageSend(m.ChannelID, jobResult.Response)
|
||||
if err != nil {
|
||||
xlog.Info("error sending message,", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Discord) handleChannelMessage(a *agent.Agent, s *discordgo.Session, m *discordgo.MessageCreate) {
|
||||
|
||||
conv := d.conversationTracker.GetConversation(m.ChannelID)
|
||||
|
||||
jobResult := a.Ask(
|
||||
types.WithConversationHistory(conv),
|
||||
)
|
||||
|
||||
if jobResult.Error != nil {
|
||||
xlog.Info("error asking agent,", jobResult.Error)
|
||||
return
|
||||
}
|
||||
|
||||
_, err := s.ChannelMessageSend(m.ChannelID, jobResult.Response)
|
||||
if err != nil {
|
||||
xlog.Info("error sending message,", err)
|
||||
}
|
||||
|
||||
d.conversationTracker.AddMessage(m.ChannelID, openai.ChatCompletionMessage{
|
||||
Role: "user",
|
||||
Content: m.Content,
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
// This function will be called (due to AddHandler above) every time a new
|
||||
// message is created on any channel that the authenticated bot has access to.
|
||||
func (d *Discord) messageCreate(a *agent.Agent) func(s *discordgo.Session, m *discordgo.MessageCreate) {
|
||||
@@ -78,40 +152,30 @@ func (d *Discord) messageCreate(a *agent.Agent) func(s *discordgo.Session, m *di
|
||||
if m.Author.ID == s.State.User.ID {
|
||||
return
|
||||
}
|
||||
interact := func() {
|
||||
//m := m.ContentWithMentionsReplaced()
|
||||
content := m.Content
|
||||
|
||||
content = strings.ReplaceAll(content, "<@"+s.State.User.ID+"> ", "")
|
||||
xlog.Info("Received message", "content", content)
|
||||
job := a.Ask(
|
||||
types.WithText(
|
||||
content,
|
||||
),
|
||||
)
|
||||
if job.Error != nil {
|
||||
xlog.Info("error asking agent,", job.Error)
|
||||
// Interact if we are mentioned
|
||||
mentioned := false
|
||||
for _, mention := range m.Mentions {
|
||||
if mention.ID == s.State.User.ID {
|
||||
mentioned = true
|
||||
return
|
||||
}
|
||||
|
||||
xlog.Info("Response", "response", job.Response)
|
||||
_, err := s.ChannelMessageSend(m.ChannelID, job.Response)
|
||||
if err != nil {
|
||||
xlog.Info("error sending message,", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Interact if we are mentioned
|
||||
for _, mention := range m.Mentions {
|
||||
if mention.ID == s.State.User.ID {
|
||||
go interact()
|
||||
return
|
||||
}
|
||||
if !mentioned && d.defaultChannel == "" {
|
||||
xlog.Debug("Not mentioned")
|
||||
return
|
||||
}
|
||||
|
||||
// check if the message is in a thread and get all messages in the thread
|
||||
if m.MessageReference != nil {
|
||||
d.handleThreadMessage(a, s, m)
|
||||
return
|
||||
}
|
||||
|
||||
// Or we are in the default channel (if one is set!)
|
||||
if d.defaultChannel != "" && m.ChannelID == d.defaultChannel {
|
||||
go interact()
|
||||
d.handleChannelMessage(a, s, m)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,25 +9,33 @@ import (
|
||||
"github.com/mudler/LocalAgent/core/types"
|
||||
"github.com/mudler/LocalAgent/pkg/xlog"
|
||||
"github.com/mudler/LocalAgent/services/actions"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
irc "github.com/thoj/go-ircevent"
|
||||
)
|
||||
|
||||
type IRC struct {
|
||||
server string
|
||||
port string
|
||||
nickname string
|
||||
channel string
|
||||
conn *irc.Connection
|
||||
alwaysReply bool
|
||||
server string
|
||||
port string
|
||||
nickname string
|
||||
channel string
|
||||
conn *irc.Connection
|
||||
alwaysReply bool
|
||||
conversationTracker *ConversationTracker[string]
|
||||
}
|
||||
|
||||
func NewIRC(config map[string]string) *IRC {
|
||||
|
||||
duration, err := time.ParseDuration(config["lastMessageDuration"])
|
||||
if err != nil {
|
||||
duration = 5 * time.Minute
|
||||
}
|
||||
return &IRC{
|
||||
server: config["server"],
|
||||
port: config["port"],
|
||||
nickname: config["nickname"],
|
||||
channel: config["channel"],
|
||||
alwaysReply: config["alwaysReply"] == "true",
|
||||
server: config["server"],
|
||||
port: config["port"],
|
||||
nickname: config["nickname"],
|
||||
channel: config["channel"],
|
||||
alwaysReply: config["alwaysReply"] == "true",
|
||||
conversationTracker: NewConversationTracker[string](duration),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,13 +110,33 @@ func (i *IRC) Start(a *agent.Agent) {
|
||||
}
|
||||
|
||||
xlog.Info("Recv message", "message", message, "sender", sender, "channel", channel)
|
||||
cleanedMessage := "My name is " + sender + ". " + cleanUpMessage(message, i.nickname)
|
||||
cleanedMessage := cleanUpMessage(message, i.nickname)
|
||||
|
||||
go func() {
|
||||
res := a.Ask(
|
||||
types.WithText(cleanedMessage),
|
||||
conv := i.conversationTracker.GetConversation(channel)
|
||||
|
||||
conv = append(conv,
|
||||
openai.ChatCompletionMessage{
|
||||
Content: cleanedMessage,
|
||||
Role: "user",
|
||||
},
|
||||
)
|
||||
|
||||
res := a.Ask(
|
||||
types.WithConversationHistory(conv),
|
||||
)
|
||||
|
||||
if res.Response == "" {
|
||||
xlog.Info("No response from agent")
|
||||
return
|
||||
}
|
||||
|
||||
// Update the conversation history
|
||||
i.conversationTracker.AddMessage(channel, openai.ChatCompletionMessage{
|
||||
Content: res.Response,
|
||||
Role: "assistant",
|
||||
})
|
||||
|
||||
xlog.Info("Sending message", "message", res.Response, "channel", channel)
|
||||
|
||||
// Split the response into multiple messages if it's too long
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAgent/pkg/xlog"
|
||||
"github.com/mudler/LocalAgent/services/actions"
|
||||
@@ -35,17 +36,28 @@ type Slack struct {
|
||||
placeholders map[string]string // map[jobUUID]messageTS
|
||||
placeholderMutex sync.RWMutex
|
||||
apiClient *slack.Client
|
||||
|
||||
conversationTracker *ConversationTracker[string]
|
||||
processing sync.Mutex
|
||||
processingMessage bool
|
||||
}
|
||||
|
||||
const thinkingMessage = "thinking..."
|
||||
|
||||
func NewSlack(config map[string]string) *Slack {
|
||||
|
||||
duration, err := time.ParseDuration(config["lastMessageDuration"])
|
||||
if err != nil {
|
||||
duration = 5 * time.Minute
|
||||
}
|
||||
|
||||
return &Slack{
|
||||
appToken: config["appToken"],
|
||||
botToken: config["botToken"],
|
||||
channelID: config["channelID"],
|
||||
alwaysReply: config["alwaysReply"] == "true",
|
||||
placeholders: make(map[string]string),
|
||||
appToken: config["appToken"],
|
||||
botToken: config["botToken"],
|
||||
channelID: config["channelID"],
|
||||
alwaysReply: config["alwaysReply"] == "true",
|
||||
conversationTracker: NewConversationTracker[string](duration),
|
||||
placeholders: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -182,6 +194,26 @@ func (t *Slack) handleChannelMessage(
|
||||
return
|
||||
}
|
||||
|
||||
currentConv := t.conversationTracker.GetConversation(t.channelID)
|
||||
|
||||
// Lock the conversation mutex to update the conversation history
|
||||
t.processing.Lock()
|
||||
|
||||
// If we are already processing something, stop the current action
|
||||
if t.processingMessage {
|
||||
a.StopAction()
|
||||
} else {
|
||||
t.processingMessage = true
|
||||
}
|
||||
t.processing.Unlock()
|
||||
|
||||
// Defer to reset the processing flag
|
||||
defer func() {
|
||||
t.processing.Lock()
|
||||
t.processingMessage = false
|
||||
t.processing.Unlock()
|
||||
}()
|
||||
|
||||
message := replaceUserIDsWithNamesInMessage(api, cleanUpUsernameFromMessage(ev.Text, b))
|
||||
|
||||
go func() {
|
||||
@@ -221,22 +253,59 @@ func (t *Slack) handleChannelMessage(
|
||||
|
||||
// If the last message has an image, we send it as a multi content message
|
||||
if len(imageBytes.Bytes()) > 0 {
|
||||
|
||||
// // Encode the image to base64
|
||||
imgBase64, err := encodeImageFromURL(*imageBytes)
|
||||
if err != nil {
|
||||
xlog.Error(fmt.Sprintf("Error encoding image to base64: %v", err))
|
||||
} else {
|
||||
agentOptions = append(agentOptions, types.WithTextImage(message, fmt.Sprintf("data:%s;base64,%s", mimeType, imgBase64)))
|
||||
currentConv = append(currentConv,
|
||||
openai.ChatCompletionMessage{
|
||||
Role: "user",
|
||||
MultiContent: []openai.ChatMessagePart{
|
||||
{
|
||||
Text: message,
|
||||
Type: openai.ChatMessagePartTypeText,
|
||||
},
|
||||
{
|
||||
Type: openai.ChatMessagePartTypeImageURL,
|
||||
ImageURL: &openai.ChatMessageImageURL{
|
||||
URL: fmt.Sprintf("data:%s;base64,%s", mimeType, imgBase64),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
}
|
||||
} else {
|
||||
agentOptions = append(agentOptions, types.WithText(message))
|
||||
currentConv = append(currentConv, openai.ChatCompletionMessage{
|
||||
Role: "user",
|
||||
Content: message,
|
||||
})
|
||||
}
|
||||
|
||||
agentOptions = append(agentOptions, types.WithConversationHistory(currentConv))
|
||||
|
||||
res := a.Ask(
|
||||
agentOptions...,
|
||||
)
|
||||
|
||||
if res.Response == "" {
|
||||
xlog.Debug(fmt.Sprintf("Empty response from agent"))
|
||||
return
|
||||
}
|
||||
|
||||
if res.Error != nil {
|
||||
xlog.Error(fmt.Sprintf("Error from agent: %v", res.Error))
|
||||
return
|
||||
}
|
||||
|
||||
t.conversationTracker.AddMessage(
|
||||
t.channelID, openai.ChatCompletionMessage{
|
||||
Role: "assistant",
|
||||
Content: res.Response,
|
||||
},
|
||||
)
|
||||
|
||||
//res.Response = githubmarkdownconvertergo.Slack(res.Response)
|
||||
|
||||
_, _, err = api.PostMessage(ev.Channel,
|
||||
@@ -250,6 +319,7 @@ func (t *Slack) handleChannelMessage(
|
||||
if err != nil {
|
||||
xlog.Error(fmt.Sprintf("Error posting message: %v", err))
|
||||
}
|
||||
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -486,6 +556,15 @@ func (t *Slack) handleMention(
|
||||
types.WithMetadata(metadata),
|
||||
)
|
||||
|
||||
if res.Response == "" {
|
||||
xlog.Debug(fmt.Sprintf("Empty response from agent"))
|
||||
_, _, err := api.DeleteMessage(ev.Channel, msgTs)
|
||||
if err != nil {
|
||||
xlog.Error(fmt.Sprintf("Error deleting message: %v", err))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// get user id
|
||||
user, err := api.GetUserInfo(ev.User)
|
||||
if err != nil {
|
||||
|
||||
@@ -5,18 +5,30 @@ import (
|
||||
"errors"
|
||||
"os"
|
||||
"os/signal"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-telegram/bot"
|
||||
"github.com/go-telegram/bot/models"
|
||||
"github.com/mudler/LocalAgent/core/agent"
|
||||
"github.com/mudler/LocalAgent/core/types"
|
||||
"github.com/mudler/LocalAgent/pkg/xlog"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
type Telegram struct {
|
||||
Token string
|
||||
lastChatID int64
|
||||
bot *bot.Bot
|
||||
agent *agent.Agent
|
||||
Token string
|
||||
bot *bot.Bot
|
||||
agent *agent.Agent
|
||||
|
||||
currentconversation map[int64][]openai.ChatCompletionMessage
|
||||
lastMessageTime map[int64]time.Time
|
||||
lastMessageDuration time.Duration
|
||||
|
||||
admins []string
|
||||
|
||||
conversationTracker *ConversationTracker[int64]
|
||||
}
|
||||
|
||||
// Send any text message to the bot after the bot has been started
|
||||
@@ -38,24 +50,60 @@ func (t *Telegram) AgentReasoningCallback() func(state types.ActionCurrentState)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Telegram) handleUpdate(ctx context.Context, b *bot.Bot, a *agent.Agent, update *models.Update) {
|
||||
username := update.Message.From.Username
|
||||
|
||||
if len(t.admins) > 0 && !slices.Contains(t.admins, username) {
|
||||
xlog.Info("Unauthorized user", "username", username)
|
||||
return
|
||||
}
|
||||
|
||||
currentConv := t.conversationTracker.GetConversation(update.Message.From.ID)
|
||||
currentConv = append(currentConv, openai.ChatCompletionMessage{
|
||||
Content: update.Message.Text,
|
||||
Role: "user",
|
||||
})
|
||||
|
||||
res := a.Ask(
|
||||
types.WithConversationHistory(currentConv),
|
||||
)
|
||||
|
||||
if res.Response == "" {
|
||||
return
|
||||
}
|
||||
|
||||
t.conversationTracker.AddMessage(
|
||||
update.Message.From.ID,
|
||||
openai.ChatCompletionMessage{
|
||||
Content: res.Response,
|
||||
Role: "assistant",
|
||||
},
|
||||
)
|
||||
|
||||
b.SendMessage(ctx, &bot.SendMessageParams{
|
||||
ParseMode: models.ParseModeMarkdown,
|
||||
ChatID: update.Message.Chat.ID,
|
||||
Text: res.Response,
|
||||
})
|
||||
}
|
||||
|
||||
// func (t *Telegram) handleNewMessage(ctx context.Context, b *bot.Bot, m openai.ChatCompletionMessage) {
|
||||
// if t.lastChatID == 0 {
|
||||
// return
|
||||
// }
|
||||
// b.SendMessage(ctx, &bot.SendMessageParams{
|
||||
// ChatID: t.lastChatID,
|
||||
// Text: m.Content,
|
||||
// })
|
||||
// }
|
||||
|
||||
func (t *Telegram) Start(a *agent.Agent) {
|
||||
ctx, cancel := signal.NotifyContext(a.Context(), os.Interrupt)
|
||||
defer cancel()
|
||||
|
||||
opts := []bot.Option{
|
||||
bot.WithDefaultHandler(func(ctx context.Context, b *bot.Bot, update *models.Update) {
|
||||
go func() {
|
||||
res := a.Ask(
|
||||
types.WithText(
|
||||
update.Message.Text,
|
||||
),
|
||||
)
|
||||
b.SendMessage(ctx, &bot.SendMessageParams{
|
||||
ChatID: update.Message.Chat.ID,
|
||||
Text: res.Response,
|
||||
})
|
||||
t.lastChatID = update.Message.Chat.ID
|
||||
}()
|
||||
go t.handleUpdate(ctx, b, a, update)
|
||||
}),
|
||||
}
|
||||
|
||||
@@ -67,17 +115,11 @@ func (t *Telegram) Start(a *agent.Agent) {
|
||||
t.bot = b
|
||||
t.agent = a
|
||||
|
||||
go func() {
|
||||
for m := range a.ConversationChannel() {
|
||||
if t.lastChatID == 0 {
|
||||
continue
|
||||
}
|
||||
b.SendMessage(ctx, &bot.SendMessageParams{
|
||||
ChatID: t.lastChatID,
|
||||
Text: m.Content,
|
||||
})
|
||||
}
|
||||
}()
|
||||
// go func() {
|
||||
// for m := range a.ConversationChannel() {
|
||||
// t.handleNewMessage(ctx, b, m)
|
||||
// }
|
||||
// }()
|
||||
|
||||
b.Start(ctx)
|
||||
}
|
||||
@@ -88,7 +130,23 @@ func NewTelegramConnector(config map[string]string) (*Telegram, error) {
|
||||
return nil, errors.New("token is required")
|
||||
}
|
||||
|
||||
duration, err := time.ParseDuration(config["lastMessageDuration"])
|
||||
if err != nil {
|
||||
duration = 5 * time.Minute
|
||||
}
|
||||
|
||||
admins := []string{}
|
||||
|
||||
if _, ok := config["admins"]; ok {
|
||||
admins = append(admins, strings.Split(config["admins"], ",")...)
|
||||
}
|
||||
|
||||
return &Telegram{
|
||||
Token: token,
|
||||
Token: token,
|
||||
lastMessageDuration: duration,
|
||||
admins: admins,
|
||||
currentconversation: map[int64][]openai.ChatCompletionMessage{},
|
||||
lastMessageTime: map[int64]time.Time{},
|
||||
conversationTracker: NewConversationTracker[int64](duration),
|
||||
}, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user