diff --git a/services/actions.go b/services/actions.go index c4ad928..5ec2098 100644 --- a/services/actions.go +++ b/services/actions.go @@ -339,7 +339,7 @@ func ActionsConfigMeta() []config.FieldGroup { { Name: "call_agents", Label: "Call Agents", - Fields: []config.Field{}, + Fields: actions.CallAgentConfigMeta(), }, } } diff --git a/services/actions/callagents.go b/services/actions/callagents.go index a37a019..d4b65ac 100644 --- a/services/actions/callagents.go +++ b/services/actions/callagents.go @@ -3,23 +3,53 @@ package actions import ( "context" "fmt" + "slices" + "strings" "github.com/mudler/LocalAGI/core/state" "github.com/mudler/LocalAGI/core/types" + "github.com/mudler/LocalAGI/pkg/config" "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/jsonschema" ) +func trimList(list []string) []string { + for i, v := range list { + list[i] = strings.TrimSpace(v) + } + return list +} + func NewCallAgent(config map[string]string, agentName string, pool *state.AgentPoolInternalAPI) *CallAgentAction { + whitelist := []string{} + blacklist := []string{} + if v, ok := config["whitelist"]; ok { + if strings.Contains(v, ",") { + whitelist = trimList(strings.Split(v, ",")) + } else { + whitelist = []string{v} + } + } + if v, ok := config["blacklist"]; ok { + if strings.Contains(v, ",") { + blacklist = trimList(strings.Split(v, ",")) + } else { + blacklist = []string{v} + } + } return &CallAgentAction{ - pool: pool, - myName: agentName, + pool: pool, + myName: agentName, + whitelist: whitelist, + blacklist: blacklist, } } type CallAgentAction struct { - pool *state.AgentPoolInternalAPI - myName string + pool *state.AgentPoolInternalAPI + myName string + whitelist []string + blacklist []string } func (a *CallAgentAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { @@ -83,13 +113,32 @@ func (a *CallAgentAction) Run(ctx context.Context, params types.ActionParams) (t return types.ActionResult{Result: resp.Response, Metadata: metadata}, nil } +func (a *CallAgentAction) isAllowedToBeCalled(agentName string) bool { + if agentName == a.myName { + return false + } + + if len(a.whitelist) > 0 && len(a.blacklist) > 0 { + return slices.Contains(a.whitelist, agentName) && !slices.Contains(a.blacklist, agentName) + } + + if len(a.whitelist) > 0 { + return slices.Contains(a.whitelist, agentName) + } + + if len(a.blacklist) > 0 { + return !slices.Contains(a.blacklist, agentName) + } + return true +} + func (a *CallAgentAction) Definition() types.ActionDefinition { allAgents := a.pool.AllAgents() agents := []string{} for _, ag := range allAgents { - if ag != a.myName { + if a.isAllowedToBeCalled(ag) { agents = append(agents, ag) } } @@ -125,3 +174,21 @@ func (a *CallAgentAction) Definition() types.ActionDefinition { func (a *CallAgentAction) Plannable() bool { return true } + +func CallAgentConfigMeta() []config.Field { + return []config.Field{ + { + Name: "whitelist", + Label: "Whitelist", + Type: config.FieldTypeText, + Required: false, + HelpText: "Comma-separated list of agent names to call. If not specified, all agents are allowed.", + }, + { + Name: "blacklist", + Label: "Blacklist", + Type: config.FieldTypeText, + HelpText: "Comma-separated list of agent names to exclude from the call. If not specified, all agents are allowed.", + }, + } +}