From 08785e2908315cfff24c9ea3027b1ff4a2f1c8a9 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Wed, 19 Mar 2025 22:58:35 +0100 Subject: [PATCH] feat: add action to call other agents (#60) Signed-off-by: Ettore Di Giacinto --- core/state/pool.go | 18 +++++-- services/actions.go | 8 +++- services/actions/callagents.go | 87 ++++++++++++++++++++++++++++++++++ 3 files changed, 108 insertions(+), 5 deletions(-) create mode 100644 services/actions/callagents.go diff --git a/core/state/pool.go b/core/state/pool.go index 790bc1d..b79d0b3 100644 --- a/core/state/pool.go +++ b/core/state/pool.go @@ -28,7 +28,7 @@ type AgentPool struct { managers map[string]sse.Manager agentStatus map[string]*Status apiURL, defaultModel, defaultMultimodalModel, localRAGAPI, apiKey string - availableActions func(*AgentConfig) func(ctx context.Context) []Action + availableActions func(*AgentConfig) func(ctx context.Context, pool *AgentPool) []Action connectors func(*AgentConfig) []Connector promptBlocks func(*AgentConfig) []PromptBlock timeout string @@ -68,7 +68,7 @@ func loadPoolFromFile(path string) (*AgentPoolData, error) { func NewAgentPool( defaultModel, defaultMultimodalModel, apiURL, apiKey, directory string, LocalRAGAPI string, - availableActions func(*AgentConfig) func(ctx context.Context) []agent.Action, + availableActions func(*AgentConfig) func(ctx context.Context, pool *AgentPool) []agent.Action, connectors func(*AgentConfig) []Connector, promptBlocks func(*AgentConfig) []PromptBlock, timeout string, @@ -185,7 +185,7 @@ func (a *AgentPool) startAgentWithConfig(name string, config *AgentConfig) error connectors := a.connectors(config) promptBlocks := a.promptBlocks(config) - actions := a.availableActions(config)(ctx) + actions := a.availableActions(config)(ctx, a) stateFile, characterFile := a.stateFiles(name) @@ -458,9 +458,21 @@ func (a *AgentPool) save() error { return os.WriteFile(a.file, data, 0644) } func (a *AgentPool) GetAgent(name string) *Agent { + a.Lock() + defer a.Unlock() return a.agents[name] } +func (a *AgentPool) AllAgents() []string { + a.Lock() + defer a.Unlock() + var agents []string + for agent := range a.agents { + agents = append(agents, agent) + } + return agents +} + func (a *AgentPool) GetConfig(name string) *AgentConfig { a.Lock() defer a.Unlock() diff --git a/services/actions.go b/services/actions.go index f7bc8e8..85c8252 100644 --- a/services/actions.go +++ b/services/actions.go @@ -31,6 +31,7 @@ const ( ActionSendMail = "send_mail" ActionGenerateImage = "generate_image" ActionCounter = "counter" + ActionCallAgents = "call_agents" ) var AvailableActions = []string{ @@ -51,10 +52,11 @@ var AvailableActions = []string{ ActionGenerateImage, ActionTwitterPost, ActionCounter, + ActionCallAgents, } -func Actions(a *state.AgentConfig) func(ctx context.Context) []agent.Action { - return func(ctx context.Context) []agent.Action { +func Actions(a *state.AgentConfig) func(ctx context.Context, pool *state.AgentPool) []agent.Action { + return func(ctx context.Context, pool *state.AgentPool) []agent.Action { allActions := []agent.Action{} for _, a := range a.Actions { @@ -104,6 +106,8 @@ func Actions(a *state.AgentConfig) func(ctx context.Context) []agent.Action { allActions = append(allActions, actions.NewPostTweet(config)) case ActionCounter: allActions = append(allActions, actions.NewCounter(config)) + case ActionCallAgents: + allActions = append(allActions, actions.NewCallAgent(config, pool)) } } diff --git a/services/actions/callagents.go b/services/actions/callagents.go new file mode 100644 index 0000000..e09d704 --- /dev/null +++ b/services/actions/callagents.go @@ -0,0 +1,87 @@ +package actions + +import ( + "context" + "fmt" + + "github.com/mudler/LocalAgent/core/action" + "github.com/mudler/LocalAgent/core/agent" + "github.com/mudler/LocalAgent/core/state" + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/jsonschema" +) + +func NewCallAgent(config map[string]string, pool *state.AgentPool) *CallAgentAction { + return &CallAgentAction{ + pool: pool, + } +} + +type CallAgentAction struct { + pool *state.AgentPool +} + +func (a *CallAgentAction) Run(ctx context.Context, params action.ActionParams) (action.ActionResult, error) { + result := struct { + AgentName string `json:"agent_name"` + Message string `json:"message"` + }{} + err := params.Unmarshal(&result) + if err != nil { + fmt.Printf("error: %v", err) + + return action.ActionResult{}, err + } + + ag := a.pool.GetAgent(result.AgentName) + if ag == nil { + return action.ActionResult{}, fmt.Errorf("agent '%s' not found", result.AgentName) + } + + resp := ag.Ask( + agent.WithConversationHistory( + []openai.ChatCompletionMessage{ + { + Role: "user", + Content: result.Message, + }, + }, + ), + ) + if resp.Error != nil { + return action.ActionResult{}, err + } + + return action.ActionResult{Result: resp.Response}, nil +} + +func (a *CallAgentAction) Definition() action.ActionDefinition { + allAgents := a.pool.AllAgents() + + description := "Use this tool to call another agent. Available agents and their roles are:" + + for _, agent := range allAgents { + agentConfig := a.pool.GetConfig(agent) + if agentConfig == nil { + continue + } + description += fmt.Sprintf("\n- %s: %s", agent, agentConfig.Description) + } + + return action.ActionDefinition{ + Name: "call_agent", + Description: description, + Properties: map[string]jsonschema.Definition{ + "agent_name": { + Type: jsonschema.String, + Description: "The name of the agent to call.", + Enum: allAgents, + }, + "message": { + Type: jsonschema.String, + Description: "The message to send to the agent.", + }, + }, + Required: []string{"agent_name", "message"}, + } +}