feat(agent): shared state, allow to track conversations globally (#148)
* feat(agent): shared state, allow to track conversations globally Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Cleanup Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * track conversations initiated by the bot Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
2b07dd79ec
commit
c23e655f44
@@ -81,7 +81,7 @@ func (a *CustomAction) Plannable() bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *CustomAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) {
|
func (a *CustomAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||||
v, err := a.i.Eval(fmt.Sprintf("%s.Run", a.config["name"]))
|
v, err := a.i.Eval(fmt.Sprintf("%s.Run", a.config["name"]))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.ActionResult{}, err
|
return types.ActionResult{}, err
|
||||||
|
|||||||
@@ -76,7 +76,7 @@ return []string{"foo"}
|
|||||||
Description: "A test action",
|
Description: "A test action",
|
||||||
}))
|
}))
|
||||||
|
|
||||||
runResult, err := customAction.Run(context.Background(), types.ActionParams{
|
runResult, err := customAction.Run(context.Background(), nil, types.ActionParams{
|
||||||
"Foo": "bar",
|
"Foo": "bar",
|
||||||
})
|
})
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ type GoalResponse struct {
|
|||||||
Achieved bool `json:"achieved"`
|
Achieved bool `json:"achieved"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *GoalAction) Run(context.Context, types.ActionParams) (types.ActionResult, error) {
|
func (a *GoalAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||||
return types.ActionResult{}, nil
|
return types.ActionResult{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ type IntentResponse struct {
|
|||||||
Reasoning string `json:"reasoning"`
|
Reasoning string `json:"reasoning"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *IntentAction) Run(context.Context, types.ActionParams) (types.ActionResult, error) {
|
func (a *IntentAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||||
return types.ActionResult{}, nil
|
return types.ActionResult{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ type ConversationActionResponse struct {
|
|||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *ConversationAction) Run(context.Context, types.ActionParams) (types.ActionResult, error) {
|
func (a *ConversationAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||||
return types.ActionResult{}, nil
|
return types.ActionResult{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ func NewStop() *StopAction {
|
|||||||
|
|
||||||
type StopAction struct{}
|
type StopAction struct{}
|
||||||
|
|
||||||
func (a *StopAction) Run(context.Context, types.ActionParams) (types.ActionResult, error) {
|
func (a *StopAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||||
return types.ActionResult{}, nil
|
return types.ActionResult{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ type PlanSubtask struct {
|
|||||||
Reasoning string `json:"reasoning"`
|
Reasoning string `json:"reasoning"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *PlanAction) Run(context.Context, types.ActionParams) (types.ActionResult, error) {
|
func (a *PlanAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||||
return types.ActionResult{}, nil
|
return types.ActionResult{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ type ReasoningResponse struct {
|
|||||||
Reasoning string `json:"reasoning"`
|
Reasoning string `json:"reasoning"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *ReasoningAction) Run(context.Context, types.ActionParams) (types.ActionResult, error) {
|
func (a *ReasoningAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||||
return types.ActionResult{}, nil
|
return types.ActionResult{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ type ReplyResponse struct {
|
|||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *ReplyAction) Run(context.Context, types.ActionParams) (string, error) {
|
func (a *ReplyAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (string, error) {
|
||||||
return "no-op", nil
|
return "no-op", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ func NewState() *StateAction {
|
|||||||
|
|
||||||
type StateAction struct{}
|
type StateAction struct{}
|
||||||
|
|
||||||
func (a *StateAction) Run(context.Context, types.ActionParams) (types.ActionResult, error) {
|
func (a *StateAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||||
return types.ActionResult{Result: "internal state has been updated"}, nil
|
return types.ActionResult{Result: "internal state has been updated"}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -46,6 +46,8 @@ type Agent struct {
|
|||||||
newMessagesSubscribers []func(openai.ChatCompletionMessage)
|
newMessagesSubscribers []func(openai.ChatCompletionMessage)
|
||||||
|
|
||||||
observer Observer
|
observer Observer
|
||||||
|
|
||||||
|
sharedState *types.AgentSharedState
|
||||||
}
|
}
|
||||||
|
|
||||||
type RAGDB interface {
|
type RAGDB interface {
|
||||||
@@ -78,6 +80,7 @@ func New(opts ...Option) (*Agent, error) {
|
|||||||
context: types.NewActionContext(ctx, cancel),
|
context: types.NewActionContext(ctx, cancel),
|
||||||
newConversations: make(chan openai.ChatCompletionMessage),
|
newConversations: make(chan openai.ChatCompletionMessage),
|
||||||
newMessagesSubscribers: options.newConversationsSubscribers,
|
newMessagesSubscribers: options.newConversationsSubscribers,
|
||||||
|
sharedState: types.NewAgentSharedState(options.lastMessageDuration),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize observer if provided
|
// Initialize observer if provided
|
||||||
@@ -118,6 +121,10 @@ func New(opts ...Option) (*Agent, error) {
|
|||||||
return a, nil
|
return a, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Agent) SharedState() *types.AgentSharedState {
|
||||||
|
return a.sharedState
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Agent) startNewConversationsConsumer() {
|
func (a *Agent) startNewConversationsConsumer() {
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
@@ -294,7 +301,7 @@ func (a *Agent) runAction(job *types.Job, chosenAction types.Action, params type
|
|||||||
|
|
||||||
for _, act := range a.availableActions() {
|
for _, act := range a.availableActions() {
|
||||||
if act.Definition().Name == chosenAction.Definition().Name {
|
if act.Definition().Name == chosenAction.Definition().Name {
|
||||||
res, err := act.Run(job.GetContext(), params)
|
res, err := act.Run(job.GetContext(), a.sharedState, params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if obs != nil {
|
if obs != nil {
|
||||||
obs.Completion = &types.Completion{
|
obs.Completion = &types.Completion{
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ func (a *TestAction) Plannable() bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *TestAction) Run(c context.Context, p types.ActionParams) (types.ActionResult, error) {
|
func (a *TestAction) Run(c context.Context, sharedState *types.AgentSharedState, p types.ActionParams) (types.ActionResult, error) {
|
||||||
for k, r := range a.response {
|
for k, r := range a.response {
|
||||||
if strings.Contains(strings.ToLower(p.String()), strings.ToLower(k)) {
|
if strings.Contains(strings.ToLower(p.String()), strings.ToLower(k)) {
|
||||||
return types.ActionResult{Result: r}, nil
|
return types.ActionResult{Result: r}, nil
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ func (a *mcpAction) Plannable() bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mcpAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) {
|
func (m *mcpAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||||
resp, err := m.mcpClient.CallTool(ctx, m.toolName, params)
|
resp, err := m.mcpClient.CallTool(ctx, m.toolName, params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
xlog.Error("Failed to call tool", "error", err.Error())
|
xlog.Error("Failed to call tool", "error", err.Error())
|
||||||
|
|||||||
@@ -64,6 +64,8 @@ type options struct {
|
|||||||
|
|
||||||
observer Observer
|
observer Observer
|
||||||
parallelJobs int
|
parallelJobs int
|
||||||
|
|
||||||
|
lastMessageDuration time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *options) SeparatedMultimodalModel() bool {
|
func (o *options) SeparatedMultimodalModel() bool {
|
||||||
@@ -151,6 +153,17 @@ func EnableKnowledgeBaseWithResults(results int) Option {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithLastMessageDuration(duration string) Option {
|
||||||
|
return func(o *options) error {
|
||||||
|
d, err := time.ParseDuration(duration)
|
||||||
|
if err != nil {
|
||||||
|
d = types.DefaultLastMessageDuration
|
||||||
|
}
|
||||||
|
o.lastMessageDuration = d
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func WithParallelJobs(jobs int) Option {
|
func WithParallelJobs(jobs int) Option {
|
||||||
return func(o *options) error {
|
return func(o *options) error {
|
||||||
o.parallelJobs = jobs
|
o.parallelJobs = jobs
|
||||||
|
|||||||
@@ -14,10 +14,10 @@ import (
|
|||||||
// all information that should be displayed to the LLM
|
// all information that should be displayed to the LLM
|
||||||
// in the prompts
|
// in the prompts
|
||||||
type PromptHUD struct {
|
type PromptHUD struct {
|
||||||
Character Character `json:"character"`
|
Character Character `json:"character"`
|
||||||
CurrentState types.AgentInternalState `json:"current_state"`
|
CurrentState types.AgentInternalState `json:"current_state"`
|
||||||
PermanentGoal string `json:"permanent_goal"`
|
PermanentGoal string `json:"permanent_goal"`
|
||||||
ShowCharacter bool `json:"show_character"`
|
ShowCharacter bool `json:"show_character"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Character struct {
|
type Character struct {
|
||||||
|
|||||||
13
core/conversations/conversations_suite_test.go
Normal file
13
core/conversations/conversations_suite_test.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
package conversations_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConversations(t *testing.T) {
|
||||||
|
RegisterFailHandler(Fail)
|
||||||
|
RunSpecs(t, "Conversations test suite")
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package connectors
|
package conversations
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -1,9 +1,9 @@
|
|||||||
package connectors_test
|
package conversations_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/mudler/LocalAGI/services/connectors"
|
"github.com/mudler/LocalAGI/core/conversations"
|
||||||
. "github.com/onsi/ginkgo/v2"
|
. "github.com/onsi/ginkgo/v2"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
"github.com/sashabaranov/go-openai"
|
"github.com/sashabaranov/go-openai"
|
||||||
@@ -11,13 +11,13 @@ import (
|
|||||||
|
|
||||||
var _ = Describe("ConversationTracker", func() {
|
var _ = Describe("ConversationTracker", func() {
|
||||||
var (
|
var (
|
||||||
tracker *connectors.ConversationTracker[string]
|
tracker *conversations.ConversationTracker[string]
|
||||||
duration time.Duration
|
duration time.Duration
|
||||||
)
|
)
|
||||||
|
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
duration = 1 * time.Second
|
duration = 1 * time.Second
|
||||||
tracker = connectors.NewConversationTracker[string](duration)
|
tracker = conversations.NewConversationTracker[string](duration)
|
||||||
})
|
})
|
||||||
|
|
||||||
It("should initialize with empty conversations", func() {
|
It("should initialize with empty conversations", func() {
|
||||||
@@ -81,8 +81,8 @@ var _ = Describe("ConversationTracker", func() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
It("should handle different key types", func() {
|
It("should handle different key types", func() {
|
||||||
trackerInt := connectors.NewConversationTracker[int](duration)
|
trackerInt := conversations.NewConversationTracker[int](duration)
|
||||||
trackerInt64 := connectors.NewConversationTracker[int64](duration)
|
trackerInt64 := conversations.NewConversationTracker[int64](duration)
|
||||||
|
|
||||||
message := openai.ChatCompletionMessage{
|
message := openai.ChatCompletionMessage{
|
||||||
Role: openai.ChatMessageRoleUser,
|
Role: openai.ChatMessageRoleUser,
|
||||||
@@ -48,12 +48,13 @@ type AgentConfig struct {
|
|||||||
|
|
||||||
Description string `json:"description" form:"description"`
|
Description string `json:"description" form:"description"`
|
||||||
|
|
||||||
Model string `json:"model" form:"model"`
|
Model string `json:"model" form:"model"`
|
||||||
MultimodalModel string `json:"multimodal_model" form:"multimodal_model"`
|
MultimodalModel string `json:"multimodal_model" form:"multimodal_model"`
|
||||||
APIURL string `json:"api_url" form:"api_url"`
|
APIURL string `json:"api_url" form:"api_url"`
|
||||||
APIKey string `json:"api_key" form:"api_key"`
|
APIKey string `json:"api_key" form:"api_key"`
|
||||||
LocalRAGURL string `json:"local_rag_url" form:"local_rag_url"`
|
LocalRAGURL string `json:"local_rag_url" form:"local_rag_url"`
|
||||||
LocalRAGAPIKey string `json:"local_rag_api_key" form:"local_rag_api_key"`
|
LocalRAGAPIKey string `json:"local_rag_api_key" form:"local_rag_api_key"`
|
||||||
|
LastMessageDuration string `json:"last_message_duration" form:"last_message_duration"`
|
||||||
|
|
||||||
Name string `json:"name" form:"name"`
|
Name string `json:"name" form:"name"`
|
||||||
HUD bool `json:"hud" form:"hud"`
|
HUD bool `json:"hud" form:"hud"`
|
||||||
@@ -329,6 +330,14 @@ func NewAgentConfigMeta(
|
|||||||
HelpText: "Maximum number of evaluation loops to perform when addressing gaps in responses",
|
HelpText: "Maximum number of evaluation loops to perform when addressing gaps in responses",
|
||||||
Tags: config.Tags{Section: "AdvancedSettings"},
|
Tags: config.Tags{Section: "AdvancedSettings"},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "last_message_duration",
|
||||||
|
Label: "Last Message Duration",
|
||||||
|
Type: "text",
|
||||||
|
DefaultValue: "5m",
|
||||||
|
HelpText: "Duration for the last message to be considered in the conversation",
|
||||||
|
Tags: config.Tags{Section: "AdvancedSettings"},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
MCPServers: []config.Field{
|
MCPServers: []config.Field{
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -462,6 +462,7 @@ func (a *AgentPool) startAgentWithConfig(name string, config *AgentConfig, obs O
|
|||||||
}),
|
}),
|
||||||
WithSystemPrompt(config.SystemPrompt),
|
WithSystemPrompt(config.SystemPrompt),
|
||||||
WithMultimodalModel(multimodalModel),
|
WithMultimodalModel(multimodalModel),
|
||||||
|
WithLastMessageDuration(config.LastMessageDuration),
|
||||||
WithAgentResultCallback(func(state types.ActionState) {
|
WithAgentResultCallback(func(state types.ActionState) {
|
||||||
a.Lock()
|
a.Lock()
|
||||||
if _, ok := a.agentStatus[name]; !ok {
|
if _, ok := a.agentStatus[name]; !ok {
|
||||||
|
|||||||
@@ -88,7 +88,7 @@ func (a ActionDefinition) ToFunctionDefinition() *openai.FunctionDefinition {
|
|||||||
|
|
||||||
// Actions is something the agent can do
|
// Actions is something the agent can do
|
||||||
type Action interface {
|
type Action interface {
|
||||||
Run(ctx context.Context, action ActionParams) (ActionResult, error)
|
Run(ctx context.Context, sharedState *AgentSharedState, action ActionParams) (ActionResult, error)
|
||||||
Definition() ActionDefinition
|
Definition() ActionDefinition
|
||||||
Plannable() bool
|
Plannable() bool
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,11 @@
|
|||||||
package types
|
package types
|
||||||
|
|
||||||
import "fmt"
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAGI/core/conversations"
|
||||||
|
)
|
||||||
|
|
||||||
// State is the structure
|
// State is the structure
|
||||||
// that is used to keep track of the current state
|
// that is used to keep track of the current state
|
||||||
@@ -20,6 +25,23 @@ type AgentInternalState struct {
|
|||||||
Goal string `json:"goal"`
|
Goal string `json:"goal"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
DefaultLastMessageDuration = 5 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
type AgentSharedState struct {
|
||||||
|
ConversationTracker *conversations.ConversationTracker[string] `json:"conversation_tracker"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAgentSharedState(lastMessageDuration time.Duration) *AgentSharedState {
|
||||||
|
if lastMessageDuration == 0 {
|
||||||
|
lastMessageDuration = DefaultLastMessageDuration
|
||||||
|
}
|
||||||
|
return &AgentSharedState{
|
||||||
|
ConversationTracker: conversations.NewConversationTracker[string](lastMessageDuration),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const fmtT = `=====================
|
const fmtT = `=====================
|
||||||
NowDoing: %s
|
NowDoing: %s
|
||||||
DoingNext: %s
|
DoingNext: %s
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ func NewBrowse(config map[string]string) *BrowseAction {
|
|||||||
|
|
||||||
type BrowseAction struct{}
|
type BrowseAction struct{}
|
||||||
|
|
||||||
func (a *BrowseAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) {
|
func (a *BrowseAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||||
result := struct {
|
result := struct {
|
||||||
URL string `json:"url"`
|
URL string `json:"url"`
|
||||||
}{}
|
}{}
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ func NewBrowserAgentRunner(config map[string]string, defaultURL string) *Browser
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BrowserAgentRunner) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) {
|
func (b *BrowserAgentRunner) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||||
result := api.AgentRequest{}
|
result := api.AgentRequest{}
|
||||||
err := params.Unmarshal(&result)
|
err := params.Unmarshal(&result)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ type CallAgentAction struct {
|
|||||||
blacklist []string
|
blacklist []string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *CallAgentAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) {
|
func (a *CallAgentAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||||
result := struct {
|
result := struct {
|
||||||
AgentName string `json:"agent_name"`
|
AgentName string `json:"agent_name"`
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ func NewCounter(config map[string]string) *CounterAction {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Run executes the counter action
|
// Run executes the counter action
|
||||||
func (a *CounterAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) {
|
func (a *CounterAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||||
// Parse parameters
|
// Parse parameters
|
||||||
request := struct {
|
request := struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ func NewDeepResearchRunner(config map[string]string, defaultURL string) *DeepRes
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DeepResearchRunner) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) {
|
func (d *DeepResearchRunner) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||||
result := api.DeepResearchRequest{}
|
result := api.DeepResearchRequest{}
|
||||||
err := params.Unmarshal(&result)
|
err := params.Unmarshal(&result)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ type GenImageAction struct {
|
|||||||
imageModel string
|
imageModel string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *GenImageAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) {
|
func (a *GenImageAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||||
result := struct {
|
result := struct {
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
Size string `json:"size"`
|
Size string `json:"size"`
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ var _ = Describe("GenImageAction", func() {
|
|||||||
"size": "256x256",
|
"size": "256x256",
|
||||||
}
|
}
|
||||||
|
|
||||||
url, err := action.Run(ctx, params)
|
url, err := action.Run(ctx, nil, params)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(url).ToNot(BeEmpty())
|
Expect(url).ToNot(BeEmpty())
|
||||||
})
|
})
|
||||||
@@ -52,7 +52,7 @@ var _ = Describe("GenImageAction", func() {
|
|||||||
"size": "256x256",
|
"size": "256x256",
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := action.Run(ctx, params)
|
_, err := action.Run(ctx, nil, params)
|
||||||
Expect(err).To(HaveOccurred())
|
Expect(err).To(HaveOccurred())
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ func NewGithubIssueCloser(config map[string]string) *GithubIssuesCloser {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GithubIssuesCloser) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) {
|
func (g *GithubIssuesCloser) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||||
result := struct {
|
result := struct {
|
||||||
Repository string `json:"repository"`
|
Repository string `json:"repository"`
|
||||||
Owner string `json:"owner"`
|
Owner string `json:"owner"`
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ func NewGithubIssueCommenter(config map[string]string) *GithubIssuesCommenter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GithubIssuesCommenter) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) {
|
func (g *GithubIssuesCommenter) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||||
result := struct {
|
result := struct {
|
||||||
Repository string `json:"repository"`
|
Repository string `json:"repository"`
|
||||||
Owner string `json:"owner"`
|
Owner string `json:"owner"`
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ func NewGithubIssueEditor(config map[string]string) *GithubIssueEditor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GithubIssueEditor) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) {
|
func (g *GithubIssueEditor) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||||
result := struct {
|
result := struct {
|
||||||
Repository string `json:"repository"`
|
Repository string `json:"repository"`
|
||||||
Owner string `json:"owner"`
|
Owner string `json:"owner"`
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ func NewGithubIssueLabeler(config map[string]string) *GithubIssuesLabeler {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GithubIssuesLabeler) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) {
|
func (g *GithubIssuesLabeler) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||||
result := struct {
|
result := struct {
|
||||||
Repository string `json:"repository"`
|
Repository string `json:"repository"`
|
||||||
Owner string `json:"owner"`
|
Owner string `json:"owner"`
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ func NewGithubIssueOpener(config map[string]string) *GithubIssuesOpener {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GithubIssuesOpener) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) {
|
func (g *GithubIssuesOpener) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||||
result := struct {
|
result := struct {
|
||||||
Title string `json:"title"`
|
Title string `json:"title"`
|
||||||
Body string `json:"text"`
|
Body string `json:"text"`
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ func NewGithubIssueReader(config map[string]string) *GithubIssuesReader {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GithubIssuesReader) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) {
|
func (g *GithubIssuesReader) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||||
result := struct {
|
result := struct {
|
||||||
Repository string `json:"repository"`
|
Repository string `json:"repository"`
|
||||||
Owner string `json:"owner"`
|
Owner string `json:"owner"`
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ func NewGithubIssueSearch(config map[string]string) *GithubIssueSearch {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GithubIssueSearch) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) {
|
func (g *GithubIssueSearch) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||||
result := struct {
|
result := struct {
|
||||||
Query string `json:"query"`
|
Query string `json:"query"`
|
||||||
Repository string `json:"repository"`
|
Repository string `json:"repository"`
|
||||||
|
|||||||
@@ -3,8 +3,6 @@ package actions
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"regexp"
|
|
||||||
"strconv"
|
|
||||||
|
|
||||||
"github.com/google/go-github/v69/github"
|
"github.com/google/go-github/v69/github"
|
||||||
"github.com/mudler/LocalAGI/core/types"
|
"github.com/mudler/LocalAGI/core/types"
|
||||||
@@ -17,96 +15,6 @@ type GithubPRCommenter struct {
|
|||||||
client *github.Client
|
client *github.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
|
||||||
patchRegex = regexp.MustCompile(`^@@.*\d [\+\-](\d+),?(\d+)?.+?@@`)
|
|
||||||
)
|
|
||||||
|
|
||||||
type commitFileInfo struct {
|
|
||||||
FileName string
|
|
||||||
hunkInfos []*hunkInfo
|
|
||||||
sha string
|
|
||||||
}
|
|
||||||
|
|
||||||
type hunkInfo struct {
|
|
||||||
hunkStart int
|
|
||||||
hunkEnd int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hi hunkInfo) isLineInHunk(line int) bool {
|
|
||||||
return line >= hi.hunkStart && line <= hi.hunkEnd
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cfi *commitFileInfo) getHunkInfo(line int) *hunkInfo {
|
|
||||||
for _, hunkInfo := range cfi.hunkInfos {
|
|
||||||
if hunkInfo.isLineInHunk(line) {
|
|
||||||
return hunkInfo
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cfi *commitFileInfo) isLineInChange(line int) bool {
|
|
||||||
return cfi.getHunkInfo(line) != nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cfi commitFileInfo) calculatePosition(line int) *int {
|
|
||||||
hi := cfi.getHunkInfo(line)
|
|
||||||
if hi == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
position := line - hi.hunkStart
|
|
||||||
return &position
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseHunkPositions(patch, filename string) ([]*hunkInfo, error) {
|
|
||||||
hunkInfos := make([]*hunkInfo, 0)
|
|
||||||
if patch != "" {
|
|
||||||
groups := patchRegex.FindAllStringSubmatch(patch, -1)
|
|
||||||
if len(groups) < 1 {
|
|
||||||
return hunkInfos, fmt.Errorf("the patch details for [%s] could not be resolved", filename)
|
|
||||||
}
|
|
||||||
for _, patchGroup := range groups {
|
|
||||||
endPos := 2
|
|
||||||
if len(patchGroup) > 2 && patchGroup[2] == "" {
|
|
||||||
endPos = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
hunkStart, err := strconv.Atoi(patchGroup[1])
|
|
||||||
if err != nil {
|
|
||||||
hunkStart = -1
|
|
||||||
}
|
|
||||||
hunkEnd, err := strconv.Atoi(patchGroup[endPos])
|
|
||||||
if err != nil {
|
|
||||||
hunkEnd = -1
|
|
||||||
}
|
|
||||||
hunkInfos = append(hunkInfos, &hunkInfo{
|
|
||||||
hunkStart: hunkStart,
|
|
||||||
hunkEnd: hunkEnd,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return hunkInfos, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getCommitInfo(file *github.CommitFile) (*commitFileInfo, error) {
|
|
||||||
patch := file.GetPatch()
|
|
||||||
hunkInfos, err := parseHunkPositions(patch, *file.Filename)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
sha := file.GetSHA()
|
|
||||||
if sha == "" {
|
|
||||||
return nil, fmt.Errorf("the sha details for [%s] could not be resolved", *file.Filename)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &commitFileInfo{
|
|
||||||
FileName: *file.Filename,
|
|
||||||
hunkInfos: hunkInfos,
|
|
||||||
sha: sha,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewGithubPRCommenter(config map[string]string) *GithubPRCommenter {
|
func NewGithubPRCommenter(config map[string]string) *GithubPRCommenter {
|
||||||
client := github.NewClient(nil).WithAuthToken(config["token"])
|
client := github.NewClient(nil).WithAuthToken(config["token"])
|
||||||
|
|
||||||
@@ -119,7 +27,7 @@ func NewGithubPRCommenter(config map[string]string) *GithubPRCommenter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GithubPRCommenter) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) {
|
func (g *GithubPRCommenter) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||||
result := struct {
|
result := struct {
|
||||||
Repository string `json:"repository"`
|
Repository string `json:"repository"`
|
||||||
Owner string `json:"owner"`
|
Owner string `json:"owner"`
|
||||||
|
|||||||
@@ -148,7 +148,7 @@ func (g *GithubPRCreator) createOrUpdateFile(ctx context.Context, branch string,
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GithubPRCreator) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) {
|
func (g *GithubPRCreator) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||||
result := struct {
|
result := struct {
|
||||||
Repository string `json:"repository"`
|
Repository string `json:"repository"`
|
||||||
Owner string `json:"owner"`
|
Owner string `json:"owner"`
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ var _ = Describe("GithubPRCreator", func() {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := action.Run(ctx, params)
|
result, err := action.Run(ctx, nil, params)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
Expect(result.Result).To(ContainSubstring("pull request #"))
|
Expect(result.Result).To(ContainSubstring("pull request #"))
|
||||||
})
|
})
|
||||||
@@ -65,7 +65,7 @@ var _ = Describe("GithubPRCreator", func() {
|
|||||||
"body": "This is a test pull request",
|
"body": "This is a test pull request",
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := action.Run(ctx, params)
|
_, err := action.Run(ctx, nil, params)
|
||||||
Expect(err).To(HaveOccurred())
|
Expect(err).To(HaveOccurred())
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ func NewGithubPRReader(config map[string]string) *GithubPRReader {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GithubPRReader) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) {
|
func (g *GithubPRReader) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||||
result := struct {
|
result := struct {
|
||||||
Repository string `json:"repository"`
|
Repository string `json:"repository"`
|
||||||
Owner string `json:"owner"`
|
Owner string `json:"owner"`
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ func NewGithubPRReviewer(config map[string]string) *GithubPRReviewer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GithubPRReviewer) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) {
|
func (g *GithubPRReviewer) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||||
result := struct {
|
result := struct {
|
||||||
Repository string `json:"repository"`
|
Repository string `json:"repository"`
|
||||||
Owner string `json:"owner"`
|
Owner string `json:"owner"`
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ var _ = Describe("GithubPRReviewer", func() {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := reviewer.Run(ctx, params)
|
result, err := reviewer.Run(ctx, nil, params)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
Expect(result.Result).To(ContainSubstring("reviewed successfully"))
|
Expect(result.Result).To(ContainSubstring("reviewed successfully"))
|
||||||
})
|
})
|
||||||
@@ -70,7 +70,7 @@ var _ = Describe("GithubPRReviewer", func() {
|
|||||||
"review_action": "COMMENT",
|
"review_action": "COMMENT",
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := reviewer.Run(ctx, params)
|
result, err := reviewer.Run(ctx, nil, params)
|
||||||
Expect(err).To(HaveOccurred())
|
Expect(err).To(HaveOccurred())
|
||||||
Expect(result.Result).To(ContainSubstring("not found"))
|
Expect(result.Result).To(ContainSubstring("not found"))
|
||||||
})
|
})
|
||||||
@@ -85,7 +85,7 @@ var _ = Describe("GithubPRReviewer", func() {
|
|||||||
"review_action": "INVALID_ACTION",
|
"review_action": "INVALID_ACTION",
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := reviewer.Run(ctx, params)
|
_, err := reviewer.Run(ctx, nil, params)
|
||||||
Expect(err).To(HaveOccurred())
|
Expect(err).To(HaveOccurred())
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ func NewGithubRepositoryCreateOrUpdateContent(config map[string]string) *GithubR
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GithubRepositoryCreateOrUpdateContent) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) {
|
func (g *GithubRepositoryCreateOrUpdateContent) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||||
result := struct {
|
result := struct {
|
||||||
Path string `json:"path"`
|
Path string `json:"path"`
|
||||||
Repository string `json:"repository"`
|
Repository string `json:"repository"`
|
||||||
|
|||||||
@@ -101,7 +101,7 @@ func (g *GithubRepositoryGetAllContent) getContentRecursively(ctx context.Contex
|
|||||||
return result.String(), nil
|
return result.String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GithubRepositoryGetAllContent) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) {
|
func (g *GithubRepositoryGetAllContent) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||||
result := struct {
|
result := struct {
|
||||||
Repository string `json:"repository"`
|
Repository string `json:"repository"`
|
||||||
Owner string `json:"owner"`
|
Owner string `json:"owner"`
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ var _ = Describe("GithubRepositoryGetAllContent", func() {
|
|||||||
"path": ".",
|
"path": ".",
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := action.Run(ctx, params)
|
result, err := action.Run(ctx, nil, params)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
Expect(result.Result).NotTo(BeEmpty())
|
Expect(result.Result).NotTo(BeEmpty())
|
||||||
|
|
||||||
@@ -64,7 +64,7 @@ var _ = Describe("GithubRepositoryGetAllContent", func() {
|
|||||||
"path": "non-existent-path",
|
"path": "non-existent-path",
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := action.Run(ctx, params)
|
_, err := action.Run(ctx, nil, params)
|
||||||
Expect(err).To(HaveOccurred())
|
Expect(err).To(HaveOccurred())
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ func NewGithubRepositoryGetContent(config map[string]string) *GithubRepositoryGe
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GithubRepositoryGetContent) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) {
|
func (g *GithubRepositoryGetContent) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||||
result := struct {
|
result := struct {
|
||||||
Path string `json:"path"`
|
Path string `json:"path"`
|
||||||
Repository string `json:"repository"`
|
Repository string `json:"repository"`
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ func (g *GithubRepositoryListFiles) listFilesRecursively(ctx context.Context, pa
|
|||||||
return files, nil
|
return files, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GithubRepositoryListFiles) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) {
|
func (g *GithubRepositoryListFiles) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||||
result := struct {
|
result := struct {
|
||||||
Repository string `json:"repository"`
|
Repository string `json:"repository"`
|
||||||
Owner string `json:"owner"`
|
Owner string `json:"owner"`
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ func NewGithubRepositoryREADME(config map[string]string) *GithubRepositoryREADME
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GithubRepositoryREADME) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) {
|
func (g *GithubRepositoryREADME) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||||
result := struct {
|
result := struct {
|
||||||
Repository string `json:"repository"`
|
Repository string `json:"repository"`
|
||||||
Owner string `json:"owner"`
|
Owner string `json:"owner"`
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ func (g *GithubRepositorySearchFiles) searchFilesRecursively(ctx context.Context
|
|||||||
return result.String(), nil
|
return result.String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GithubRepositorySearchFiles) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) {
|
func (g *GithubRepositorySearchFiles) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||||
result := struct {
|
result := struct {
|
||||||
Repository string `json:"repository"`
|
Repository string `json:"repository"`
|
||||||
Owner string `json:"owner"`
|
Owner string `json:"owner"`
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ func NewScraper(config map[string]string) *ScraperAction {
|
|||||||
|
|
||||||
type ScraperAction struct{}
|
type ScraperAction struct{}
|
||||||
|
|
||||||
func (a *ScraperAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) {
|
func (a *ScraperAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||||
result := struct {
|
result := struct {
|
||||||
URL string `json:"url"`
|
URL string `json:"url"`
|
||||||
}{}
|
}{}
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ func NewSearch(config map[string]string) *SearchAction {
|
|||||||
|
|
||||||
type SearchAction struct{ results int }
|
type SearchAction struct{ results int }
|
||||||
|
|
||||||
func (a *SearchAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) {
|
func (a *SearchAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||||
result := struct {
|
result := struct {
|
||||||
Query string `json:"query"`
|
Query string `json:"query"`
|
||||||
}{}
|
}{}
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ type SendMailAction struct {
|
|||||||
smtpPort string
|
smtpPort string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *SendMailAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) {
|
func (a *SendMailAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||||
result := struct {
|
result := struct {
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
To string `json:"to"`
|
To string `json:"to"`
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"github.com/mudler/LocalAGI/core/types"
|
"github.com/mudler/LocalAGI/core/types"
|
||||||
"github.com/mudler/LocalAGI/pkg/config"
|
"github.com/mudler/LocalAGI/pkg/config"
|
||||||
"github.com/mudler/LocalAGI/pkg/xstrings"
|
"github.com/mudler/LocalAGI/pkg/xstrings"
|
||||||
|
"github.com/sashabaranov/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/jsonschema"
|
"github.com/sashabaranov/go-openai/jsonschema"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -19,9 +20,11 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type SendTelegramMessageRunner struct {
|
type SendTelegramMessageRunner struct {
|
||||||
token string
|
token string
|
||||||
chatID int64
|
chatID int64
|
||||||
bot *bot.Bot
|
bot *bot.Bot
|
||||||
|
customName string
|
||||||
|
customDescription string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSendTelegramMessageRunner(config map[string]string) *SendTelegramMessageRunner {
|
func NewSendTelegramMessageRunner(config map[string]string) *SendTelegramMessageRunner {
|
||||||
@@ -46,9 +49,11 @@ func NewSendTelegramMessageRunner(config map[string]string) *SendTelegramMessage
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &SendTelegramMessageRunner{
|
return &SendTelegramMessageRunner{
|
||||||
token: token,
|
token: token,
|
||||||
chatID: chatID,
|
chatID: chatID,
|
||||||
bot: b,
|
bot: b,
|
||||||
|
customName: config["custom_name"],
|
||||||
|
customDescription: config["custom_description"],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -57,7 +62,7 @@ type TelegramMessageParams struct {
|
|||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SendTelegramMessageRunner) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) {
|
func (s *SendTelegramMessageRunner) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||||
var messageParams TelegramMessageParams
|
var messageParams TelegramMessageParams
|
||||||
err := params.Unmarshal(&messageParams)
|
err := params.Unmarshal(&messageParams)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -95,6 +100,11 @@ func (s *SendTelegramMessageRunner) Run(ctx context.Context, params types.Action
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sharedState.ConversationTracker.AddMessage(fmt.Sprintf("telegram:%d", messageParams.ChatID), openai.ChatCompletionMessage{
|
||||||
|
Content: messageParams.Message,
|
||||||
|
Role: "assistant",
|
||||||
|
})
|
||||||
|
|
||||||
return types.ActionResult{
|
return types.ActionResult{
|
||||||
Result: fmt.Sprintf("Message sent successfully to chat ID %d in %d parts", messageParams.ChatID, len(messages)),
|
Result: fmt.Sprintf("Message sent successfully to chat ID %d in %d parts", messageParams.ChatID, len(messages)),
|
||||||
Metadata: map[string]interface{}{
|
Metadata: map[string]interface{}{
|
||||||
@@ -104,10 +114,21 @@ func (s *SendTelegramMessageRunner) Run(ctx context.Context, params types.Action
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *SendTelegramMessageRunner) Definition() types.ActionDefinition {
|
func (s *SendTelegramMessageRunner) Definition() types.ActionDefinition {
|
||||||
|
|
||||||
|
customName := "send_telegram_message"
|
||||||
|
if s.customName != "" {
|
||||||
|
customName = s.customName
|
||||||
|
}
|
||||||
|
|
||||||
|
customDescription := "Send a message to a Telegram user or group"
|
||||||
|
if s.customDescription != "" {
|
||||||
|
customDescription = s.customDescription
|
||||||
|
}
|
||||||
|
|
||||||
if s.chatID != 0 {
|
if s.chatID != 0 {
|
||||||
return types.ActionDefinition{
|
return types.ActionDefinition{
|
||||||
Name: "send_telegram_message",
|
Name: types.ActionDefinitionName(customName),
|
||||||
Description: "Send a message to a Telegram user or group",
|
Description: customDescription,
|
||||||
Properties: map[string]jsonschema.Definition{
|
Properties: map[string]jsonschema.Definition{
|
||||||
"message": {
|
"message": {
|
||||||
Type: jsonschema.String,
|
Type: jsonschema.String,
|
||||||
@@ -119,8 +140,8 @@ func (s *SendTelegramMessageRunner) Definition() types.ActionDefinition {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return types.ActionDefinition{
|
return types.ActionDefinition{
|
||||||
Name: "send_telegram_message",
|
Name: types.ActionDefinitionName(customName),
|
||||||
Description: "Send a message to a Telegram user or group",
|
Description: customDescription,
|
||||||
Properties: map[string]jsonschema.Definition{
|
Properties: map[string]jsonschema.Definition{
|
||||||
"chat_id": {
|
"chat_id": {
|
||||||
Type: jsonschema.Number,
|
Type: jsonschema.Number,
|
||||||
@@ -156,5 +177,19 @@ func SendTelegramMessageConfigMeta() []config.Field {
|
|||||||
Required: false,
|
Required: false,
|
||||||
HelpText: "Default Telegram chat ID to send messages to (can be overridden in parameters)",
|
HelpText: "Default Telegram chat ID to send messages to (can be overridden in parameters)",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "custom_name",
|
||||||
|
Label: "Custom Name",
|
||||||
|
Type: config.FieldTypeText,
|
||||||
|
Required: false,
|
||||||
|
HelpText: "Custom name for the action (optional, defaults to 'send_telegram_message')",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "custom_description",
|
||||||
|
Label: "Custom Description",
|
||||||
|
Type: config.FieldTypeText,
|
||||||
|
Required: false,
|
||||||
|
HelpText: "Custom description for the action (optional, defaults to 'Send a message to a Telegram user or group')",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ type ShellAction struct {
|
|||||||
customDescription string
|
customDescription string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *ShellAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) {
|
func (a *ShellAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||||
result := struct {
|
result := struct {
|
||||||
Command string `json:"command"`
|
Command string `json:"command"`
|
||||||
Host string `json:"host"`
|
Host string `json:"host"`
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ type PostTweetAction struct {
|
|||||||
noCharacterLimit bool
|
noCharacterLimit bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *PostTweetAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) {
|
func (a *PostTweetAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||||
result := struct {
|
result := struct {
|
||||||
Text string `json:"text"`
|
Text string `json:"text"`
|
||||||
}{}
|
}{}
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ func NewWikipedia(config map[string]string) *WikipediaAction {
|
|||||||
|
|
||||||
type WikipediaAction struct{}
|
type WikipediaAction struct{}
|
||||||
|
|
||||||
func (a *WikipediaAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) {
|
func (a *WikipediaAction) Run(ctx context.Context, sharedState *types.AgentSharedState, params types.ActionParams) (types.ActionResult, error) {
|
||||||
result := struct {
|
result := struct {
|
||||||
Query string `json:"query"`
|
Query string `json:"query"`
|
||||||
}{}
|
}{}
|
||||||
|
|||||||
@@ -2,8 +2,8 @@ package connectors
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/bwmarrin/discordgo"
|
"github.com/bwmarrin/discordgo"
|
||||||
"github.com/mudler/LocalAGI/core/agent"
|
"github.com/mudler/LocalAGI/core/agent"
|
||||||
@@ -14,9 +14,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Discord struct {
|
type Discord struct {
|
||||||
token string
|
token string
|
||||||
defaultChannel string
|
defaultChannel string
|
||||||
conversationTracker *ConversationTracker[string]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDiscord creates a new Discord connector
|
// NewDiscord creates a new Discord connector
|
||||||
@@ -25,11 +24,6 @@ type Discord struct {
|
|||||||
// - defaultChannel: Discord channel to always answer even if not mentioned
|
// - defaultChannel: Discord channel to always answer even if not mentioned
|
||||||
func NewDiscord(config map[string]string) *Discord {
|
func NewDiscord(config map[string]string) *Discord {
|
||||||
|
|
||||||
duration, err := time.ParseDuration(config["lastMessageDuration"])
|
|
||||||
if err != nil {
|
|
||||||
duration = 5 * time.Minute
|
|
||||||
}
|
|
||||||
|
|
||||||
token := config["token"]
|
token := config["token"]
|
||||||
|
|
||||||
if !strings.HasPrefix(token, "Bot ") {
|
if !strings.HasPrefix(token, "Bot ") {
|
||||||
@@ -37,9 +31,8 @@ func NewDiscord(config map[string]string) *Discord {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &Discord{
|
return &Discord{
|
||||||
conversationTracker: NewConversationTracker[string](duration),
|
token: token,
|
||||||
token: token,
|
defaultChannel: config["defaultChannel"],
|
||||||
defaultChannel: config["defaultChannel"],
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -157,12 +150,12 @@ func (d *Discord) handleThreadMessage(a *agent.Agent, s *discordgo.Session, m *d
|
|||||||
|
|
||||||
func (d *Discord) handleChannelMessage(a *agent.Agent, s *discordgo.Session, m *discordgo.MessageCreate) {
|
func (d *Discord) handleChannelMessage(a *agent.Agent, s *discordgo.Session, m *discordgo.MessageCreate) {
|
||||||
|
|
||||||
d.conversationTracker.AddMessage(m.ChannelID, openai.ChatCompletionMessage{
|
a.SharedState().ConversationTracker.AddMessage(fmt.Sprintf("discord:%s", m.ChannelID), openai.ChatCompletionMessage{
|
||||||
Role: "user",
|
Role: "user",
|
||||||
Content: m.Content,
|
Content: m.Content,
|
||||||
})
|
})
|
||||||
|
|
||||||
conv := d.conversationTracker.GetConversation(m.ChannelID)
|
conv := a.SharedState().ConversationTracker.GetConversation(fmt.Sprintf("discord:%s", m.ChannelID))
|
||||||
|
|
||||||
jobResult := a.Ask(
|
jobResult := a.Ask(
|
||||||
types.WithConversationHistory(conv),
|
types.WithConversationHistory(conv),
|
||||||
@@ -173,7 +166,7 @@ func (d *Discord) handleChannelMessage(a *agent.Agent, s *discordgo.Session, m *
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
d.conversationTracker.AddMessage(m.ChannelID, openai.ChatCompletionMessage{
|
a.SharedState().ConversationTracker.AddMessage(fmt.Sprintf("discord:%s", m.ChannelID), openai.ChatCompletionMessage{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: jobResult.Response,
|
Content: jobResult.Response,
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -15,28 +15,22 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type IRC struct {
|
type IRC struct {
|
||||||
server string
|
server string
|
||||||
port string
|
port string
|
||||||
nickname string
|
nickname string
|
||||||
channel string
|
channel string
|
||||||
conn *irc.Connection
|
conn *irc.Connection
|
||||||
alwaysReply bool
|
alwaysReply bool
|
||||||
conversationTracker *ConversationTracker[string]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewIRC(config map[string]string) *IRC {
|
func NewIRC(config map[string]string) *IRC {
|
||||||
|
|
||||||
duration, err := time.ParseDuration(config["lastMessageDuration"])
|
|
||||||
if err != nil {
|
|
||||||
duration = 5 * time.Minute
|
|
||||||
}
|
|
||||||
return &IRC{
|
return &IRC{
|
||||||
server: config["server"],
|
server: config["server"],
|
||||||
port: config["port"],
|
port: config["port"],
|
||||||
nickname: config["nickname"],
|
nickname: config["nickname"],
|
||||||
channel: config["channel"],
|
channel: config["channel"],
|
||||||
alwaysReply: config["alwaysReply"] == "true",
|
alwaysReply: config["alwaysReply"] == "true",
|
||||||
conversationTracker: NewConversationTracker[string](duration),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -115,7 +109,7 @@ func (i *IRC) Start(a *agent.Agent) {
|
|||||||
cleanedMessage := cleanUpMessage(message, i.nickname)
|
cleanedMessage := cleanUpMessage(message, i.nickname)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
conv := i.conversationTracker.GetConversation(channel)
|
conv := a.SharedState().ConversationTracker.GetConversation(fmt.Sprintf("irc:%s", channel))
|
||||||
|
|
||||||
conv = append(conv,
|
conv = append(conv,
|
||||||
openai.ChatCompletionMessage{
|
openai.ChatCompletionMessage{
|
||||||
@@ -125,7 +119,7 @@ func (i *IRC) Start(a *agent.Agent) {
|
|||||||
)
|
)
|
||||||
|
|
||||||
// Update the conversation history
|
// Update the conversation history
|
||||||
i.conversationTracker.AddMessage(channel, openai.ChatCompletionMessage{
|
a.SharedState().ConversationTracker.AddMessage(fmt.Sprintf("irc:%s", channel), openai.ChatCompletionMessage{
|
||||||
Content: cleanedMessage,
|
Content: cleanedMessage,
|
||||||
Role: "user",
|
Role: "user",
|
||||||
})
|
})
|
||||||
@@ -140,7 +134,7 @@ func (i *IRC) Start(a *agent.Agent) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Update the conversation history
|
// Update the conversation history
|
||||||
i.conversationTracker.AddMessage(channel, openai.ChatCompletionMessage{
|
a.SharedState().ConversationTracker.AddMessage(fmt.Sprintf("irc:%s", channel), openai.ChatCompletionMessage{
|
||||||
Content: res.Response,
|
Content: res.Response,
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
})
|
})
|
||||||
@@ -209,7 +203,7 @@ func (i *IRC) Start(a *agent.Agent) {
|
|||||||
// Start the IRC client in a goroutine
|
// Start the IRC client in a goroutine
|
||||||
go i.conn.Loop()
|
go i.conn.Loop()
|
||||||
go func() {
|
go func() {
|
||||||
select {
|
select {
|
||||||
case <-a.Context().Done():
|
case <-a.Context().Done():
|
||||||
i.conn.Quit()
|
i.conn.Quit()
|
||||||
return
|
return
|
||||||
@@ -249,11 +243,5 @@ func IRCConfigMeta() []config.Field {
|
|||||||
Label: "Always Reply",
|
Label: "Always Reply",
|
||||||
Type: config.FieldTypeCheckbox,
|
Type: config.FieldTypeCheckbox,
|
||||||
},
|
},
|
||||||
{
|
|
||||||
Name: "lastMessageDuration",
|
|
||||||
Label: "Last Message Duration",
|
|
||||||
Type: config.FieldTypeText,
|
|
||||||
DefaultValue: "5m",
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -31,27 +31,20 @@ type Matrix struct {
|
|||||||
// Track active jobs for cancellation
|
// Track active jobs for cancellation
|
||||||
activeJobs map[string][]*types.Job // map[roomID]bool to track if a room has active processing
|
activeJobs map[string][]*types.Job // map[roomID]bool to track if a room has active processing
|
||||||
activeJobsMutex sync.RWMutex
|
activeJobsMutex sync.RWMutex
|
||||||
|
|
||||||
conversationTracker *ConversationTracker[string]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const matrixThinkingMessage = "🤔 thinking..."
|
const matrixThinkingMessage = "🤔 thinking..."
|
||||||
|
|
||||||
func NewMatrix(config map[string]string) *Matrix {
|
func NewMatrix(config map[string]string) *Matrix {
|
||||||
duration, err := time.ParseDuration(config["lastMessageDuration"])
|
|
||||||
if err != nil {
|
|
||||||
duration = 5 * time.Minute
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Matrix{
|
return &Matrix{
|
||||||
homeserverURL: config["homeserverURL"],
|
homeserverURL: config["homeserverURL"],
|
||||||
userID: config["userID"],
|
userID: config["userID"],
|
||||||
accessToken: config["accessToken"],
|
accessToken: config["accessToken"],
|
||||||
roomID: config["roomID"],
|
roomID: config["roomID"],
|
||||||
roomMode: config["roomMode"] == "true",
|
roomMode: config["roomMode"] == "true",
|
||||||
conversationTracker: NewConversationTracker[string](duration),
|
placeholders: make(map[string]string),
|
||||||
placeholders: make(map[string]string),
|
activeJobs: make(map[string][]*types.Job),
|
||||||
activeJobs: make(map[string][]*types.Job),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -149,7 +142,7 @@ func (m *Matrix) handleRoomMessage(a *agent.Agent, evt *event.Event) {
|
|||||||
// Cancel any active job for this room before starting a new one
|
// Cancel any active job for this room before starting a new one
|
||||||
m.cancelActiveJobForRoom(evt.RoomID.String())
|
m.cancelActiveJobForRoom(evt.RoomID.String())
|
||||||
|
|
||||||
currentConv := m.conversationTracker.GetConversation(evt.RoomID.String())
|
currentConv := a.SharedState().ConversationTracker.GetConversation(fmt.Sprintf("matrix:%s", evt.RoomID.String()))
|
||||||
|
|
||||||
message := evt.Content.AsMessage().Body
|
message := evt.Content.AsMessage().Body
|
||||||
|
|
||||||
@@ -163,8 +156,8 @@ func (m *Matrix) handleRoomMessage(a *agent.Agent, evt *event.Event) {
|
|||||||
Content: message,
|
Content: message,
|
||||||
})
|
})
|
||||||
|
|
||||||
m.conversationTracker.AddMessage(
|
a.SharedState().ConversationTracker.AddMessage(
|
||||||
evt.RoomID.String(), currentConv[len(currentConv)-1],
|
fmt.Sprintf("matrix:%s", evt.RoomID.String()), currentConv[len(currentConv)-1],
|
||||||
)
|
)
|
||||||
|
|
||||||
agentOptions = append(agentOptions, types.WithConversationHistory(currentConv))
|
agentOptions = append(agentOptions, types.WithConversationHistory(currentConv))
|
||||||
@@ -209,8 +202,8 @@ func (m *Matrix) handleRoomMessage(a *agent.Agent, evt *event.Event) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
m.conversationTracker.AddMessage(
|
a.SharedState().ConversationTracker.AddMessage(
|
||||||
evt.RoomID.String(), openai.ChatCompletionMessage{
|
fmt.Sprintf("matrix:%s", evt.RoomID.String()), openai.ChatCompletionMessage{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: res.Response,
|
Content: res.Response,
|
||||||
},
|
},
|
||||||
@@ -307,11 +300,5 @@ func MatrixConfigMeta() []config.Field {
|
|||||||
Label: "Room Mode",
|
Label: "Room Mode",
|
||||||
Type: config.FieldTypeCheckbox,
|
Type: config.FieldTypeCheckbox,
|
||||||
},
|
},
|
||||||
{
|
|
||||||
Name: "lastMessageDuration",
|
|
||||||
Label: "Last Message Duration",
|
|
||||||
Type: config.FieldTypeText,
|
|
||||||
DefaultValue: "5m",
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/mudler/LocalAGI/pkg/config"
|
"github.com/mudler/LocalAGI/pkg/config"
|
||||||
"github.com/mudler/LocalAGI/pkg/localoperator"
|
"github.com/mudler/LocalAGI/pkg/localoperator"
|
||||||
@@ -42,27 +41,19 @@ type Slack struct {
|
|||||||
// Track active jobs for cancellation
|
// Track active jobs for cancellation
|
||||||
activeJobs map[string][]*types.Job // map[channelID]bool to track if a channel has active processing
|
activeJobs map[string][]*types.Job // map[channelID]bool to track if a channel has active processing
|
||||||
activeJobsMutex sync.RWMutex
|
activeJobsMutex sync.RWMutex
|
||||||
|
|
||||||
conversationTracker *ConversationTracker[string]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const thinkingMessage = ":hourglass: thinking..."
|
const thinkingMessage = ":hourglass: thinking..."
|
||||||
|
|
||||||
func NewSlack(config map[string]string) *Slack {
|
func NewSlack(config map[string]string) *Slack {
|
||||||
|
|
||||||
duration, err := time.ParseDuration(config["lastMessageDuration"])
|
|
||||||
if err != nil {
|
|
||||||
duration = 5 * time.Minute
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Slack{
|
return &Slack{
|
||||||
appToken: config["appToken"],
|
appToken: config["appToken"],
|
||||||
botToken: config["botToken"],
|
botToken: config["botToken"],
|
||||||
channelID: config["channelID"],
|
channelID: config["channelID"],
|
||||||
channelMode: config["channelMode"] == "true",
|
channelMode: config["channelMode"] == "true",
|
||||||
conversationTracker: NewConversationTracker[string](duration),
|
placeholders: make(map[string]string),
|
||||||
placeholders: make(map[string]string),
|
activeJobs: make(map[string][]*types.Job),
|
||||||
activeJobs: make(map[string][]*types.Job),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -140,16 +131,6 @@ func cleanUpUsernameFromMessage(message string, b *slack.AuthTestResponse) strin
|
|||||||
return cleaned
|
return cleaned
|
||||||
}
|
}
|
||||||
|
|
||||||
func extractUserIDsFromMessage(message string) []string {
|
|
||||||
var userIDs []string
|
|
||||||
for _, part := range strings.Split(message, " ") {
|
|
||||||
if strings.HasPrefix(part, "<@") && strings.HasSuffix(part, ">") {
|
|
||||||
userIDs = append(userIDs, strings.TrimPrefix(strings.TrimSuffix(part, ">"), "<@"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return userIDs
|
|
||||||
}
|
|
||||||
|
|
||||||
func replaceUserIDsWithNamesInMessage(api *slack.Client, message string) string {
|
func replaceUserIDsWithNamesInMessage(api *slack.Client, message string) string {
|
||||||
for _, part := range strings.Split(message, " ") {
|
for _, part := range strings.Split(message, " ") {
|
||||||
if strings.HasPrefix(part, "<@") && strings.HasSuffix(part, ">") {
|
if strings.HasPrefix(part, "<@") && strings.HasSuffix(part, ">") {
|
||||||
@@ -279,7 +260,7 @@ func (t *Slack) handleChannelMessage(
|
|||||||
// Cancel any active job for this channel before starting a new one
|
// Cancel any active job for this channel before starting a new one
|
||||||
t.cancelActiveJobForChannel(ev.Channel)
|
t.cancelActiveJobForChannel(ev.Channel)
|
||||||
|
|
||||||
currentConv := t.conversationTracker.GetConversation(t.channelID)
|
currentConv := a.SharedState().ConversationTracker.GetConversation(fmt.Sprintf("slack:%s", t.channelID))
|
||||||
|
|
||||||
message := replaceUserIDsWithNamesInMessage(api, cleanUpUsernameFromMessage(ev.Text, b))
|
message := replaceUserIDsWithNamesInMessage(api, cleanUpUsernameFromMessage(ev.Text, b))
|
||||||
|
|
||||||
@@ -323,8 +304,8 @@ func (t *Slack) handleChannelMessage(
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
t.conversationTracker.AddMessage(
|
a.SharedState().ConversationTracker.AddMessage(
|
||||||
t.channelID, currentConv[len(currentConv)-1],
|
fmt.Sprintf("slack:%s", t.channelID), currentConv[len(currentConv)-1],
|
||||||
)
|
)
|
||||||
|
|
||||||
agentOptions = append(agentOptions, types.WithConversationHistory(currentConv))
|
agentOptions = append(agentOptions, types.WithConversationHistory(currentConv))
|
||||||
@@ -370,14 +351,14 @@ func (t *Slack) handleChannelMessage(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
t.conversationTracker.AddMessage(
|
a.SharedState().ConversationTracker.AddMessage(
|
||||||
t.channelID, openai.ChatCompletionMessage{
|
fmt.Sprintf("slack:%s", t.channelID), openai.ChatCompletionMessage{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: res.Response,
|
Content: res.Response,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
xlog.Debug("After adding message to conversation tracker", "conversation", t.conversationTracker.GetConversation(t.channelID))
|
xlog.Debug("After adding message to conversation tracker", "conversation", a.SharedState().ConversationTracker.GetConversation(fmt.Sprintf("slack:%s", t.channelID)))
|
||||||
|
|
||||||
//res.Response = githubmarkdownconvertergo.Slack(res.Response)
|
//res.Response = githubmarkdownconvertergo.Slack(res.Response)
|
||||||
|
|
||||||
@@ -752,6 +733,13 @@ func (t *Slack) Start(a *agent.Agent) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
xlog.Error(fmt.Sprintf("Error posting message: %v", err))
|
xlog.Error(fmt.Sprintf("Error posting message: %v", err))
|
||||||
}
|
}
|
||||||
|
a.SharedState().ConversationTracker.AddMessage(
|
||||||
|
fmt.Sprintf("slack:%s", t.channelID),
|
||||||
|
openai.ChatCompletionMessage{
|
||||||
|
Content: ccm.Content,
|
||||||
|
Role: "assistant",
|
||||||
|
},
|
||||||
|
)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -835,11 +823,5 @@ func SlackConfigMeta() []config.Field {
|
|||||||
Label: "Always Reply",
|
Label: "Always Reply",
|
||||||
Type: config.FieldTypeCheckbox,
|
Type: config.FieldTypeCheckbox,
|
||||||
},
|
},
|
||||||
{
|
|
||||||
Name: "lastMessageDuration",
|
|
||||||
Label: "Last Message Duration",
|
|
||||||
Type: config.FieldTypeText,
|
|
||||||
DefaultValue: "5m",
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ import (
|
|||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/go-telegram/bot"
|
"github.com/go-telegram/bot"
|
||||||
"github.com/go-telegram/bot/models"
|
"github.com/go-telegram/bot/models"
|
||||||
@@ -35,14 +34,8 @@ type Telegram struct {
|
|||||||
bot *bot.Bot
|
bot *bot.Bot
|
||||||
agent *agent.Agent
|
agent *agent.Agent
|
||||||
|
|
||||||
currentconversation map[int64][]openai.ChatCompletionMessage
|
|
||||||
lastMessageTime map[int64]time.Time
|
|
||||||
lastMessageDuration time.Duration
|
|
||||||
|
|
||||||
admins []string
|
admins []string
|
||||||
|
|
||||||
conversationTracker *ConversationTracker[int64]
|
|
||||||
|
|
||||||
// To track placeholder messages
|
// To track placeholder messages
|
||||||
placeholders map[string]int // map[jobUUID]messageID
|
placeholders map[string]int // map[jobUUID]messageID
|
||||||
placeholderMutex sync.RWMutex
|
placeholderMutex sync.RWMutex
|
||||||
@@ -50,6 +43,8 @@ type Telegram struct {
|
|||||||
// Track active jobs for cancellation
|
// Track active jobs for cancellation
|
||||||
activeJobs map[int64][]*types.Job // map[chatID]bool to track if a chat has active processing
|
activeJobs map[int64][]*types.Job // map[chatID]bool to track if a chat has active processing
|
||||||
activeJobsMutex sync.RWMutex
|
activeJobsMutex sync.RWMutex
|
||||||
|
|
||||||
|
channelID string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send any text message to the bot after the bot has been started
|
// Send any text message to the bot after the bot has been started
|
||||||
@@ -219,6 +214,8 @@ func formatResponseWithURLs(response string, urls []string) string {
|
|||||||
func (t *Telegram) handleUpdate(ctx context.Context, b *bot.Bot, a *agent.Agent, update *models.Update) {
|
func (t *Telegram) handleUpdate(ctx context.Context, b *bot.Bot, a *agent.Agent, update *models.Update) {
|
||||||
username := update.Message.From.Username
|
username := update.Message.From.Username
|
||||||
|
|
||||||
|
xlog.Debug("Received message from user", "username", username, "chatID", update.Message.Chat.ID, "message", update.Message.Text)
|
||||||
|
|
||||||
internalError := func(err error, msg *models.Message) {
|
internalError := func(err error, msg *models.Message) {
|
||||||
xlog.Error("Error updating final message", "error", err)
|
xlog.Error("Error updating final message", "error", err)
|
||||||
b.EditMessageText(ctx, &bot.EditMessageTextParams{
|
b.EditMessageText(ctx, &bot.EditMessageTextParams{
|
||||||
@@ -242,14 +239,14 @@ func (t *Telegram) handleUpdate(ctx context.Context, b *bot.Bot, a *agent.Agent,
|
|||||||
// Cancel any active job for this chat before starting a new one
|
// Cancel any active job for this chat before starting a new one
|
||||||
t.cancelActiveJobForChat(update.Message.Chat.ID)
|
t.cancelActiveJobForChat(update.Message.Chat.ID)
|
||||||
|
|
||||||
currentConv := t.conversationTracker.GetConversation(update.Message.From.ID)
|
currentConv := a.SharedState().ConversationTracker.GetConversation(fmt.Sprintf("telegram:%d", update.Message.From.ID))
|
||||||
currentConv = append(currentConv, openai.ChatCompletionMessage{
|
currentConv = append(currentConv, openai.ChatCompletionMessage{
|
||||||
Content: update.Message.Text,
|
Content: update.Message.Text,
|
||||||
Role: "user",
|
Role: "user",
|
||||||
})
|
})
|
||||||
|
|
||||||
t.conversationTracker.AddMessage(
|
a.SharedState().ConversationTracker.AddMessage(
|
||||||
update.Message.From.ID,
|
fmt.Sprintf("telegram:%d", update.Message.From.ID),
|
||||||
openai.ChatCompletionMessage{
|
openai.ChatCompletionMessage{
|
||||||
Content: update.Message.Text,
|
Content: update.Message.Text,
|
||||||
Role: "user",
|
Role: "user",
|
||||||
@@ -328,8 +325,8 @@ func (t *Telegram) handleUpdate(ctx context.Context, b *bot.Bot, a *agent.Agent,
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
t.conversationTracker.AddMessage(
|
a.SharedState().ConversationTracker.AddMessage(
|
||||||
update.Message.From.ID,
|
fmt.Sprintf("telegram:%d", update.Message.From.ID),
|
||||||
openai.ChatCompletionMessage{
|
openai.ChatCompletionMessage{
|
||||||
Content: res.Response,
|
Content: res.Response,
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
@@ -408,11 +405,34 @@ func (t *Telegram) Start(a *agent.Agent) {
|
|||||||
t.agent = a
|
t.agent = a
|
||||||
|
|
||||||
// go func() {
|
// go func() {
|
||||||
// for m := range a.ConversationChannel() {
|
// forc m := range a.ConversationChannel() {
|
||||||
// t.handleNewMessage(ctx, b, m)
|
// t.handleNewMessage(ctx, b, m)
|
||||||
// }
|
// }
|
||||||
// }()
|
// }()
|
||||||
|
|
||||||
|
if t.channelID != "" {
|
||||||
|
// handle new conversations
|
||||||
|
a.AddSubscriber(func(ccm openai.ChatCompletionMessage) {
|
||||||
|
xlog.Debug("Subscriber(telegram)", "message", ccm.Content)
|
||||||
|
_, err := b.SendMessage(ctx, &bot.SendMessageParams{
|
||||||
|
ChatID: t.channelID,
|
||||||
|
Text: ccm.Content,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
xlog.Error("Error sending message", "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
t.agent.SharedState().ConversationTracker.AddMessage(
|
||||||
|
fmt.Sprintf("telegram:%s", t.channelID),
|
||||||
|
openai.ChatCompletionMessage{
|
||||||
|
Content: ccm.Content,
|
||||||
|
Role: "assistant",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
b.Start(ctx)
|
b.Start(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -422,11 +442,6 @@ func NewTelegramConnector(config map[string]string) (*Telegram, error) {
|
|||||||
return nil, errors.New("token is required")
|
return nil, errors.New("token is required")
|
||||||
}
|
}
|
||||||
|
|
||||||
duration, err := time.ParseDuration(config["lastMessageDuration"])
|
|
||||||
if err != nil {
|
|
||||||
duration = 5 * time.Minute
|
|
||||||
}
|
|
||||||
|
|
||||||
admins := []string{}
|
admins := []string{}
|
||||||
|
|
||||||
if _, ok := config["admins"]; ok {
|
if _, ok := config["admins"]; ok {
|
||||||
@@ -434,14 +449,11 @@ func NewTelegramConnector(config map[string]string) (*Telegram, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &Telegram{
|
return &Telegram{
|
||||||
Token: token,
|
Token: token,
|
||||||
lastMessageDuration: duration,
|
admins: admins,
|
||||||
admins: admins,
|
placeholders: make(map[string]int),
|
||||||
currentconversation: map[int64][]openai.ChatCompletionMessage{},
|
activeJobs: make(map[int64][]*types.Job),
|
||||||
lastMessageTime: map[int64]time.Time{},
|
channelID: config["channel_id"],
|
||||||
conversationTracker: NewConversationTracker[int64](duration),
|
|
||||||
placeholders: make(map[string]int),
|
|
||||||
activeJobs: make(map[int64][]*types.Job),
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -461,10 +473,10 @@ func TelegramConfigMeta() []config.Field {
|
|||||||
HelpText: "Comma-separated list of Telegram usernames that are allowed to interact with the bot",
|
HelpText: "Comma-separated list of Telegram usernames that are allowed to interact with the bot",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "lastMessageDuration",
|
Name: "channel_id",
|
||||||
Label: "Last Message Duration",
|
Label: "Channel ID",
|
||||||
Type: config.FieldTypeText,
|
Type: config.FieldTypeText,
|
||||||
DefaultValue: "5m",
|
HelpText: "Telegram channel ID to send messages to if the agent needs to initiate a conversation",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
18
webui/app.go
18
webui/app.go
@@ -11,12 +11,14 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/mudler/LocalAGI/core/conversations"
|
||||||
coreTypes "github.com/mudler/LocalAGI/core/types"
|
coreTypes "github.com/mudler/LocalAGI/core/types"
|
||||||
|
internalTypes "github.com/mudler/LocalAGI/core/types"
|
||||||
"github.com/mudler/LocalAGI/pkg/llm"
|
"github.com/mudler/LocalAGI/pkg/llm"
|
||||||
"github.com/mudler/LocalAGI/pkg/xlog"
|
"github.com/mudler/LocalAGI/pkg/xlog"
|
||||||
"github.com/mudler/LocalAGI/services"
|
"github.com/mudler/LocalAGI/services"
|
||||||
"github.com/mudler/LocalAGI/services/connectors"
|
|
||||||
"github.com/mudler/LocalAGI/webui/types"
|
"github.com/mudler/LocalAGI/webui/types"
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai"
|
"github.com/sashabaranov/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/jsonschema"
|
"github.com/sashabaranov/go-openai/jsonschema"
|
||||||
|
|
||||||
@@ -33,6 +35,7 @@ type (
|
|||||||
htmx *htmx.HTMX
|
htmx *htmx.HTMX
|
||||||
config *Config
|
config *Config
|
||||||
*fiber.App
|
*fiber.App
|
||||||
|
sharedState *internalTypes.AgentSharedState
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -47,9 +50,10 @@ func NewApp(opts ...Option) *App {
|
|||||||
})
|
})
|
||||||
|
|
||||||
a := &App{
|
a := &App{
|
||||||
htmx: htmx.New(),
|
htmx: htmx.New(),
|
||||||
config: config,
|
config: config,
|
||||||
App: webapp,
|
App: webapp,
|
||||||
|
sharedState: internalTypes.NewAgentSharedState(5 * time.Minute),
|
||||||
}
|
}
|
||||||
|
|
||||||
a.registerRoutes(config.Pool, webapp)
|
a.registerRoutes(config.Pool, webapp)
|
||||||
@@ -443,7 +447,7 @@ func (a *App) GetActionDefinition(pool *state.AgentPool) func(c *fiber.Ctx) erro
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *App) ExecuteAction(pool *state.AgentPool) func(c *fiber.Ctx) error {
|
func (app *App) ExecuteAction(pool *state.AgentPool) func(c *fiber.Ctx) error {
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
payload := struct {
|
payload := struct {
|
||||||
Config map[string]string `json:"config"`
|
Config map[string]string `json:"config"`
|
||||||
@@ -467,7 +471,7 @@ func (a *App) ExecuteAction(pool *state.AgentPool) func(c *fiber.Ctx) error {
|
|||||||
ctx, cancel := context.WithTimeout(c.Context(), 200*time.Second)
|
ctx, cancel := context.WithTimeout(c.Context(), 200*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
res, err := a.Run(ctx, payload.Params)
|
res, err := a.Run(ctx, app.sharedState, payload.Params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
xlog.Error("Error running action", "error", err)
|
xlog.Error("Error running action", "error", err)
|
||||||
return errorJSONMessage(c, err.Error())
|
return errorJSONMessage(c, err.Error())
|
||||||
@@ -484,7 +488,7 @@ func (a *App) ListActions() func(c *fiber.Ctx) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *App) Responses(pool *state.AgentPool, tracker *connectors.ConversationTracker[string]) func(c *fiber.Ctx) error {
|
func (a *App) Responses(pool *state.AgentPool, tracker *conversations.ConversationTracker[string]) func(c *fiber.Ctx) error {
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
var request types.RequestBody
|
var request types.RequestBody
|
||||||
if err := c.BodyParser(&request); err != nil {
|
if err := c.BodyParser(&request); err != nil {
|
||||||
|
|||||||
@@ -13,8 +13,8 @@ import (
|
|||||||
fiber "github.com/gofiber/fiber/v2"
|
fiber "github.com/gofiber/fiber/v2"
|
||||||
"github.com/gofiber/fiber/v2/middleware/filesystem"
|
"github.com/gofiber/fiber/v2/middleware/filesystem"
|
||||||
"github.com/gofiber/fiber/v2/middleware/keyauth"
|
"github.com/gofiber/fiber/v2/middleware/keyauth"
|
||||||
|
"github.com/mudler/LocalAGI/core/conversations"
|
||||||
"github.com/mudler/LocalAGI/core/sse"
|
"github.com/mudler/LocalAGI/core/sse"
|
||||||
"github.com/mudler/LocalAGI/services/connectors"
|
|
||||||
|
|
||||||
"github.com/mudler/LocalAGI/core/state"
|
"github.com/mudler/LocalAGI/core/state"
|
||||||
"github.com/mudler/LocalAGI/core/types"
|
"github.com/mudler/LocalAGI/core/types"
|
||||||
@@ -138,7 +138,7 @@ func (app *App) registerRoutes(pool *state.AgentPool, webapp *fiber.App) {
|
|||||||
|
|
||||||
webapp.Post("/api/chat/:name", app.Chat(pool))
|
webapp.Post("/api/chat/:name", app.Chat(pool))
|
||||||
|
|
||||||
conversationTracker := connectors.NewConversationTracker[string](app.config.ConversationStoreDuration)
|
conversationTracker := conversations.NewConversationTracker[string](app.config.ConversationStoreDuration)
|
||||||
|
|
||||||
webapp.Post("/v1/responses", app.Responses(pool, conversationTracker))
|
webapp.Post("/v1/responses", app.Responses(pool, conversationTracker))
|
||||||
|
|
||||||
@@ -268,7 +268,7 @@ func (app *App) registerRoutes(pool *state.AgentPool, webapp *fiber.App) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return c.JSON(fiber.Map{
|
return c.JSON(fiber.Map{
|
||||||
"Name": name,
|
"Name": name,
|
||||||
"History": agent.Observer().History(),
|
"History": agent.Observer().History(),
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
Reference in New Issue
Block a user