feat(call_agents): allow to specify whitelist and blacklist agents (#144)

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2025-05-10 22:27:34 +02:00
committed by GitHub
parent 324124e002
commit 289edb67a6
2 changed files with 73 additions and 6 deletions

View File

@@ -339,7 +339,7 @@ func ActionsConfigMeta() []config.FieldGroup {
{ {
Name: "call_agents", Name: "call_agents",
Label: "Call Agents", Label: "Call Agents",
Fields: []config.Field{}, Fields: actions.CallAgentConfigMeta(),
}, },
} }
} }

View File

@@ -3,23 +3,53 @@ package actions
import ( import (
"context" "context"
"fmt" "fmt"
"slices"
"strings"
"github.com/mudler/LocalAGI/core/state" "github.com/mudler/LocalAGI/core/state"
"github.com/mudler/LocalAGI/core/types" "github.com/mudler/LocalAGI/core/types"
"github.com/mudler/LocalAGI/pkg/config"
"github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/jsonschema" "github.com/sashabaranov/go-openai/jsonschema"
) )
func trimList(list []string) []string {
for i, v := range list {
list[i] = strings.TrimSpace(v)
}
return list
}
func NewCallAgent(config map[string]string, agentName string, pool *state.AgentPoolInternalAPI) *CallAgentAction { func NewCallAgent(config map[string]string, agentName string, pool *state.AgentPoolInternalAPI) *CallAgentAction {
whitelist := []string{}
blacklist := []string{}
if v, ok := config["whitelist"]; ok {
if strings.Contains(v, ",") {
whitelist = trimList(strings.Split(v, ","))
} else {
whitelist = []string{v}
}
}
if v, ok := config["blacklist"]; ok {
if strings.Contains(v, ",") {
blacklist = trimList(strings.Split(v, ","))
} else {
blacklist = []string{v}
}
}
return &CallAgentAction{ return &CallAgentAction{
pool: pool, pool: pool,
myName: agentName, myName: agentName,
whitelist: whitelist,
blacklist: blacklist,
} }
} }
type CallAgentAction struct { type CallAgentAction struct {
pool *state.AgentPoolInternalAPI pool *state.AgentPoolInternalAPI
myName string myName string
whitelist []string
blacklist []string
} }
func (a *CallAgentAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) { func (a *CallAgentAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) {
@@ -83,13 +113,32 @@ func (a *CallAgentAction) Run(ctx context.Context, params types.ActionParams) (t
return types.ActionResult{Result: resp.Response, Metadata: metadata}, nil return types.ActionResult{Result: resp.Response, Metadata: metadata}, nil
} }
func (a *CallAgentAction) isAllowedToBeCalled(agentName string) bool {
if agentName == a.myName {
return false
}
if len(a.whitelist) > 0 && len(a.blacklist) > 0 {
return slices.Contains(a.whitelist, agentName) && !slices.Contains(a.blacklist, agentName)
}
if len(a.whitelist) > 0 {
return slices.Contains(a.whitelist, agentName)
}
if len(a.blacklist) > 0 {
return !slices.Contains(a.blacklist, agentName)
}
return true
}
func (a *CallAgentAction) Definition() types.ActionDefinition { func (a *CallAgentAction) Definition() types.ActionDefinition {
allAgents := a.pool.AllAgents() allAgents := a.pool.AllAgents()
agents := []string{} agents := []string{}
for _, ag := range allAgents { for _, ag := range allAgents {
if ag != a.myName { if a.isAllowedToBeCalled(ag) {
agents = append(agents, ag) agents = append(agents, ag)
} }
} }
@@ -125,3 +174,21 @@ func (a *CallAgentAction) Definition() types.ActionDefinition {
func (a *CallAgentAction) Plannable() bool { func (a *CallAgentAction) Plannable() bool {
return true return true
} }
func CallAgentConfigMeta() []config.Field {
return []config.Field{
{
Name: "whitelist",
Label: "Whitelist",
Type: config.FieldTypeText,
Required: false,
HelpText: "Comma-separated list of agent names to call. If not specified, all agents are allowed.",
},
{
Name: "blacklist",
Label: "Blacklist",
Type: config.FieldTypeText,
HelpText: "Comma-separated list of agent names to exclude from the call. If not specified, all agents are allowed.",
},
}
}