diff --git a/README.md b/README.md index e562dc9..87d52e8 100644 --- a/README.md +++ b/README.md @@ -73,6 +73,8 @@ Still having issues? see this Youtube video: https://youtu.be/HtVwIxW3ePg [](https://youtu.be/HtVwIxW3ePg) [](https://youtu.be/v82rswGJt_M) +[](https://youtu.be/d_we-AYksSw) + ## 📚🆕 Local Stack Family diff --git a/core/agent/agent.go b/core/agent/agent.go index 80dacb6..658fec9 100644 --- a/core/agent/agent.go +++ b/core/agent/agent.go @@ -492,6 +492,73 @@ func (a *Agent) processUserInputs(job *types.Job, role string, conv Messages) Me return conv } +func (a *Agent) filterJob(job *types.Job) (ok bool, err error) { + hasTriggers := false + triggeredBy := "" + failedBy := "" + + if job.DoneFilter { + return true, nil + } + job.DoneFilter = true + + if len(a.options.jobFilters) < 1 { + xlog.Debug("No filters") + return true, nil + } + + for _, filter := range a.options.jobFilters { + name := filter.Name() + if triggeredBy != "" && filter.IsTrigger() { + continue + } + + ok, err = filter.Apply(job) + if err != nil { + xlog.Error("Error in job filter", "filter", name, "error", err) + failedBy = name + break + } + + if filter.IsTrigger() { + hasTriggers = true + if ok { + triggeredBy = name + xlog.Info("Job triggered by filter", "filter", name) + } + } else if !ok { + failedBy = name + xlog.Info("Job failed filter", "filter", name) + break + } else { + xlog.Debug("Job passed filter", "filter", name) + } + } + + if a.Observer() != nil { + obs := a.Observer().NewObservable() + obs.Name = "filter" + obs.Icon = "shield" + obs.ParentID = job.Obs.ID + if err == nil { + obs.Completion = &types.Completion{ + FilterResult: &types.FilterResult{ + HasTriggers: hasTriggers, + TriggeredBy: triggeredBy, + FailedBy: failedBy, + }, + } + } else { + obs.Completion = &types.Completion{ + Error: err.Error(), + } + } + a.Observer().Update(*obs) + } + + return failedBy == "" && (!hasTriggers || triggeredBy != ""), nil +} + func (a *Agent) consumeJob(job *types.Job, role string, retries int) { if err := job.GetContext().Err(); err != nil { @@ -533,6 +600,14 @@ func (a *Agent) consumeJob(job *types.Job, role string, retries int) { } conv = a.processPrompts(conv) + if ok, err := a.filterJob(job); !ok || err != nil { + if err != nil { + job.Result.Finish(fmt.Errorf("Error in job filter: %w", err)) + } else { + job.Result.Finish(nil) + } + return + } conv = a.processUserInputs(job, role, conv) // RAG diff --git a/core/agent/options.go b/core/agent/options.go index 55319a3..4943dad 100644 --- a/core/agent/options.go +++ b/core/agent/options.go @@ -24,6 +24,7 @@ type options struct { randomIdentityGuidance string randomIdentity bool userActions types.Actions + jobFilters types.JobFilters enableHUD, standaloneJob, showCharacter, enableKB, enableSummaryMemory, enableLongTermMemory bool stripThinkingTags bool @@ -373,6 +374,13 @@ func WithActions(actions ...types.Action) Option { } } +func WithJobFilters(filters ...types.JobFilter) Option { + return func(o *options) error { + o.jobFilters = filters + return nil + } +} + func WithObserver(observer Observer) Option { return func(o *options) error { o.observer = observer diff --git a/core/state/config.go b/core/state/config.go index 2a7370d..d807ef7 100644 --- a/core/state/config.go +++ b/core/state/config.go @@ -31,6 +31,11 @@ func (d DynamicPromptsConfig) ToMap() map[string]string { return config } +type FiltersConfig struct { + Type string `json:"type"` + Config string `json:"config"` +} + type AgentConfig struct { Connector []ConnectorConfig `json:"connectors" form:"connectors" ` Actions []ActionsConfig `json:"actions" form:"actions"` @@ -39,6 +44,7 @@ type AgentConfig struct { MCPSTDIOServers []agent.MCPSTDIOServer `json:"mcp_stdio_servers" form:"mcp_stdio_servers"` MCPPrepareScript string `json:"mcp_prepare_script" form:"mcp_prepare_script"` MCPBoxURL string `json:"mcp_box_url" form:"mcp_box_url"` + Filters []FiltersConfig `json:"filters" form:"filters"` Description string `json:"description" form:"description"` @@ -71,6 +77,7 @@ type AgentConfig struct { } type AgentConfigMeta struct { + Filters []config.FieldGroup Fields []config.Field Connectors []config.FieldGroup Actions []config.FieldGroup @@ -82,6 +89,7 @@ func NewAgentConfigMeta( actionsConfig []config.FieldGroup, connectorsConfig []config.FieldGroup, dynamicPromptsConfig []config.FieldGroup, + filtersConfig []config.FieldGroup, ) AgentConfigMeta { return AgentConfigMeta{ Fields: []config.Field{ @@ -319,6 +327,7 @@ func NewAgentConfigMeta( DynamicPrompts: dynamicPromptsConfig, Connectors: connectorsConfig, Actions: actionsConfig, + Filters: filtersConfig, } } diff --git a/core/state/pool.go b/core/state/pool.go index 97ad031..5cbac14 100644 --- a/core/state/pool.go +++ b/core/state/pool.go @@ -38,6 +38,7 @@ type AgentPool struct { availableActions func(*AgentConfig) func(ctx context.Context, pool *AgentPool) []types.Action connectors func(*AgentConfig) []Connector dynamicPrompt func(*AgentConfig) []DynamicPrompt + filters func(*AgentConfig) types.JobFilters timeout string conversationLogs string } @@ -78,6 +79,7 @@ func NewAgentPool( availableActions func(*AgentConfig) func(ctx context.Context, pool *AgentPool) []types.Action, connectors func(*AgentConfig) []Connector, promptBlocks func(*AgentConfig) []DynamicPrompt, + filters func(*AgentConfig) types.JobFilters, timeout string, withLogs bool, ) (*AgentPool, error) { @@ -110,6 +112,7 @@ func NewAgentPool( connectors: connectors, availableActions: availableActions, dynamicPrompt: promptBlocks, + filters: filters, timeout: timeout, conversationLogs: conversationPath, }, nil @@ -135,6 +138,7 @@ func NewAgentPool( connectors: connectors, localRAGAPI: LocalRAGAPI, dynamicPrompt: promptBlocks, + filters: filters, availableActions: availableActions, timeout: timeout, conversationLogs: conversationPath, @@ -337,6 +341,8 @@ func (a *AgentPool) startAgentWithConfig(name string, config *AgentConfig, obs O if config.Model != "" { model = config.Model + } else { + config.Model = model } if config.MCPBoxURL != "" { @@ -347,12 +353,17 @@ func (a *AgentPool) startAgentWithConfig(name string, config *AgentConfig, obs O config.PeriodicRuns = "10m" } + // XXX: Why do we update the pool config from an Agent's config? if config.APIURL != "" { a.apiURL = config.APIURL + } else { + config.APIURL = a.apiURL } if config.APIKey != "" { a.apiKey = config.APIKey + } else { + config.APIKey = a.apiKey } if config.LocalRAGURL != "" { @@ -366,6 +377,7 @@ func (a *AgentPool) startAgentWithConfig(name string, config *AgentConfig, obs O connectors := a.connectors(config) promptBlocks := a.dynamicPrompt(config) actions := a.availableActions(config)(ctx, a) + filters := a.filters(config) stateFile, characterFile := a.stateFiles(name) actionsLog := []string{} @@ -378,6 +390,11 @@ func (a *AgentPool) startAgentWithConfig(name string, config *AgentConfig, obs O connectorLog = append(connectorLog, fmt.Sprintf("%+v", connector)) } + filtersLog := []string{} + for _, filter := range filters { + filtersLog = append(filtersLog, filter.Name()) + } + xlog.Info( "Creating agent", "name", name, @@ -385,6 +402,7 @@ func (a *AgentPool) startAgentWithConfig(name string, config *AgentConfig, obs O "api_url", a.apiURL, "actions", actionsLog, "connectors", connectorLog, + "filters", filtersLog, ) // dynamicPrompts := []map[string]string{} @@ -406,6 +424,7 @@ func (a *AgentPool) startAgentWithConfig(name string, config *AgentConfig, obs O WithMCPSTDIOServers(config.MCPSTDIOServers...), WithMCPBoxURL(a.mcpBoxURL), WithPrompts(promptBlocks...), + WithJobFilters(filters...), WithMCPPrepareScript(config.MCPPrepareScript), // WithDynamicPrompts(dynamicPrompts...), WithCharacter(Character{ diff --git a/core/types/filters.go b/core/types/filters.go new file mode 100644 index 0000000..dbcf585 --- /dev/null +++ b/core/types/filters.go @@ -0,0 +1,15 @@ +package types + +type JobFilter interface { + Name() string + Apply(job *Job) (bool, error) + IsTrigger() bool +} + +type JobFilters []JobFilter + +type FilterResult struct { + HasTriggers bool `json:"has_triggers"` + TriggeredBy string `json:"triggered_by,omitempty"` + FailedBy string `json:"failed_by,omitempty"` +} diff --git a/core/types/job.go b/core/types/job.go index c699c2e..7a48ee1 100644 --- a/core/types/job.go +++ b/core/types/job.go @@ -19,6 +19,7 @@ type Job struct { ConversationHistory []openai.ChatCompletionMessage UUID string Metadata map[string]interface{} + DoneFilter bool pastActions []*ActionRequest nextAction *Action diff --git a/core/types/observable.go b/core/types/observable.go index 48844e9..3991476 100644 --- a/core/types/observable.go +++ b/core/types/observable.go @@ -24,7 +24,8 @@ type Completion struct { ChatCompletionResponse *openai.ChatCompletionResponse `json:"chat_completion_response,omitempty"` Conversation []openai.ChatCompletionMessage `json:"conversation,omitempty"` ActionResult string `json:"action_result,omitempty"` - AgentState *AgentInternalState `json:"agent_state"` + AgentState *AgentInternalState `json:"agent_state,omitempty"` + FilterResult *FilterResult `json:"filter_result,omitempty"` } type Observable struct { diff --git a/main.go b/main.go index c20c815..fc306d6 100644 --- a/main.go +++ b/main.go @@ -70,6 +70,7 @@ func main() { }), services.Connectors, services.DynamicPrompts, + services.Filters, timeout, withLogs, ) diff --git a/services/filters.go b/services/filters.go new file mode 100644 index 0000000..a1d0c3a --- /dev/null +++ b/services/filters.go @@ -0,0 +1,44 @@ +package services + +import ( + "github.com/mudler/LocalAGI/core/state" + "github.com/mudler/LocalAGI/core/types" + "github.com/mudler/LocalAGI/pkg/config" + "github.com/mudler/LocalAGI/pkg/xlog" + "github.com/mudler/LocalAGI/services/filters" +) + +func Filters(a *state.AgentConfig) types.JobFilters { + var result []types.JobFilter + for _, f := range a.Filters { + var filter types.JobFilter + var err error + switch f.Type { + case filters.FilterRegex: + filter, err = filters.NewRegexFilter(f.Config) + if err != nil { + xlog.Error("Failed to configure regex", "err", err.Error()) + continue + } + case filters.FilterClassifier: + filter, err = filters.NewClassifierFilter(f.Config, a) + if err != nil { + xlog.Error("failed to configure classifier", "err", err.Error()) + continue + } + default: + xlog.Error("Unrecognized filter type", "type", f.Type) + continue + } + result = append(result, filter) + } + return result +} + +// FiltersConfigMeta returns all filter config metas for UI. +func FiltersConfigMeta() []config.FieldGroup { + return []config.FieldGroup{ + filters.RegexFilterConfigMeta(), + filters.ClassifierFilterConfigMeta(), + } +} diff --git a/services/filters/classifier.go b/services/filters/classifier.go new file mode 100644 index 0000000..d85aa4f --- /dev/null +++ b/services/filters/classifier.go @@ -0,0 +1,121 @@ +package filters + +import ( + "encoding/json" + "fmt" + + "github.com/mudler/LocalAGI/core/state" + "github.com/mudler/LocalAGI/core/types" + "github.com/mudler/LocalAGI/pkg/config" + "github.com/mudler/LocalAGI/pkg/llm" + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/jsonschema" +) + +const FilterClassifier = "classifier" + +type ClassifierFilter struct { + name string + client *openai.Client + model string + description string + allowOnMatch bool + isTrigger bool +} + +type ClassifierFilterConfig struct { + Name string `json:"name"` + Model string `json:"model,omitempty"` + APIURL string `json:"api_url,omitempty"` + Description string `json:"description"` + AllowOnMatch bool `json:"allow_on_match"` + IsTrigger bool `json:"is_trigger"` +} + +func NewClassifierFilter(configJSON string, a *state.AgentConfig) (*ClassifierFilter, error) { + var cfg ClassifierFilterConfig + if err := json.Unmarshal([]byte(configJSON), &cfg); err != nil { + return nil, err + } + var model string + if cfg.Model != "" { + model = cfg.Model + } else { + model = a.Model + } + if cfg.Name == "" { + return nil, fmt.Errorf("Classifier with no name") + } + if cfg.Description == "" { + return nil, fmt.Errorf("%s classifier has no description", cfg.Name) + } + apiUrl := a.APIURL + if cfg.APIURL != "" { + apiUrl = cfg.APIURL + } + client := llm.NewClient(a.APIKey, apiUrl, "1m") + + return &ClassifierFilter{ + name: cfg.Name, + model: model, + description: cfg.Description, + client: client, + allowOnMatch: cfg.AllowOnMatch, + isTrigger: cfg.IsTrigger, + }, nil +} + +const fmtT = ` + Does the below message fit the description "%s" + + %s + ` + +func (f *ClassifierFilter) Name() string { return f.name } +func (f *ClassifierFilter) Apply(job *types.Job) (bool, error) { + input := extractInputFromJob(job) + guidance := fmt.Sprintf(fmtT, f.description, input) + var result struct { + Asserted bool `json:"answer"` + } + err := llm.GenerateTypedJSON(job.GetContext(), f.client, guidance, f.model, jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "answer": { + Type: jsonschema.Boolean, + Description: "The answer to the first question", + }, + }, + Required: []string{"answer"}, + }, &result) + if err != nil { + return false, err + } + + if result.Asserted { + return f.allowOnMatch, nil + } + return !f.allowOnMatch, nil +} + +func (f *ClassifierFilter) IsTrigger() bool { + return f.isTrigger +} + +func ClassifierFilterConfigMeta() config.FieldGroup { + return config.FieldGroup{ + Name: FilterClassifier, + Label: "Classifier Filter/Trigger", + Fields: []config.Field{ + {Name: "name", Label: "Name", Type: "text", Required: true}, + {Name: "model", Label: "Model", Type: "text", Required: false, + HelpText: "The LLM to use, usually a smaller one. Leave blank to use the same as the agent's"}, + {Name: "api_url", Label: "API URL", Type: "url", Required: false, + HelpText: "The URL of the LLM service if different from the agent's"}, + {Name: "description", Label: "Description", Type: "text", Required: true, + HelpText: "Describe the type of content to match against e.g. 'technical support request'"}, + {Name: "allow_on_match", Label: "Allow on Match", Type: "checkbox", Required: true}, + {Name: "is_trigger", Label: "Is Trigger", Type: "checkbox", Required: true}, + }, + } +} diff --git a/services/filters/regex.go b/services/filters/regex.go new file mode 100644 index 0000000..db12e0c --- /dev/null +++ b/services/filters/regex.go @@ -0,0 +1,86 @@ +package filters + +import ( + "encoding/json" + "regexp" + + "github.com/mudler/LocalAGI/core/types" + "github.com/mudler/LocalAGI/pkg/config" +) + +const FilterRegex = "regex" + +type RegexFilter struct { + name string + pattern *regexp.Regexp + allowOnMatch bool + isTrigger bool +} + +type RegexFilterConfig struct { + Name string `json:"name"` + Pattern string `json:"pattern"` + AllowOnMatch bool `json:"allow_on_match"` + IsTrigger bool `json:"is_trigger"` +} + +func NewRegexFilter(configJSON string) (*RegexFilter, error) { + var cfg RegexFilterConfig + if err := json.Unmarshal([]byte(configJSON), &cfg); err != nil { + return nil, err + } + re, err := regexp.Compile(cfg.Pattern) + if err != nil { + return nil, err + } + return &RegexFilter{ + name: cfg.Name, + pattern: re, + allowOnMatch: cfg.AllowOnMatch, + isTrigger: cfg.IsTrigger, + }, nil +} + +func (f *RegexFilter) Name() string { return f.name } +func (f *RegexFilter) Apply(job *types.Job) (bool, error) { + input := extractInputFromJob(job) + if f.pattern.MatchString(input) { + return f.allowOnMatch, nil + } + return !f.allowOnMatch, nil +} + +func (f *RegexFilter) IsTrigger() bool { + return f.isTrigger +} + +func RegexFilterConfigMeta() config.FieldGroup { + return config.FieldGroup{ + Name: FilterRegex, + Label: "Regex Filter/Trigger", + Fields: []config.Field{ + {Name: "name", Label: "Name", Type: "text", Required: true}, + {Name: "pattern", Label: "Pattern", Type: "text", Required: true}, + {Name: "allow_on_match", Label: "Allow on Match", Type: "checkbox", Required: true}, + {Name: "is_trigger", Label: "Is Trigger", Type: "checkbox", Required: true}, + }, + } +} + +// extractInputFromJob attempts to extract a string input for filtering. +func extractInputFromJob(job *types.Job) string { + if job.Metadata != nil { + if v, ok := job.Metadata["input"]; ok { + if s, ok := v.(string); ok { + return s + } + } + } + // fallback: try to use conversation history if available + if len(job.ConversationHistory) > 0 { + // Use the last message content + last := job.ConversationHistory[len(job.ConversationHistory)-1] + return last.Content + } + return "" +} diff --git a/webui/app.go b/webui/app.go index 19632b4..acf5291 100644 --- a/webui/app.go +++ b/webui/app.go @@ -644,6 +644,7 @@ func (a *App) GetAgentConfigMeta() func(c *fiber.Ctx) error { services.ActionsConfigMeta(), services.ConnectorsConfigMeta(), services.DynamicPromptsConfigMeta(), + services.FiltersConfigMeta(), ) return c.JSON(configMeta) } diff --git a/webui/react-ui/src/components/AgentForm.jsx b/webui/react-ui/src/components/AgentForm.jsx index 94e5e64..df0d6a4 100644 --- a/webui/react-ui/src/components/AgentForm.jsx +++ b/webui/react-ui/src/components/AgentForm.jsx @@ -11,6 +11,7 @@ import ModelSettingsSection from './agent-form-sections/ModelSettingsSection'; import PromptsGoalsSection from './agent-form-sections/PromptsGoalsSection'; import AdvancedSettingsSection from './agent-form-sections/AdvancedSettingsSection'; import ExportSection from './agent-form-sections/ExportSection'; +import FiltersSection from './agent-form-sections/FiltersSection'; const AgentForm = ({ isEdit = false, @@ -189,6 +190,13 @@ const AgentForm = ({ Connectors +
+ Jobs received by the agent must pass all filters and at least one trigger (if any are specified) +
+ +