diff --git a/agent/agent.go b/agent/agent.go index eaa2a41..7560f61 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -23,7 +23,6 @@ type Agent struct { options *options Character Character client *openai.Client - storeClient *llm.StoreClient jobQueue chan *Job actionContext *action.ActionContext context *action.ActionContext @@ -37,6 +36,11 @@ type Agent struct { newConversations chan openai.ChatCompletionMessage } +type RAGDB interface { + Store(s string) error + Search(s string, similarEntries int) ([]string, error) +} + func New(opts ...Option) (*Agent, error) { options, err := newOptions(opts...) if err != nil { @@ -44,7 +48,6 @@ func New(opts ...Option) (*Agent, error) { } client := llm.NewClient(options.LLMAPI.APIKey, options.LLMAPI.APIURL) - storeClient := llm.NewStoreClient(options.LLMAPI.APIURL, options.LLMAPI.APIKey) c := context.Background() if options.context != nil { @@ -58,7 +61,6 @@ func New(opts ...Option) (*Agent, error) { client: client, Character: options.character, currentState: &action.StateResult{}, - storeClient: storeClient, context: action.NewContext(ctx, cancel), } @@ -204,7 +206,7 @@ func (a *Agent) consumeJob(job *Job, role string) { } // RAG - if a.options.enableKB { + if a.options.enableKB && a.options.ragdb != nil { // Walk conversation from bottom to top, and find the first message of the user // to use it as a query to the KB var userMessage string @@ -216,7 +218,7 @@ func (a *Agent) consumeJob(job *Job, role string) { } if userMessage != "" { - results, err := llm.FindSimilarStrings(a.storeClient, a.client, userMessage, a.options.kbResults) + results, err := a.options.ragdb.Search(userMessage, a.options.kbResults) if err != nil { if a.options.debugMode { fmt.Println("Error finding similar strings inside KB:", err) diff --git a/agent/options.go b/agent/options.go index 1d7cde8..82b4ef3 100644 --- a/agent/options.go +++ b/agent/options.go @@ -28,6 +28,7 @@ type options struct { permanentGoal string periodicRuns time.Duration kbResults int + ragdb RAGDB // callbacks reasoningCallback func(ActionCurrentState) bool @@ -102,6 +103,13 @@ var EnablePersonality = func(o *options) error { return nil } +func WithRAGDB(db RAGDB) Option { + return func(o *options) error { + o.ragdb = db + return nil + } +} + func WithLLMAPIURL(url string) Option { return func(o *options) error { o.LLMAPI.APIURL = url diff --git a/example/webui/agentpool.go b/example/webui/agentpool.go index 3005e8d..b8e6336 100644 --- a/example/webui/agentpool.go +++ b/example/webui/agentpool.go @@ -48,6 +48,7 @@ type AgentPool struct { agents map[string]*Agent managers map[string]Manager apiURL, model string + ragDB RAGDB } type AgentPoolData map[string]AgentConfig @@ -63,7 +64,7 @@ func loadPoolFromFile(path string) (*AgentPoolData, error) { return poolData, err } -func NewAgentPool(model, apiURL, directory string) (*AgentPool, error) { +func NewAgentPool(model, apiURL, directory string, RagDB RAGDB) (*AgentPool, error) { // if file exists, try to load an existing pool. // if file does not exist, create a new pool. @@ -76,6 +77,7 @@ func NewAgentPool(model, apiURL, directory string) (*AgentPool, error) { pooldir: directory, apiURL: apiURL, model: model, + ragDB: RagDB, agents: make(map[string]*Agent), pool: make(map[string]AgentConfig), managers: make(map[string]Manager), @@ -90,6 +92,7 @@ func NewAgentPool(model, apiURL, directory string) (*AgentPool, error) { file: poolfile, apiURL: apiURL, pooldir: directory, + ragDB: RagDB, model: model, agents: make(map[string]*Agent), managers: make(map[string]Manager), @@ -229,6 +232,7 @@ func (a *AgentPool) startAgentWithConfig(name string, config *AgentConfig) error ), WithStateFile(stateFile), WithCharacterFile(characterFile), + WithRAGDB(a.ragDB), WithAgentReasoningCallback(func(state ActionCurrentState) bool { fmt.Println("Reasoning", state.Reasoning) manager.Send( diff --git a/example/webui/main.go b/example/webui/main.go index 4fc6d95..15fb1a0 100644 --- a/example/webui/main.go +++ b/example/webui/main.go @@ -12,6 +12,8 @@ import ( fiber "github.com/gofiber/fiber/v2" . "github.com/mudler/local-agent-framework/agent" + "github.com/mudler/local-agent-framework/llm" + "github.com/mudler/local-agent-framework/llm/rag" ) type ( @@ -24,6 +26,9 @@ type ( var testModel = os.Getenv("TEST_MODEL") var apiURL = os.Getenv("API_URL") var apiKey = os.Getenv("API_KEY") +var vectorStore = os.Getenv("VECTOR_STORE") + +const defaultChunkSize = 4098 func init() { if testModel == "" { @@ -50,12 +55,27 @@ func main() { stateDir := cwd + "/pool" os.MkdirAll(stateDir, 0755) - pool, err := NewAgentPool(testModel, apiURL, stateDir) + var dbStore RAGDB + lai := llm.NewClient(apiKey, apiURL+"/v1") + + switch vectorStore { + case "localai": + laiStore := rag.NewStoreClient(apiURL, apiKey) + dbStore = rag.NewLocalAIRAGDB(laiStore, lai) + default: + var err error + dbStore, err = rag.NewChromemDB("local-agent-framework", stateDir, lai) + if err != nil { + panic(err) + } + } + + pool, err := NewAgentPool(testModel, apiURL, stateDir, dbStore) if err != nil { panic(err) } - db, err := NewInMemoryDB(stateDir) + db, err := NewInMemoryDB(stateDir, dbStore) if err != nil { panic(err) } @@ -144,6 +164,10 @@ func (a *App) KnowledgeBase(db *InMemoryDatabase) func(c *fiber.Ctx) error { if website == "" { return fmt.Errorf("please enter a URL") } + chunkSize := defaultChunkSize + if payload.ChunkSize > 0 { + chunkSize = payload.ChunkSize + } go func() { content, err := Sitemap(website) @@ -151,9 +175,10 @@ func (a *App) KnowledgeBase(db *InMemoryDatabase) func(c *fiber.Ctx) error { fmt.Println("Error walking sitemap for website", err) } fmt.Println("Found pages: ", len(content)) + fmt.Println("ChunkSize: ", chunkSize) for _, c := range content { - chunks := splitParagraphIntoChunks(c, payload.ChunkSize) + chunks := splitParagraphIntoChunks(c, chunkSize) fmt.Println("chunks: ", len(chunks)) for _, chunk := range chunks { db.AddEntry(chunk) @@ -162,7 +187,7 @@ func (a *App) KnowledgeBase(db *InMemoryDatabase) func(c *fiber.Ctx) error { db.SaveDB() } - if err := db.SaveToStore(apiKey, apiURL); err != nil { + if err := db.SaveToStore(); err != nil { fmt.Println("Error storing in the KB", err) } }() diff --git a/example/webui/rag.go b/example/webui/rag.go index ace8345..375a92a 100644 --- a/example/webui/rag.go +++ b/example/webui/rag.go @@ -10,9 +10,9 @@ import ( "strings" "sync" + . "github.com/mudler/local-agent-framework/agent" "jaytaylor.com/html2text" - "github.com/mudler/local-agent-framework/llm" sitemap "github.com/oxffaa/gopher-parse-sitemap" ) @@ -20,6 +20,7 @@ type InMemoryDatabase struct { sync.Mutex Database []string path string + rag RAGDB } func loadDB(path string) ([]string, error) { @@ -33,7 +34,7 @@ func loadDB(path string) ([]string, error) { return poolData, err } -func NewInMemoryDB(knowledgebase string) (*InMemoryDatabase, error) { +func NewInMemoryDB(knowledgebase string, store RAGDB) (*InMemoryDatabase, error) { // if file exists, try to load an existing pool. // if file does not exist, create a new pool. @@ -44,6 +45,7 @@ func NewInMemoryDB(knowledgebase string) (*InMemoryDatabase, error) { return &InMemoryDatabase{ Database: []string{}, path: poolfile, + rag: store, }, nil } @@ -54,15 +56,17 @@ func NewInMemoryDB(knowledgebase string) (*InMemoryDatabase, error) { return &InMemoryDatabase{ Database: poolData, path: poolfile, + rag: store, }, nil } -func (db *InMemoryDatabase) SaveToStore(apiKey string, apiURL string) error { +func (db *InMemoryDatabase) SaveToStore() error { for _, d := range db.Database { - lai := llm.NewClient(apiKey, apiURL+"/v1") - laiStore := llm.NewStoreClient(apiURL, apiKey) - - err := llm.StoreStringEmbeddingInVectorDB(laiStore, lai, d) + if d == "" { + // skip empty chunks + continue + } + err := db.rag.Store(d) if err != nil { return fmt.Errorf("Error storing in the KB: %w", err) } @@ -119,8 +123,6 @@ func Sitemap(url string) (res []string, err error) { // and returns a slice of strings where each string is a chunk of the paragraph // that is at most maxChunkSize long, ensuring that words are not split. func splitParagraphIntoChunks(paragraph string, maxChunkSize int) []string { - // Check if the paragraph length is less than or equal to maxChunkSize. - // If so, return the paragraph as the only chunk. if len(paragraph) <= maxChunkSize { return []string{paragraph} } @@ -131,11 +133,14 @@ func splitParagraphIntoChunks(paragraph string, maxChunkSize int) []string { words := strings.Fields(paragraph) // Splits the paragraph into words. for _, word := range words { - // Check if adding the next word would exceed the maxChunkSize. - // If so, add the currentChunk to the chunks slice and start a new chunk. - if currentChunk.Len()+len(word) > maxChunkSize { + // If adding the next word would exceed maxChunkSize (considering a space if not the first word in a chunk), + // add the currentChunk to chunks, and reset currentChunk. + if currentChunk.Len() > 0 && currentChunk.Len()+len(word)+1 > maxChunkSize { // +1 for the space if not the first word chunks = append(chunks, currentChunk.String()) currentChunk.Reset() + } else if currentChunk.Len() == 0 && len(word) > maxChunkSize { // Word itself exceeds maxChunkSize, split the word + chunks = append(chunks, word) + continue } // Add a space before the word if it's not the beginning of a new chunk. @@ -147,7 +152,7 @@ func splitParagraphIntoChunks(paragraph string, maxChunkSize int) []string { currentChunk.WriteString(word) } - // Add the last chunk if it's not empty. + // After the loop, add any remaining content in currentChunk to chunks. if currentChunk.Len() > 0 { chunks = append(chunks, currentChunk.String()) } diff --git a/go.mod b/go.mod index a729aa0..9aeb6ab 100644 --- a/go.mod +++ b/go.mod @@ -35,6 +35,7 @@ require ( github.com/mattn/go-runewidth v0.0.15 // indirect github.com/olekukonko/tablewriter v0.0.5 // indirect github.com/oxffaa/gopher-parse-sitemap v0.0.0-20191021113419-005d2eb1def4 // indirect + github.com/philippgille/chromem-go v0.5.0 // indirect github.com/rivo/uniseg v0.2.0 // indirect github.com/ssor/bom v0.0.0-20170718123548-6386211fdfcf // indirect github.com/stretchr/testify v1.9.0 // indirect diff --git a/go.sum b/go.sum index c39285b..b1deafb 100644 --- a/go.sum +++ b/go.sum @@ -59,6 +59,8 @@ github.com/onsi/gomega v1.31.1 h1:KYppCUK+bUgAZwHOu7EXVBKyQA6ILvOESHkn/tgoqvo= github.com/onsi/gomega v1.31.1/go.mod h1:y40C95dwAD1Nz36SsEnxvfFe8FFfNxzI5eJ0EYGyAy0= github.com/oxffaa/gopher-parse-sitemap v0.0.0-20191021113419-005d2eb1def4 h1:2vmb32OdDhjZf2ETGDlr9n8RYXx7c+jXPxMiPbwnA+8= github.com/oxffaa/gopher-parse-sitemap v0.0.0-20191021113419-005d2eb1def4/go.mod h1:2JQx4jDHmWrbABvpOayg/+OTU6ehN0IyK2EHzceXpJo= +github.com/philippgille/chromem-go v0.5.0 h1:bryX0F3N6jnN/21iBd8i2/k9EzPTZn3nyiqAti19si8= +github.com/philippgille/chromem-go v0.5.0/go.mod h1:hTd+wGEm/fFPQl7ilfCwQXkgEUxceYh86iIdoKMolPo= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= diff --git a/llm/rag/rag_chromem.go b/llm/rag/rag_chromem.go new file mode 100644 index 0000000..a14486e --- /dev/null +++ b/llm/rag/rag_chromem.go @@ -0,0 +1,88 @@ +package rag + +import ( + "context" + "fmt" + "runtime" + + "github.com/philippgille/chromem-go" + "github.com/sashabaranov/go-openai" +) + +type ChromemDB struct { + collectionName string + collection *chromem.Collection + index int +} + +func NewChromemDB(collection, path string, openaiClient *openai.Client) (*ChromemDB, error) { + // db, err := chromem.NewPersistentDB(path, true) + // if err != nil { + // return nil, err + // } + db := chromem.NewDB() + + embeddingFunc := chromem.EmbeddingFunc( + func(ctx context.Context, text string) ([]float32, error) { + fmt.Println("Creating embeddings") + resp, err := openaiClient.CreateEmbeddings(ctx, + openai.EmbeddingRequestStrings{ + Input: []string{text}, + Model: openai.AdaEmbeddingV2, + }, + ) + if err != nil { + return []float32{}, fmt.Errorf("error getting keys: %v", err) + } + + if len(resp.Data) == 0 { + return []float32{}, fmt.Errorf("no response from OpenAI API") + } + + embedding := resp.Data[0].Embedding + + return embedding, nil + }, + ) + + c, err := db.GetOrCreateCollection(collection, nil, embeddingFunc) + if err != nil { + return nil, err + } + + return &ChromemDB{ + collectionName: collection, + collection: c, + index: 1, + }, nil +} + +func (c *ChromemDB) Store(s string) error { + defer func() { + c.index++ + }() + if s == "" { + return fmt.Errorf("empty string") + } + fmt.Println("Trying to store", s) + return c.collection.AddDocuments(context.Background(), []chromem.Document{ + { + Content: s, + ID: fmt.Sprint(c.index), + }, + }, runtime.NumCPU()) +} + +func (c *ChromemDB) Search(s string, similarEntries int) ([]string, error) { + res, err := c.collection.Query(context.Background(), s, similarEntries, nil, nil) + if err != nil { + return nil, err + } + + var results []string + for _, r := range res { + results = append(results, r.Content) + } + + return results, nil +} diff --git a/llm/rag.go b/llm/rag/rag_localai.go similarity index 65% rename from llm/rag.go rename to llm/rag/rag_localai.go index af0371b..7e8bf3a 100644 --- a/llm/rag.go +++ b/llm/rag/rag_localai.go @@ -1,4 +1,4 @@ -package llm +package rag import ( "context" @@ -7,8 +7,20 @@ import ( "github.com/sashabaranov/go-openai" ) -func StoreStringEmbeddingInVectorDB(client *StoreClient, openaiClient *openai.Client, s string) error { - resp, err := openaiClient.CreateEmbeddings(context.TODO(), +type LocalAIRAGDB struct { + client *StoreClient + openaiClient *openai.Client +} + +func NewLocalAIRAGDB(storeClient *StoreClient, openaiClient *openai.Client) *LocalAIRAGDB { + return &LocalAIRAGDB{ + client: storeClient, + openaiClient: openaiClient, + } +} + +func (db *LocalAIRAGDB) Store(s string) error { + resp, err := db.openaiClient.CreateEmbeddings(context.TODO(), openai.EmbeddingRequestStrings{ Input: []string{s}, Model: openai.AdaEmbeddingV2, @@ -28,7 +40,7 @@ func StoreStringEmbeddingInVectorDB(client *StoreClient, openaiClient *openai.Cl Keys: [][]float32{embedding}, Values: []string{s}, } - err = client.Set(setReq) + err = db.client.Set(setReq) if err != nil { return fmt.Errorf("error setting keys: %v", err) } @@ -36,9 +48,8 @@ func StoreStringEmbeddingInVectorDB(client *StoreClient, openaiClient *openai.Cl return nil } -func FindSimilarStrings(client *StoreClient, openaiClient *openai.Client, s string, similarEntries int) ([]string, error) { - - resp, err := openaiClient.CreateEmbeddings(context.TODO(), +func (db *LocalAIRAGDB) Search(s string, similarEntries int) ([]string, error) { + resp, err := db.openaiClient.CreateEmbeddings(context.TODO(), openai.EmbeddingRequestStrings{ Input: []string{s}, Model: openai.AdaEmbeddingV2, @@ -58,7 +69,7 @@ func FindSimilarStrings(client *StoreClient, openaiClient *openai.Client, s stri TopK: similarEntries, // Number of similar entries you want to find Key: embedding, // The key you're looking for similarities to } - findResp, err := client.Find(findReq) + findResp, err := db.client.Find(findReq) if err != nil { return []string{}, fmt.Errorf("error finding keys: %v", err) } diff --git a/llm/store.go b/llm/rag/store.go similarity index 99% rename from llm/store.go rename to llm/rag/store.go index 3b85568..1980cf9 100644 --- a/llm/store.go +++ b/llm/rag/store.go @@ -1,4 +1,4 @@ -package llm +package rag import ( "bytes"