From 96091a1ad5331bcc78550c7213102c0964ead735 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Mon, 24 Feb 2025 23:21:51 +0100 Subject: [PATCH] Add custom actions with golang interpreter Signed-off-by: Ettore Di Giacinto --- agent/actions_custom.go | 126 +++++++++++++++++++++++++++++++++++ agent/actions_custom_test.go | 87 ++++++++++++++++++++++++ 2 files changed, 213 insertions(+) create mode 100644 agent/actions_custom.go create mode 100644 agent/actions_custom_test.go diff --git a/agent/actions_custom.go b/agent/actions_custom.go new file mode 100644 index 0000000..f2948f5 --- /dev/null +++ b/agent/actions_custom.go @@ -0,0 +1,126 @@ +package agent + +import ( + "context" + "fmt" + "strings" + + "github.com/mudler/local-agent-framework/action" + "github.com/mudler/local-agent-framework/xlog" + "github.com/sashabaranov/go-openai/jsonschema" + "github.com/traefik/yaegi/interp" + "github.com/traefik/yaegi/stdlib" +) + +func NewCustom(config map[string]string, goPkgPath string) (*CustomAction, error) { + a := &CustomAction{ + config: config, + goPkgPath: goPkgPath, + } + + if err := a.initializeInterpreter(); err != nil { + return nil, err + } + + if err := a.callInit(); err != nil { + xlog.Error("Error calling custom action init", "error", err) + } + + return a, nil +} + +type CustomAction struct { + config map[string]string + goPkgPath string + i *interp.Interpreter +} + +func (a *CustomAction) callInit() error { + if a.i == nil { + return nil + } + + v, err := a.i.Eval(fmt.Sprintf("%s.Init", a.config["name"])) + if err != nil { + return err + } + + run := v.Interface().(func() error) + + return run() +} + +func (a *CustomAction) initializeInterpreter() error { + if _, exists := a.config["code"]; exists && a.i == nil { + unsafe := strings.ToLower(a.config["unsafe"]) == "true" + i := interp.New(interp.Options{ + GoPath: a.goPkgPath, + Unrestricted: unsafe, + }) + if err := i.Use(stdlib.Symbols); err != nil { + return err + } + + if _, exists := a.config["name"]; !exists { + a.config["name"] = "custom" + } + + _, err := i.Eval(fmt.Sprintf("package %s\n%s", a.config["name"], a.config["code"])) + if err != nil { + return err + } + + a.i = i + } + + return nil +} + +func (a *CustomAction) Run(ctx context.Context, params action.ActionParams) (string, error) { + v, err := a.i.Eval(fmt.Sprintf("%s.Run", a.config["name"])) + if err != nil { + return "", err + } + + run := v.Interface().(func(map[string]interface{}) (string, error)) + + return run(params) +} + +func (a *CustomAction) Definition() action.ActionDefinition { + + v, err := a.i.Eval(fmt.Sprintf("%s.Definition", a.config["name"])) + if err != nil { + xlog.Error("Error getting custom action definition", "error", err) + return action.ActionDefinition{} + } + + properties := v.Interface().(func() map[string][]string) + + v, err = a.i.Eval(fmt.Sprintf("%s.RequiredFields", a.config["name"])) + if err != nil { + xlog.Error("Error getting custom action definition", "error", err) + return action.ActionDefinition{} + } + + requiredFields := v.Interface().(func() []string) + + prop := map[string]jsonschema.Definition{} + + for k, v := range properties() { + if len(v) != 2 { + xlog.Error("Invalid property definition", "property", k) + continue + } + prop[k] = jsonschema.Definition{ + Type: jsonschema.DataType(v[0]), + Description: v[1], + } + } + return action.ActionDefinition{ + Name: action.ActionDefinitionName(a.config["name"]), + Description: a.config["description"], + Properties: prop, + Required: requiredFields(), + } +} diff --git a/agent/actions_custom_test.go b/agent/actions_custom_test.go new file mode 100644 index 0000000..8476767 --- /dev/null +++ b/agent/actions_custom_test.go @@ -0,0 +1,87 @@ +package agent_test + +import ( + "context" + + "github.com/mudler/local-agent-framework/action" + . "github.com/mudler/local-agent-framework/agent" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/sashabaranov/go-openai/jsonschema" +) + +var _ = Describe("Agent custom action", func() { + Context("custom action", func() { + It("initializes correctly", func() { + + testCode := ` + +import ( + "encoding/json" +) +type Params struct { + Foo string +} + +func Run(config map[string]interface{}) (string, error) { + +p := Params{} +b, err := json.Marshal(config) + if err != nil { + return "", err + } +if err := json.Unmarshal(b, &p); err != nil { + return "", err +} + +return p.Foo, nil +} + +func Definition() map[string][]string { +return map[string][]string{ + "foo": []string{ + "string", + "The foo value", + }, + } +} + +func RequiredFields() []string { +return []string{"foo"} +} + + ` + + customAction, err := NewCustom( + map[string]string{ + "code": testCode, + "name": "test", + "description": "A test action", + }, + "", + ) + Expect(err).ToNot(HaveOccurred()) + + definition := customAction.Definition() + Expect(definition).To(Equal(action.ActionDefinition{ + Properties: map[string]jsonschema.Definition{ + "foo": { + Type: jsonschema.String, + Description: "The foo value", + }, + }, + Required: []string{"foo"}, + Name: "test", + Description: "A test action", + })) + + runResult, err := customAction.Run(context.Background(), action.ActionParams{ + "Foo": "bar", + }) + Expect(err).ToNot(HaveOccurred()) + Expect(runResult).To(Equal("bar")) + + }) + }) +})