From 8c447a0cf8143a0af5f4a52bb76b77e478c00622 Mon Sep 17 00:00:00 2001 From: mudler Date: Wed, 18 Dec 2024 20:17:05 +0100 Subject: [PATCH] feat: support separate knowledge bases for each agent Also allow to export/import KB Signed-off-by: mudler --- agent/agent.go | 5 ++ example/webui/agentpool.go | 26 +++++++-- example/webui/app.go | 70 ++++++++++++++++++++++-- example/webui/main.go | 19 +------ example/webui/rag.go | 27 ++++++--- example/webui/routes.go | 25 +++++---- example/webui/views/index.html | 49 ++--------------- example/webui/views/knowledgebase.html | 23 ++++++-- example/webui/views/partials/header.html | 42 ++++++++++++++ example/webui/views/partials/menu.html | 3 - example/webui/views/settings.html | 7 ++- llm/rag/chromem.go | 4 ++ llm/rag/localai.go | 4 ++ 13 files changed, 205 insertions(+), 99 deletions(-) diff --git a/agent/agent.go b/agent/agent.go index b38d837..35731d2 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -45,6 +45,7 @@ type RAGDB interface { Store(s string) error Reset() error Search(s string, similarEntries int) ([]string, error) + Count() int } func New(opts ...Option) (*Agent, error) { @@ -235,6 +236,10 @@ func (a *Agent) Paused() bool { return a.pause } +func (a *Agent) Memory() RAGDB { + return a.options.ragdb +} + func (a *Agent) runAction(chosenAction Action, params action.ActionParams) (result string, err error) { for _, action := range a.systemInternalActions() { if action.Definition().Name == chosenAction.Definition().Name { diff --git a/example/webui/agentpool.go b/example/webui/agentpool.go index 935d992..54d9c42 100644 --- a/example/webui/agentpool.go +++ b/example/webui/agentpool.go @@ -24,6 +24,7 @@ type AgentPool struct { agents map[string]*Agent managers map[string]Manager agentStatus map[string]*Status + agentMemory map[string]*InMemoryDatabase apiURL, model string ragDB RAGDB } @@ -76,6 +77,7 @@ func NewAgentPool(model, apiURL, directory string, RagDB RAGDB) (*AgentPool, err pool: make(map[string]AgentConfig), agentStatus: make(map[string]*Status), managers: make(map[string]Manager), + agentMemory: make(map[string]*InMemoryDatabase), }, nil } @@ -92,6 +94,7 @@ func NewAgentPool(model, apiURL, directory string, RagDB RAGDB) (*AgentPool, err agents: make(map[string]*Agent), managers: make(map[string]Manager), agentStatus: map[string]*Status{}, + agentMemory: map[string]*InMemoryDatabase{}, pool: *poolData, }, nil } @@ -144,7 +147,14 @@ func (a *AgentPool) startAgentWithConfig(name string, config *AgentConfig) error actions := config.availableActions(ctx) - stateFile, characterFile := a.stateFiles(name) + stateFile, characterFile, knowledgeBase := a.stateFiles(name) + + agentDB, err := NewInMemoryDB(knowledgeBase, a.ragDB) + if err != nil { + return err + } + + a.agentMemory[name] = agentDB actionsLog := []string{} for _, action := range actions { @@ -179,7 +189,7 @@ func (a *AgentPool) startAgentWithConfig(name string, config *AgentConfig) error WithStateFile(stateFile), WithCharacterFile(characterFile), WithTimeout(timeout), - WithRAGDB(a.ragDB), + WithRAGDB(agentDB), WithAgentReasoningCallback(func(state ActionCurrentState) bool { xlog.Info( "Agent is thinking", @@ -355,20 +365,22 @@ func (a *AgentPool) Start(name string) error { return fmt.Errorf("agent %s not found", name) } -func (a *AgentPool) stateFiles(name string) (string, string) { +func (a *AgentPool) stateFiles(name string) (string, string, string) { stateFile := filepath.Join(a.pooldir, fmt.Sprintf("%s.state.json", name)) characterFile := filepath.Join(a.pooldir, fmt.Sprintf("%s.character.json", name)) + knowledgeBaseFile := filepath.Join(a.pooldir, fmt.Sprintf("%s.knowledgebase.json", name)) - return stateFile, characterFile + return stateFile, characterFile, knowledgeBaseFile } func (a *AgentPool) Remove(name string) error { // Cleanup character and state - stateFile, characterFile := a.stateFiles(name) + stateFile, characterFile, knowledgeBaseFile := a.stateFiles(name) os.Remove(stateFile) os.Remove(characterFile) + os.Remove(knowledgeBaseFile) a.Stop(name) delete(a.agents, name) @@ -391,6 +403,10 @@ func (a *AgentPool) GetAgent(name string) *Agent { return a.agents[name] } +func (a *AgentPool) GetAgentMemory(name string) *InMemoryDatabase { + return a.agentMemory[name] +} + func (a *AgentPool) GetConfig(name string) *AgentConfig { agent, exists := a.pool[name] if !exists { diff --git a/example/webui/app.go b/example/webui/app.go index cb0f0f5..58d27f6 100644 --- a/example/webui/app.go +++ b/example/webui/app.go @@ -24,15 +24,72 @@ type ( } ) -func (a *App) KnowledgeBaseReset(db *InMemoryDatabase) func(c *fiber.Ctx) error { +func (a *App) KnowledgeBaseReset(pool *AgentPool) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { + db := pool.GetAgentMemory(c.Params("name")) db.Reset() - return c.Redirect("/knowledgebase") + return c.Redirect("/knowledgebase/" + c.Params("name")) } } -func (a *App) KnowledgeBaseFile(db *InMemoryDatabase) func(c *fiber.Ctx) error { +func (a *App) KnowledgeBaseExport(pool *AgentPool) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { + db := pool.GetAgentMemory(c.Params("name")) + knowledgeBase := db.Data() + + c.Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s.knowledgebase.json", c.Params("name"))) + return c.JSON(knowledgeBase) + } +} + +func (a *App) KnowledgeBaseImport(pool *AgentPool) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + file, err := c.FormFile("file") + if err != nil { + // Handle error + return err + } + + os.MkdirAll("./uploads", os.ModePerm) + + destination := fmt.Sprintf("./uploads/%s", file.Filename) + if err := c.SaveFile(file, destination); err != nil { + // Handle error + return err + } + + data, err := os.ReadFile(destination) + if err != nil { + return err + } + + knowledge := []string{} + if err := json.Unmarshal(data, &knowledge); err != nil { + return err + } + + if len(knowledge) > 0 { + xlog.Info("Importing agent KB") + db := pool.GetAgentMemory(c.Params("name")) + db.Reset() + + for _, k := range knowledge { + db.Store(k) + } + + } else { + return fmt.Errorf("Empty knowledge base") + } + + return c.Redirect("/agents") + } +} + +func (a *App) KnowledgeBaseFile(pool *AgentPool) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + agent := pool.GetAgent(c.Params("name")) + db := agent.Memory() + // https://golang.withcodeexample.com/blog/file-upload-handling-golang-fiber-guide/ file, err := c.FormFile("file") if err != nil { @@ -78,8 +135,11 @@ func (a *App) KnowledgeBaseFile(db *InMemoryDatabase) func(c *fiber.Ctx) error { } } -func (a *App) KnowledgeBase(db *InMemoryDatabase) func(c *fiber.Ctx) error { +func (a *App) KnowledgeBase(pool *AgentPool) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { + agent := pool.GetAgent(c.Params("name")) + db := agent.Memory() + payload := struct { URL string `form:"url"` ChunkSize int `form:"chunk_size"` @@ -100,7 +160,7 @@ func (a *App) KnowledgeBase(db *InMemoryDatabase) func(c *fiber.Ctx) error { go WebsiteToKB(website, chunkSize, db) - return c.Redirect("/knowledgebase") + return c.Redirect("/knowledgebase/" + c.Params("name")) } } diff --git a/example/webui/main.go b/example/webui/main.go index 6b21d1a..ff3a9b3 100644 --- a/example/webui/main.go +++ b/example/webui/main.go @@ -6,8 +6,6 @@ import ( "net/http" "os" - "github.com/mudler/local-agent-framework/xlog" - "github.com/donseba/go-htmx" fiber "github.com/gofiber/fiber/v2" "github.com/gofiber/template/html/v2" @@ -21,7 +19,6 @@ var testModel = os.Getenv("TEST_MODEL") var apiURL = os.Getenv("API_URL") var apiKey = os.Getenv("API_KEY") var vectorStore = os.Getenv("VECTOR_STORE") -var kbdisableIndexing = os.Getenv("KBDISABLEINDEX") var timeout = os.Getenv("TIMEOUT") var embeddingModel = os.Getenv("EMBEDDING_MODEL") @@ -70,23 +67,11 @@ func main() { } } - db, err := NewInMemoryDB(stateDir, ragDB) + pool, err := NewAgentPool(testModel, apiURL, stateDir, ragDB) if err != nil { panic(err) } - pool, err := NewAgentPool(testModel, apiURL, stateDir, db) - if err != nil { - panic(err) - } - - if len(db.Database) > 0 && kbdisableIndexing != "true" { - xlog.Info("Loading knowledgebase from disk, to skip run with KBDISABLEINDEX=true") - if err := db.PopulateRAGDB(); err != nil { - xlog.Info("Error storing in the KB", err) - } - } - app := &App{ htmx: htmx.New(), pool: pool, @@ -102,7 +87,7 @@ func main() { Views: engine, }) - RegisterRoutes(webapp, pool, db, app) + RegisterRoutes(webapp, pool, app) log.Fatal(webapp.Listen(":3000")) } diff --git a/example/webui/rag.go b/example/webui/rag.go index 30bbeae..cfc0024 100644 --- a/example/webui/rag.go +++ b/example/webui/rag.go @@ -9,7 +9,6 @@ import ( "net/http" "os" - "path/filepath" "strings" "sync" @@ -37,12 +36,10 @@ func loadDB(path string) ([]string, error) { return poolData, err } -func NewInMemoryDB(knowledgebase string, store RAGDB) (*InMemoryDatabase, error) { +func NewInMemoryDB(poolfile string, store RAGDB) (*InMemoryDatabase, error) { // if file exists, try to load an existing pool. // if file does not exist, create a new pool. - poolfile := filepath.Join(knowledgebase, "knowledgebase.json") - if _, err := os.Stat(poolfile); err != nil { // file does not exist, return a new pool return &InMemoryDatabase{ @@ -56,14 +53,26 @@ func NewInMemoryDB(knowledgebase string, store RAGDB) (*InMemoryDatabase, error) if err != nil { return nil, err } - return &InMemoryDatabase{ + db := &InMemoryDatabase{ RAGDB: store, Database: poolData, path: poolfile, - }, nil + } + + if err := db.populateRAGDB(); err != nil { + return nil, fmt.Errorf("error populating RAGDB: %w", err) + } + + return db, nil } -func (db *InMemoryDatabase) PopulateRAGDB() error { +func (db *InMemoryDatabase) Data() []string { + db.Lock() + defer db.Unlock() + return db.Database +} + +func (db *InMemoryDatabase) populateRAGDB() error { for _, d := range db.Database { if d == "" { // skip empty chunks @@ -139,7 +148,7 @@ func getWebSitemap(url string) (res []string, err error) { return } -func WebsiteToKB(website string, chunkSize int, db *InMemoryDatabase) { +func WebsiteToKB(website string, chunkSize int, db RAGDB) { content, err := getWebSitemap(website) if err != nil { xlog.Info("Error walking sitemap for website", err) @@ -150,7 +159,7 @@ func WebsiteToKB(website string, chunkSize int, db *InMemoryDatabase) { StringsToKB(db, chunkSize, content...) } -func StringsToKB(db *InMemoryDatabase, chunkSize int, content ...string) { +func StringsToKB(db RAGDB, chunkSize int, content ...string) { for _, c := range content { chunks := splitParagraphIntoChunks(c, chunkSize) xlog.Info("chunks: ", len(chunks)) diff --git a/example/webui/routes.go b/example/webui/routes.go index 2109642..f719b90 100644 --- a/example/webui/routes.go +++ b/example/webui/routes.go @@ -9,7 +9,7 @@ import ( "github.com/mudler/local-agent-framework/agent" ) -func RegisterRoutes(webapp *fiber.App, pool *AgentPool, db *InMemoryDatabase, app *App) { +func RegisterRoutes(webapp *fiber.App, pool *AgentPool, app *App) { webapp.Use("/public", filesystem.New(filesystem.Config{ Root: http.FS(embeddedFiles), @@ -36,17 +36,20 @@ func RegisterRoutes(webapp *fiber.App, pool *AgentPool, db *InMemoryDatabase, ap webapp.Get("/create", func(c *fiber.Ctx) error { return c.Render("views/create", fiber.Map{ - "Title": "Hello, World!", "Actions": AvailableActions, "Connectors": AvailableConnectors, }) }) - webapp.Get("/knowledgebase", func(c *fiber.Ctx) error { - return c.Render("views/knowledgebase", fiber.Map{ - "Title": "Hello, World!", - "KnowledgebaseItemsCount": len(db.Database), - }) + webapp.Get("/knowledgebase/:name", func(c *fiber.Ctx) error { + db := pool.GetAgentMemory(c.Params("name")) + return c.Render( + "views/knowledgebase", + fiber.Map{ + "KnowledgebaseItemsCount": db.Count(), + "Name": c.Params("name"), + }, + ) }) // Define a route for the GET method on the root path '/' @@ -80,9 +83,11 @@ func RegisterRoutes(webapp *fiber.App, pool *AgentPool, db *InMemoryDatabase, ap webapp.Put("/pause/:name", app.Pause(pool)) webapp.Put("/start/:name", app.Start(pool)) - webapp.Post("/knowledgebase", app.KnowledgeBase(db)) - webapp.Post("/knowledgebase/upload", app.KnowledgeBaseFile(db)) - webapp.Delete("/knowledgebase/reset", app.KnowledgeBaseReset(db)) + webapp.Post("/knowledgebase/:name", app.KnowledgeBase(pool)) + webapp.Post("/knowledgebase/:name/upload", app.KnowledgeBaseFile(pool)) + webapp.Delete("/knowledgebase/:name/reset", app.KnowledgeBaseReset(pool)) + webapp.Post("/knowledgebase/:name/import", app.KnowledgeBaseImport(pool)) + webapp.Get("/knowledgebase/:name/export", app.KnowledgeBaseExport(pool)) webapp.Get("/talk/:name", func(c *fiber.Ctx) error { return c.Render("views/chat", fiber.Map{ diff --git a/example/webui/views/index.html b/example/webui/views/index.html index b03fbbf..82283df 100644 --- a/example/webui/views/index.html +++ b/example/webui/views/index.html @@ -4,47 +4,7 @@ Smart Assistant Dashboard {{template "views/partials/header"}} @@ -60,11 +20,10 @@

View and manage your list of agents, including detailed profiles and statistics.

- - +
-

Manage Knowledgebase

-

Access and update your knowledgebase to improve agent responses and efficiency.

+

Create

+

Create a new agent.

diff --git a/example/webui/views/knowledgebase.html b/example/webui/views/knowledgebase.html index 7f68770..c2cea87 100644 --- a/example/webui/views/knowledgebase.html +++ b/example/webui/views/knowledgebase.html @@ -1,7 +1,7 @@ - KnowledgeBase + Knowledgebase for {{.Name}} {{template "views/partials/header"}} @@ -12,7 +12,7 @@
-
+

Add sites to KB

@@ -24,7 +24,7 @@
- +

Upload File

@@ -37,7 +37,22 @@
- + +
+ +
+

Export

+ Export +
+ +
+ +

Import

+ + + + +