feat(call_agents): allow to specify whitelist and blacklist agents (#144)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
324124e002
commit
289edb67a6
@@ -339,7 +339,7 @@ func ActionsConfigMeta() []config.FieldGroup {
|
|||||||
{
|
{
|
||||||
Name: "call_agents",
|
Name: "call_agents",
|
||||||
Label: "Call Agents",
|
Label: "Call Agents",
|
||||||
Fields: []config.Field{},
|
Fields: actions.CallAgentConfigMeta(),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,23 +3,53 @@ package actions
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/mudler/LocalAGI/core/state"
|
"github.com/mudler/LocalAGI/core/state"
|
||||||
"github.com/mudler/LocalAGI/core/types"
|
"github.com/mudler/LocalAGI/core/types"
|
||||||
|
"github.com/mudler/LocalAGI/pkg/config"
|
||||||
"github.com/sashabaranov/go-openai"
|
"github.com/sashabaranov/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/jsonschema"
|
"github.com/sashabaranov/go-openai/jsonschema"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func trimList(list []string) []string {
|
||||||
|
for i, v := range list {
|
||||||
|
list[i] = strings.TrimSpace(v)
|
||||||
|
}
|
||||||
|
return list
|
||||||
|
}
|
||||||
|
|
||||||
func NewCallAgent(config map[string]string, agentName string, pool *state.AgentPoolInternalAPI) *CallAgentAction {
|
func NewCallAgent(config map[string]string, agentName string, pool *state.AgentPoolInternalAPI) *CallAgentAction {
|
||||||
|
whitelist := []string{}
|
||||||
|
blacklist := []string{}
|
||||||
|
if v, ok := config["whitelist"]; ok {
|
||||||
|
if strings.Contains(v, ",") {
|
||||||
|
whitelist = trimList(strings.Split(v, ","))
|
||||||
|
} else {
|
||||||
|
whitelist = []string{v}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if v, ok := config["blacklist"]; ok {
|
||||||
|
if strings.Contains(v, ",") {
|
||||||
|
blacklist = trimList(strings.Split(v, ","))
|
||||||
|
} else {
|
||||||
|
blacklist = []string{v}
|
||||||
|
}
|
||||||
|
}
|
||||||
return &CallAgentAction{
|
return &CallAgentAction{
|
||||||
pool: pool,
|
pool: pool,
|
||||||
myName: agentName,
|
myName: agentName,
|
||||||
|
whitelist: whitelist,
|
||||||
|
blacklist: blacklist,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type CallAgentAction struct {
|
type CallAgentAction struct {
|
||||||
pool *state.AgentPoolInternalAPI
|
pool *state.AgentPoolInternalAPI
|
||||||
myName string
|
myName string
|
||||||
|
whitelist []string
|
||||||
|
blacklist []string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *CallAgentAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) {
|
func (a *CallAgentAction) Run(ctx context.Context, params types.ActionParams) (types.ActionResult, error) {
|
||||||
@@ -83,13 +113,32 @@ func (a *CallAgentAction) Run(ctx context.Context, params types.ActionParams) (t
|
|||||||
return types.ActionResult{Result: resp.Response, Metadata: metadata}, nil
|
return types.ActionResult{Result: resp.Response, Metadata: metadata}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *CallAgentAction) isAllowedToBeCalled(agentName string) bool {
|
||||||
|
if agentName == a.myName {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(a.whitelist) > 0 && len(a.blacklist) > 0 {
|
||||||
|
return slices.Contains(a.whitelist, agentName) && !slices.Contains(a.blacklist, agentName)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(a.whitelist) > 0 {
|
||||||
|
return slices.Contains(a.whitelist, agentName)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(a.blacklist) > 0 {
|
||||||
|
return !slices.Contains(a.blacklist, agentName)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func (a *CallAgentAction) Definition() types.ActionDefinition {
|
func (a *CallAgentAction) Definition() types.ActionDefinition {
|
||||||
allAgents := a.pool.AllAgents()
|
allAgents := a.pool.AllAgents()
|
||||||
|
|
||||||
agents := []string{}
|
agents := []string{}
|
||||||
|
|
||||||
for _, ag := range allAgents {
|
for _, ag := range allAgents {
|
||||||
if ag != a.myName {
|
if a.isAllowedToBeCalled(ag) {
|
||||||
agents = append(agents, ag)
|
agents = append(agents, ag)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -125,3 +174,21 @@ func (a *CallAgentAction) Definition() types.ActionDefinition {
|
|||||||
func (a *CallAgentAction) Plannable() bool {
|
func (a *CallAgentAction) Plannable() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func CallAgentConfigMeta() []config.Field {
|
||||||
|
return []config.Field{
|
||||||
|
{
|
||||||
|
Name: "whitelist",
|
||||||
|
Label: "Whitelist",
|
||||||
|
Type: config.FieldTypeText,
|
||||||
|
Required: false,
|
||||||
|
HelpText: "Comma-separated list of agent names to call. If not specified, all agents are allowed.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "blacklist",
|
||||||
|
Label: "Blacklist",
|
||||||
|
Type: config.FieldTypeText,
|
||||||
|
HelpText: "Comma-separated list of agent names to exclude from the call. If not specified, all agents are allowed.",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user