diff --git a/core/action/action_suite_test.go b/core/action/action_suite_test.go new file mode 100644 index 0000000..ea3ec2a --- /dev/null +++ b/core/action/action_suite_test.go @@ -0,0 +1,13 @@ +package action_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestAction(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Agent Action test suite") +} diff --git a/core/action/custom.go b/core/action/custom.go index 3bd34b6..7ca84c4 100644 --- a/core/action/custom.go +++ b/core/action/custom.go @@ -75,15 +75,16 @@ func (a *CustomAction) initializeInterpreter() error { return nil } -func (a *CustomAction) Run(ctx context.Context, params ActionParams) (string, error) { +func (a *CustomAction) Run(ctx context.Context, params ActionParams) (ActionResult, error) { v, err := a.i.Eval(fmt.Sprintf("%s.Run", a.config["name"])) if err != nil { - return "", err + return ActionResult{}, err } - run := v.Interface().(func(map[string]interface{}) (string, error)) + run := v.Interface().(func(map[string]interface{}) (string, map[string]interface{}, error)) - return run(params) + res, meta, err := run(params) + return ActionResult{Result: res, Metadata: meta}, err } func (a *CustomAction) Definition() ActionDefinition { diff --git a/core/action/custom_test.go b/core/action/custom_test.go index 743e965..4d780a0 100644 --- a/core/action/custom_test.go +++ b/core/action/custom_test.go @@ -23,18 +23,18 @@ type Params struct { Foo string } -func Run(config map[string]interface{}) (string, error) { +func Run(config map[string]interface{}) (string, map[string]interface{}, error) { p := Params{} b, err := json.Marshal(config) if err != nil { - return "", err + return "",map[string]interface{}{}, err } if err := json.Unmarshal(b, &p); err != nil { - return "", err + return "",map[string]interface{}{}, err } -return p.Foo, nil +return p.Foo,map[string]interface{}{}, nil } func Definition() map[string][]string { @@ -79,7 +79,7 @@ return []string{"foo"} "Foo": "bar", }) Expect(err).ToNot(HaveOccurred()) - Expect(runResult).To(Equal("bar")) + Expect(runResult.Result).To(Equal("bar")) }) }) diff --git a/core/action/definition.go b/core/action/definition.go index 5179f41..58ac980 100644 --- a/core/action/definition.go +++ b/core/action/definition.go @@ -28,6 +28,11 @@ func NewContext(ctx context.Context, cancel context.CancelFunc) *ActionContext { type ActionParams map[string]interface{} +type ActionResult struct { + Result string + Metadata map[string]interface{} +} + func (ap ActionParams) Read(s string) error { err := json.Unmarshal([]byte(s), &ap) return err diff --git a/core/action/intention.go b/core/action/intention.go index 968e8f0..95c540a 100644 --- a/core/action/intention.go +++ b/core/action/intention.go @@ -21,8 +21,8 @@ type IntentResponse struct { Reasoning string `json:"reasoning"` } -func (a *IntentAction) Run(context.Context, ActionParams) (string, error) { - return "no-op", nil +func (a *IntentAction) Run(context.Context, ActionParams) (ActionResult, error) { + return ActionResult{}, nil } func (a *IntentAction) Definition() ActionDefinition { diff --git a/core/action/newconversation.go b/core/action/newconversation.go index 5d0f94b..067f54f 100644 --- a/core/action/newconversation.go +++ b/core/action/newconversation.go @@ -18,8 +18,8 @@ type ConversationActionResponse struct { Message string `json:"message"` } -func (a *ConversationAction) Run(context.Context, ActionParams) (string, error) { - return "no-op", nil +func (a *ConversationAction) Run(context.Context, ActionParams) (ActionResult, error) { + return ActionResult{}, nil } func (a *ConversationAction) Definition() ActionDefinition { diff --git a/core/action/noreply.go b/core/action/noreply.go index e6cc755..af00465 100644 --- a/core/action/noreply.go +++ b/core/action/noreply.go @@ -12,8 +12,8 @@ func NewStop() *StopAction { type StopAction struct{} -func (a *StopAction) Run(context.Context, ActionParams) (string, error) { - return "no-op", nil +func (a *StopAction) Run(context.Context, ActionParams) (ActionResult, error) { + return ActionResult{}, nil } func (a *StopAction) Definition() ActionDefinition { diff --git a/core/action/reasoning.go b/core/action/reasoning.go index 122e222..acd3262 100644 --- a/core/action/reasoning.go +++ b/core/action/reasoning.go @@ -19,8 +19,8 @@ type ReasoningResponse struct { Reasoning string `json:"reasoning"` } -func (a *ReasoningAction) Run(context.Context, ActionParams) (string, error) { - return "no-op", nil +func (a *ReasoningAction) Run(context.Context, ActionParams) (ActionResult, error) { + return ActionResult{}, nil } func (a *ReasoningAction) Definition() ActionDefinition { diff --git a/core/action/state.go b/core/action/state.go index f4d8901..3e79a3b 100644 --- a/core/action/state.go +++ b/core/action/state.go @@ -33,8 +33,8 @@ type StateResult struct { Goal string `json:"goal"` } -func (a *StateAction) Run(context.Context, ActionParams) (string, error) { - return "internal state has been updated", nil +func (a *StateAction) Run(context.Context, ActionParams) (ActionResult, error) { + return ActionResult{Result: "internal state has been updated"}, nil } func (a *StateAction) Definition() ActionDefinition { diff --git a/core/agent/actions.go b/core/agent/actions.go index 7576c89..aec695f 100644 --- a/core/agent/actions.go +++ b/core/agent/actions.go @@ -23,7 +23,7 @@ type ActionCurrentState struct { // Actions is something the agent can do type Action interface { - Run(ctx context.Context, action action.ActionParams) (string, error) + Run(ctx context.Context, action action.ActionParams) (action.ActionResult, error) Definition() action.ActionDefinition } diff --git a/core/agent/agent.go b/core/agent/agent.go index 8c46c51..d27f48f 100644 --- a/core/agent/agent.go +++ b/core/agent/agent.go @@ -242,9 +242,12 @@ func (a *Agent) Memory() RAGDB { func (a *Agent) runAction(chosenAction Action, params action.ActionParams) (result string, err error) { for _, action := range a.systemInternalActions() { if action.Definition().Name == chosenAction.Definition().Name { - if result, err = action.Run(a.actionContext, params); err != nil { + res, err := action.Run(a.actionContext, params) + if err != nil { return "", fmt.Errorf("error running action: %w", err) } + + result = res.Result } } diff --git a/core/agent/agent_test.go b/core/agent/agent_test.go index 3b468bd..d90f1c4 100644 --- a/core/agent/agent_test.go +++ b/core/agent/agent_test.go @@ -36,7 +36,7 @@ type TestAction struct { responseN int } -func (a *TestAction) Run(context.Context, action.ActionParams) (string, error) { +func (a *TestAction) Run(context.Context, action.ActionParams) (action.ActionResult, error) { res := a.response[a.responseN] a.responseN++ @@ -44,7 +44,7 @@ func (a *TestAction) Run(context.Context, action.ActionParams) (string, error) { a.responseN = 0 } - return res, nil + return action.ActionResult{Result: res}, nil } func (a *TestAction) Definition() action.ActionDefinition { diff --git a/services/actions/browse.go b/services/actions/browse.go index 1057982..c99d738 100644 --- a/services/actions/browse.go +++ b/services/actions/browse.go @@ -18,7 +18,7 @@ func NewBrowse(config map[string]string) *BrowseAction { type BrowseAction struct{} -func (a *BrowseAction) Run(ctx context.Context, params action.ActionParams) (string, error) { +func (a *BrowseAction) Run(ctx context.Context, params action.ActionParams) (action.ActionResult, error) { result := struct { URL string `json:"url"` }{} @@ -26,31 +26,31 @@ func (a *BrowseAction) Run(ctx context.Context, params action.ActionParams) (str if err != nil { fmt.Printf("error: %v", err) - return "", err + return action.ActionResult{}, err } // download page with http.Client client := &http.Client{} req, err := http.NewRequest("GET", result.URL, nil) if err != nil { - return "", err + return action.ActionResult{}, err } resp, err := client.Do(req) if err != nil { - return "", err + return action.ActionResult{}, err } defer resp.Body.Close() pagebyte, err := io.ReadAll(resp.Body) if err != nil { - return "", err + return action.ActionResult{}, err } rendered, err := html2text.FromString(string(pagebyte), html2text.Options{PrettyTables: true}) if err != nil { - return "", err + return action.ActionResult{}, err } - return fmt.Sprintf("The webpage '%s' content is:\n%s", result.URL, rendered), nil + return action.ActionResult{Result: fmt.Sprintf("The webpage '%s' content is:\n%s", result.URL, rendered)}, nil } func (a *BrowseAction) Definition() action.ActionDefinition { diff --git a/services/actions/genimage.go b/services/actions/genimage.go index 76bdce4..f17ea87 100644 --- a/services/actions/genimage.go +++ b/services/actions/genimage.go @@ -24,18 +24,18 @@ type GenImageAction struct { imageModel string } -func (a *GenImageAction) Run(ctx context.Context, params action.ActionParams) (string, error) { +func (a *GenImageAction) Run(ctx context.Context, params action.ActionParams) (action.ActionResult, error) { result := struct { Prompt string `json:"prompt"` Size string `json:"size"` }{} err := params.Unmarshal(&result) if err != nil { - return "", err + return action.ActionResult{}, err } if result.Prompt == "" { - return "", fmt.Errorf("prompt is required") + return action.ActionResult{}, fmt.Errorf("prompt is required") } req := openai.ImageRequest{ @@ -56,14 +56,17 @@ func (a *GenImageAction) Run(ctx context.Context, params action.ActionParams) (s resp, err := a.client.CreateImage(ctx, req) if err != nil { - return "Failed to generate image " + err.Error(), err + return action.ActionResult{Result: "Failed to generate image " + err.Error()}, err } if len(resp.Data) == 0 { - return "Failed to generate image", nil + return action.ActionResult{Result: "Failed to generate image"}, nil } - return fmt.Sprintf("The image was generated and available at: %s", resp.Data[0].URL), nil + return action.ActionResult{ + Result: fmt.Sprintf("The image was generated and available at: %s", resp.Data[0].URL), Metadata: map[string]interface{}{ + "url": resp.Data[0].URL, + }}, nil } func (a *GenImageAction) Definition() action.ActionDefinition { diff --git a/services/actions/githubissuecloser.go b/services/actions/githubissuecloser.go index 6468617..206c107 100644 --- a/services/actions/githubissuecloser.go +++ b/services/actions/githubissuecloser.go @@ -24,7 +24,7 @@ func NewGithubIssueCloser(ctx context.Context, config map[string]string) *Github } } -func (g *GithubIssuesCloser) Run(ctx context.Context, params action.ActionParams) (string, error) { +func (g *GithubIssuesCloser) Run(ctx context.Context, params action.ActionParams) (action.ActionResult, error) { result := struct { Repository string `json:"repository"` Owner string `json:"owner"` @@ -34,7 +34,7 @@ func (g *GithubIssuesCloser) Run(ctx context.Context, params action.ActionParams if err != nil { fmt.Printf("error: %v", err) - return "", err + return action.ActionResult{}, err } // _, _, err = g.client.Issues.CreateComment( @@ -57,14 +57,14 @@ func (g *GithubIssuesCloser) Run(ctx context.Context, params action.ActionParams if err != nil { fmt.Printf("error: %v", err) - return "", err + return action.ActionResult{}, err } resultString := fmt.Sprintf("Closed issue %d in repository %s/%s", result.IssueNumber, result.Owner, result.Repository) if err != nil { resultString = fmt.Sprintf("Error closing issue %d in repository %s/%s: %v", result.IssueNumber, result.Owner, result.Repository, err) } - return resultString, err + return action.ActionResult{Result: resultString}, err } func (g *GithubIssuesCloser) Definition() action.ActionDefinition { diff --git a/services/actions/githubissuelabeler.go b/services/actions/githubissuelabeler.go index f67dd35..db93ebd 100644 --- a/services/actions/githubissuelabeler.go +++ b/services/actions/githubissuelabeler.go @@ -36,7 +36,7 @@ func NewGithubIssueLabeler(ctx context.Context, config map[string]string) *Githu } } -func (g *GithubIssuesLabeler) Run(ctx context.Context, params action.ActionParams) (string, error) { +func (g *GithubIssuesLabeler) Run(ctx context.Context, params action.ActionParams) (action.ActionResult, error) { result := struct { Repository string `json:"repository"` Owner string `json:"owner"` @@ -45,9 +45,7 @@ func (g *GithubIssuesLabeler) Run(ctx context.Context, params action.ActionParam }{} err := params.Unmarshal(&result) if err != nil { - fmt.Printf("error: %v", err) - - return "", err + return action.ActionResult{}, err } labels, _, err := g.client.Issues.AddLabelsToIssue(g.context, result.Owner, result.Repository, result.IssueNumber, []string{result.Label}) @@ -61,7 +59,7 @@ func (g *GithubIssuesLabeler) Run(ctx context.Context, params action.ActionParam if err != nil { resultString = fmt.Sprintf("Error adding label '%s' to issue %d in repository %s/%s: %v", result.Label, result.IssueNumber, result.Owner, result.Repository, err) } - return resultString, err + return action.ActionResult{Result: resultString}, err } func (g *GithubIssuesLabeler) Definition() action.ActionDefinition { diff --git a/services/actions/githubissueopener.go b/services/actions/githubissueopener.go index 4e4548c..7a70273 100644 --- a/services/actions/githubissueopener.go +++ b/services/actions/githubissueopener.go @@ -25,7 +25,7 @@ func NewGithubIssueOpener(ctx context.Context, config map[string]string) *Github } } -func (g *GithubIssuesOpener) Run(ctx context.Context, params action.ActionParams) (string, error) { +func (g *GithubIssuesOpener) Run(ctx context.Context, params action.ActionParams) (action.ActionResult, error) { result := struct { Title string `json:"title"` Body string `json:"text"` @@ -36,7 +36,7 @@ func (g *GithubIssuesOpener) Run(ctx context.Context, params action.ActionParams if err != nil { fmt.Printf("error: %v", err) - return "", err + return action.ActionResult{}, err } issue := &github.IssueRequest{ @@ -52,7 +52,7 @@ func (g *GithubIssuesOpener) Run(ctx context.Context, params action.ActionParams resultString = fmt.Sprintf("Created issue %d in repository %s/%s", createdIssue.GetNumber(), result.Owner, result.Repository) } - return resultString, err + return action.ActionResult{Result: resultString}, err } func (g *GithubIssuesOpener) Definition() action.ActionDefinition { diff --git a/services/actions/githubissuesearch.go b/services/actions/githubissuesearch.go index 0e3beea..dbef041 100644 --- a/services/actions/githubissuesearch.go +++ b/services/actions/githubissuesearch.go @@ -26,7 +26,7 @@ func NewGithubIssueSearch(ctx context.Context, config map[string]string) *Github } } -func (g *GithubIssueSearch) Run(ctx context.Context, params action.ActionParams) (string, error) { +func (g *GithubIssueSearch) Run(ctx context.Context, params action.ActionParams) (action.ActionResult, error) { result := struct { Query string `json:"query"` Repository string `json:"repository"` @@ -36,7 +36,7 @@ func (g *GithubIssueSearch) Run(ctx context.Context, params action.ActionParams) if err != nil { fmt.Printf("error: %v", err) - return "", err + return action.ActionResult{}, err } query := fmt.Sprintf("%s in:%s user:%s", result.Query, result.Repository, result.Owner) @@ -48,7 +48,7 @@ func (g *GithubIssueSearch) Run(ctx context.Context, params action.ActionParams) }) if err != nil { resultString = fmt.Sprintf("Error listing issues: %v", err) - return resultString, err + return action.ActionResult{Result: resultString}, err } for _, i := range issues.Issues { xlog.Info("Issue found", "title", i.GetTitle()) @@ -57,7 +57,7 @@ func (g *GithubIssueSearch) Run(ctx context.Context, params action.ActionParams) // resultString += fmt.Sprintf("Body: %s\n", i.GetBody()) } - return resultString, err + return action.ActionResult{Result: resultString}, err } func (g *GithubIssueSearch) Definition() action.ActionDefinition { diff --git a/services/actions/scrape.go b/services/actions/scrape.go index a2003f7..8f69020 100644 --- a/services/actions/scrape.go +++ b/services/actions/scrape.go @@ -16,7 +16,7 @@ func NewScraper(config map[string]string) *ScraperAction { type ScraperAction struct{} -func (a *ScraperAction) Run(ctx context.Context, params action.ActionParams) (string, error) { +func (a *ScraperAction) Run(ctx context.Context, params action.ActionParams) (action.ActionResult, error) { result := struct { URL string `json:"url"` }{} @@ -24,15 +24,21 @@ func (a *ScraperAction) Run(ctx context.Context, params action.ActionParams) (st if err != nil { fmt.Printf("error: %v", err) - return "", err + return action.ActionResult{}, err } scraper, err := scraper.New() if err != nil { fmt.Printf("error: %v", err) - return "", err + return action.ActionResult{}, err } - return scraper.Call(ctx, result.URL) + res, err := scraper.Call(ctx, result.URL) + if err != nil { + fmt.Printf("error: %v", err) + + return action.ActionResult{}, err + } + return action.ActionResult{Result: res}, nil } func (a *ScraperAction) Definition() action.ActionDefinition { diff --git a/services/actions/search.go b/services/actions/search.go index db54177..15d136b 100644 --- a/services/actions/search.go +++ b/services/actions/search.go @@ -28,7 +28,7 @@ func NewSearch(config map[string]string) *SearchAction { type SearchAction struct{ results int } -func (a *SearchAction) Run(ctx context.Context, params action.ActionParams) (string, error) { +func (a *SearchAction) Run(ctx context.Context, params action.ActionParams) (action.ActionResult, error) { result := struct { Query string `json:"query"` }{} @@ -36,15 +36,21 @@ func (a *SearchAction) Run(ctx context.Context, params action.ActionParams) (str if err != nil { fmt.Printf("error: %v", err) - return "", err + return action.ActionResult{}, err } ddg, err := duckduckgo.New(a.results, "LocalAgent") if err != nil { fmt.Printf("error: %v", err) - return "", err + return action.ActionResult{}, err } - return ddg.Call(ctx, result.Query) + res, err := ddg.Call(ctx, result.Query) + if err != nil { + fmt.Printf("error: %v", err) + + return action.ActionResult{}, err + } + return action.ActionResult{Result: res}, nil } func (a *SearchAction) Definition() action.ActionDefinition { diff --git a/services/actions/sendmail.go b/services/actions/sendmail.go index b67818d..1b23742 100644 --- a/services/actions/sendmail.go +++ b/services/actions/sendmail.go @@ -27,7 +27,7 @@ type SendMailAction struct { smtpPort string } -func (a *SendMailAction) Run(ctx context.Context, params action.ActionParams) (string, error) { +func (a *SendMailAction) Run(ctx context.Context, params action.ActionParams) (action.ActionResult, error) { result := struct { Message string `json:"message"` To string `json:"to"` @@ -37,7 +37,7 @@ func (a *SendMailAction) Run(ctx context.Context, params action.ActionParams) (s if err != nil { fmt.Printf("error: %v", err) - return "", err + return action.ActionResult{}, err } // Authentication. @@ -50,9 +50,9 @@ func (a *SendMailAction) Run(ctx context.Context, params action.ActionParams) (s result.To, }, []byte(result.Message)) if err != nil { - return "", err + return action.ActionResult{}, err } - return fmt.Sprintf("Email sent to %s", result.To), nil + return action.ActionResult{Result: fmt.Sprintf("Email sent to %s", result.To)}, nil } func (a *SendMailAction) Definition() action.ActionDefinition { diff --git a/services/actions/wikipedia.go b/services/actions/wikipedia.go index d6051b2..936b5cd 100644 --- a/services/actions/wikipedia.go +++ b/services/actions/wikipedia.go @@ -15,7 +15,7 @@ func NewWikipedia(config map[string]string) *WikipediaAction { type WikipediaAction struct{} -func (a *WikipediaAction) Run(ctx context.Context, params action.ActionParams) (string, error) { +func (a *WikipediaAction) Run(ctx context.Context, params action.ActionParams) (action.ActionResult, error) { result := struct { Query string `json:"query"` }{} @@ -23,10 +23,16 @@ func (a *WikipediaAction) Run(ctx context.Context, params action.ActionParams) ( if err != nil { fmt.Printf("error: %v", err) - return "", err + return action.ActionResult{}, err } wiki := wikipedia.New("LocalAgent") - return wiki.Call(ctx, result.Query) + res, err := wiki.Call(ctx, result.Query) + if err != nil { + fmt.Printf("error: %v", err) + + return action.ActionResult{}, err + } + return action.ActionResult{Result: res}, nil } func (a *WikipediaAction) Definition() action.ActionDefinition {