Refactorings

This commit is contained in:
Ettore Di Giacinto
2025-03-02 22:44:54 +01:00
parent f6e16be170
commit 5e52383a99
7 changed files with 110 additions and 107 deletions

View File

@@ -1,101 +1,6 @@
package agent package agent
import (
"fmt"
"strings"
"github.com/mudler/LocalAgent/pkg/xlog"
"github.com/traefik/yaegi/interp"
"github.com/traefik/yaegi/stdlib"
)
type PromptBlock interface { type PromptBlock interface {
Render(a *Agent) (string, error) Render(a *Agent) (string, error)
Role() string Role() string
} }
type DynamicPrompt struct {
config map[string]string
goPkgPath string
i *interp.Interpreter
}
func NewDynamicPrompt(config map[string]string, goPkgPath string) (*DynamicPrompt, error) {
a := &DynamicPrompt{
config: config,
goPkgPath: goPkgPath,
}
if err := a.initializeInterpreter(); err != nil {
return nil, err
}
if err := a.callInit(); err != nil {
xlog.Error("Error calling custom action init", "error", err)
}
return a, nil
}
func (a *DynamicPrompt) callInit() error {
if a.i == nil {
return nil
}
v, err := a.i.Eval(fmt.Sprintf("%s.Init", a.config["name"]))
if err != nil {
return err
}
run := v.Interface().(func() error)
return run()
}
func (a *DynamicPrompt) initializeInterpreter() error {
if _, exists := a.config["code"]; exists && a.i == nil {
unsafe := strings.ToLower(a.config["unsafe"]) == "true"
i := interp.New(interp.Options{
GoPath: a.goPkgPath,
Unrestricted: unsafe,
})
if err := i.Use(stdlib.Symbols); err != nil {
return err
}
if _, exists := a.config["name"]; !exists {
a.config["name"] = "custom"
}
_, err := i.Eval(fmt.Sprintf("package %s\n%s", a.config["name"], a.config["code"]))
if err != nil {
return err
}
a.i = i
}
return nil
}
func (a *DynamicPrompt) Render(c *Agent) (string, error) {
v, err := a.i.Eval(fmt.Sprintf("%s.Render", a.config["name"]))
if err != nil {
return "", err
}
run := v.Interface().(func() (string, error))
return run()
}
func (a *DynamicPrompt) Role() string {
v, err := a.i.Eval(fmt.Sprintf("%s.Role", a.config["name"]))
if err != nil {
return "system"
}
run := v.Interface().(func() string)
return run()
}

View File

@@ -6,6 +6,7 @@ import (
"path/filepath" "path/filepath"
"github.com/mudler/LocalAgent/core/state" "github.com/mudler/LocalAgent/core/state"
"github.com/mudler/LocalAgent/services"
"github.com/mudler/LocalAgent/webui" "github.com/mudler/LocalAgent/webui"
) )
@@ -47,9 +48,9 @@ func main() {
apiKey, apiKey,
stateDir, stateDir,
localRAG, localRAG,
webui.Actions, services.Actions,
webui.Connectors, services.Connectors,
webui.PromptBlocks, services.PromptBlocks,
timeout, timeout,
) )
if err != nil { if err != nil {

View File

@@ -1,4 +1,4 @@
package webui package services
import ( import (
"context" "context"

View File

@@ -1,4 +1,4 @@
package webui package services
import ( import (
"encoding/json" "encoding/json"

View File

@@ -1,9 +1,10 @@
package webui package services
import ( import (
"encoding/json" "encoding/json"
"github.com/mudler/LocalAgent/pkg/xlog" "github.com/mudler/LocalAgent/pkg/xlog"
"github.com/mudler/LocalAgent/services/prompts"
"github.com/mudler/LocalAgent/core/agent" "github.com/mudler/LocalAgent/core/agent"
"github.com/mudler/LocalAgent/core/state" "github.com/mudler/LocalAgent/core/state"
@@ -29,7 +30,7 @@ func PromptBlocks(a *state.AgentConfig) []agent.PromptBlock {
} }
switch c.Type { switch c.Type {
case DynamicPromptCustom: case DynamicPromptCustom:
prompt, err := agent.NewDynamicPrompt(config, "") prompt, err := prompts.NewDynamicPrompt(config, "")
if err != nil { if err != nil {
xlog.Error("Error creating custom prompt", "error", err) xlog.Error("Error creating custom prompt", "error", err)
continue continue

View File

@@ -0,0 +1,97 @@
package prompts
import (
"fmt"
"strings"
"github.com/mudler/LocalAgent/core/agent"
"github.com/mudler/LocalAgent/pkg/xlog"
"github.com/traefik/yaegi/interp"
"github.com/traefik/yaegi/stdlib"
)
type DynamicPrompt struct {
config map[string]string
goPkgPath string
i *interp.Interpreter
}
func NewDynamicPrompt(config map[string]string, goPkgPath string) (*DynamicPrompt, error) {
a := &DynamicPrompt{
config: config,
goPkgPath: goPkgPath,
}
if err := a.initializeInterpreter(); err != nil {
return nil, err
}
if err := a.callInit(); err != nil {
xlog.Error("Error calling custom action init", "error", err)
}
return a, nil
}
func (a *DynamicPrompt) callInit() error {
if a.i == nil {
return nil
}
v, err := a.i.Eval(fmt.Sprintf("%s.Init", a.config["name"]))
if err != nil {
return err
}
run := v.Interface().(func() error)
return run()
}
func (a *DynamicPrompt) initializeInterpreter() error {
if _, exists := a.config["code"]; exists && a.i == nil {
unsafe := strings.ToLower(a.config["unsafe"]) == "true"
i := interp.New(interp.Options{
GoPath: a.goPkgPath,
Unrestricted: unsafe,
})
if err := i.Use(stdlib.Symbols); err != nil {
return err
}
if _, exists := a.config["name"]; !exists {
a.config["name"] = "custom"
}
_, err := i.Eval(fmt.Sprintf("package %s\n%s", a.config["name"], a.config["code"]))
if err != nil {
return err
}
a.i = i
}
return nil
}
func (a *DynamicPrompt) Render(c *agent.Agent) (string, error) {
v, err := a.i.Eval(fmt.Sprintf("%s.Render", a.config["name"]))
if err != nil {
return "", err
}
run := v.Interface().(func() (string, error))
return run()
}
func (a *DynamicPrompt) Role() string {
v, err := a.i.Eval(fmt.Sprintf("%s.Role", a.config["name"]))
if err != nil {
return "system"
}
run := v.Interface().(func() string)
return run()
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/mudler/LocalAgent/core/agent" "github.com/mudler/LocalAgent/core/agent"
"github.com/mudler/LocalAgent/core/sse" "github.com/mudler/LocalAgent/core/sse"
"github.com/mudler/LocalAgent/core/state" "github.com/mudler/LocalAgent/core/state"
"github.com/mudler/LocalAgent/services"
) )
//go:embed views/* //go:embed views/*
@@ -45,9 +46,9 @@ func (app *App) registerRoutes(pool *state.AgentPool, webapp *fiber.App) {
webapp.Get("/create", func(c *fiber.Ctx) error { webapp.Get("/create", func(c *fiber.Ctx) error {
return c.Render("views/create", fiber.Map{ return c.Render("views/create", fiber.Map{
"Actions": AvailableActions, "Actions": services.AvailableActions,
"Connectors": AvailableConnectors, "Connectors": services.AvailableConnectors,
"PromptBlocks": AvailableBlockPrompts, "PromptBlocks": services.AvailableBlockPrompts,
}) })
}) })
@@ -97,8 +98,6 @@ func (app *App) registerRoutes(pool *state.AgentPool, webapp *fiber.App) {
}) })
webapp.Post("/settings/import", app.ImportAgent(pool)) webapp.Post("/settings/import", app.ImportAgent(pool))
webapp.Get("/settings/export/:name", app.ExportAgent(pool)) webapp.Get("/settings/export/:name", app.ExportAgent(pool))
return
} }
var letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") var letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")