Make Knowledgebase RAG functional (almost)

This commit is contained in:
mudler
2024-04-10 17:38:16 +02:00
parent 48d17b6952
commit 73524adfce
10 changed files with 178 additions and 32 deletions

View File

@@ -23,7 +23,6 @@ type Agent struct {
options *options options *options
Character Character Character Character
client *openai.Client client *openai.Client
storeClient *llm.StoreClient
jobQueue chan *Job jobQueue chan *Job
actionContext *action.ActionContext actionContext *action.ActionContext
context *action.ActionContext context *action.ActionContext
@@ -37,6 +36,11 @@ type Agent struct {
newConversations chan openai.ChatCompletionMessage newConversations chan openai.ChatCompletionMessage
} }
type RAGDB interface {
Store(s string) error
Search(s string, similarEntries int) ([]string, error)
}
func New(opts ...Option) (*Agent, error) { func New(opts ...Option) (*Agent, error) {
options, err := newOptions(opts...) options, err := newOptions(opts...)
if err != nil { if err != nil {
@@ -44,7 +48,6 @@ func New(opts ...Option) (*Agent, error) {
} }
client := llm.NewClient(options.LLMAPI.APIKey, options.LLMAPI.APIURL) client := llm.NewClient(options.LLMAPI.APIKey, options.LLMAPI.APIURL)
storeClient := llm.NewStoreClient(options.LLMAPI.APIURL, options.LLMAPI.APIKey)
c := context.Background() c := context.Background()
if options.context != nil { if options.context != nil {
@@ -58,7 +61,6 @@ func New(opts ...Option) (*Agent, error) {
client: client, client: client,
Character: options.character, Character: options.character,
currentState: &action.StateResult{}, currentState: &action.StateResult{},
storeClient: storeClient,
context: action.NewContext(ctx, cancel), context: action.NewContext(ctx, cancel),
} }
@@ -204,7 +206,7 @@ func (a *Agent) consumeJob(job *Job, role string) {
} }
// RAG // RAG
if a.options.enableKB { if a.options.enableKB && a.options.ragdb != nil {
// Walk conversation from bottom to top, and find the first message of the user // Walk conversation from bottom to top, and find the first message of the user
// to use it as a query to the KB // to use it as a query to the KB
var userMessage string var userMessage string
@@ -216,7 +218,7 @@ func (a *Agent) consumeJob(job *Job, role string) {
} }
if userMessage != "" { if userMessage != "" {
results, err := llm.FindSimilarStrings(a.storeClient, a.client, userMessage, a.options.kbResults) results, err := a.options.ragdb.Search(userMessage, a.options.kbResults)
if err != nil { if err != nil {
if a.options.debugMode { if a.options.debugMode {
fmt.Println("Error finding similar strings inside KB:", err) fmt.Println("Error finding similar strings inside KB:", err)

View File

@@ -28,6 +28,7 @@ type options struct {
permanentGoal string permanentGoal string
periodicRuns time.Duration periodicRuns time.Duration
kbResults int kbResults int
ragdb RAGDB
// callbacks // callbacks
reasoningCallback func(ActionCurrentState) bool reasoningCallback func(ActionCurrentState) bool
@@ -102,6 +103,13 @@ var EnablePersonality = func(o *options) error {
return nil return nil
} }
func WithRAGDB(db RAGDB) Option {
return func(o *options) error {
o.ragdb = db
return nil
}
}
func WithLLMAPIURL(url string) Option { func WithLLMAPIURL(url string) Option {
return func(o *options) error { return func(o *options) error {
o.LLMAPI.APIURL = url o.LLMAPI.APIURL = url

View File

@@ -48,6 +48,7 @@ type AgentPool struct {
agents map[string]*Agent agents map[string]*Agent
managers map[string]Manager managers map[string]Manager
apiURL, model string apiURL, model string
ragDB RAGDB
} }
type AgentPoolData map[string]AgentConfig type AgentPoolData map[string]AgentConfig
@@ -63,7 +64,7 @@ func loadPoolFromFile(path string) (*AgentPoolData, error) {
return poolData, err return poolData, err
} }
func NewAgentPool(model, apiURL, directory string) (*AgentPool, error) { func NewAgentPool(model, apiURL, directory string, RagDB RAGDB) (*AgentPool, error) {
// if file exists, try to load an existing pool. // if file exists, try to load an existing pool.
// if file does not exist, create a new pool. // if file does not exist, create a new pool.
@@ -76,6 +77,7 @@ func NewAgentPool(model, apiURL, directory string) (*AgentPool, error) {
pooldir: directory, pooldir: directory,
apiURL: apiURL, apiURL: apiURL,
model: model, model: model,
ragDB: RagDB,
agents: make(map[string]*Agent), agents: make(map[string]*Agent),
pool: make(map[string]AgentConfig), pool: make(map[string]AgentConfig),
managers: make(map[string]Manager), managers: make(map[string]Manager),
@@ -90,6 +92,7 @@ func NewAgentPool(model, apiURL, directory string) (*AgentPool, error) {
file: poolfile, file: poolfile,
apiURL: apiURL, apiURL: apiURL,
pooldir: directory, pooldir: directory,
ragDB: RagDB,
model: model, model: model,
agents: make(map[string]*Agent), agents: make(map[string]*Agent),
managers: make(map[string]Manager), managers: make(map[string]Manager),
@@ -229,6 +232,7 @@ func (a *AgentPool) startAgentWithConfig(name string, config *AgentConfig) error
), ),
WithStateFile(stateFile), WithStateFile(stateFile),
WithCharacterFile(characterFile), WithCharacterFile(characterFile),
WithRAGDB(a.ragDB),
WithAgentReasoningCallback(func(state ActionCurrentState) bool { WithAgentReasoningCallback(func(state ActionCurrentState) bool {
fmt.Println("Reasoning", state.Reasoning) fmt.Println("Reasoning", state.Reasoning)
manager.Send( manager.Send(

View File

@@ -12,6 +12,8 @@ import (
fiber "github.com/gofiber/fiber/v2" fiber "github.com/gofiber/fiber/v2"
. "github.com/mudler/local-agent-framework/agent" . "github.com/mudler/local-agent-framework/agent"
"github.com/mudler/local-agent-framework/llm"
"github.com/mudler/local-agent-framework/llm/rag"
) )
type ( type (
@@ -24,6 +26,9 @@ type (
var testModel = os.Getenv("TEST_MODEL") var testModel = os.Getenv("TEST_MODEL")
var apiURL = os.Getenv("API_URL") var apiURL = os.Getenv("API_URL")
var apiKey = os.Getenv("API_KEY") var apiKey = os.Getenv("API_KEY")
var vectorStore = os.Getenv("VECTOR_STORE")
const defaultChunkSize = 4098
func init() { func init() {
if testModel == "" { if testModel == "" {
@@ -50,12 +55,27 @@ func main() {
stateDir := cwd + "/pool" stateDir := cwd + "/pool"
os.MkdirAll(stateDir, 0755) os.MkdirAll(stateDir, 0755)
pool, err := NewAgentPool(testModel, apiURL, stateDir) var dbStore RAGDB
lai := llm.NewClient(apiKey, apiURL+"/v1")
switch vectorStore {
case "localai":
laiStore := rag.NewStoreClient(apiURL, apiKey)
dbStore = rag.NewLocalAIRAGDB(laiStore, lai)
default:
var err error
dbStore, err = rag.NewChromemDB("local-agent-framework", stateDir, lai)
if err != nil {
panic(err)
}
}
pool, err := NewAgentPool(testModel, apiURL, stateDir, dbStore)
if err != nil { if err != nil {
panic(err) panic(err)
} }
db, err := NewInMemoryDB(stateDir) db, err := NewInMemoryDB(stateDir, dbStore)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@@ -144,6 +164,10 @@ func (a *App) KnowledgeBase(db *InMemoryDatabase) func(c *fiber.Ctx) error {
if website == "" { if website == "" {
return fmt.Errorf("please enter a URL") return fmt.Errorf("please enter a URL")
} }
chunkSize := defaultChunkSize
if payload.ChunkSize > 0 {
chunkSize = payload.ChunkSize
}
go func() { go func() {
content, err := Sitemap(website) content, err := Sitemap(website)
@@ -151,9 +175,10 @@ func (a *App) KnowledgeBase(db *InMemoryDatabase) func(c *fiber.Ctx) error {
fmt.Println("Error walking sitemap for website", err) fmt.Println("Error walking sitemap for website", err)
} }
fmt.Println("Found pages: ", len(content)) fmt.Println("Found pages: ", len(content))
fmt.Println("ChunkSize: ", chunkSize)
for _, c := range content { for _, c := range content {
chunks := splitParagraphIntoChunks(c, payload.ChunkSize) chunks := splitParagraphIntoChunks(c, chunkSize)
fmt.Println("chunks: ", len(chunks)) fmt.Println("chunks: ", len(chunks))
for _, chunk := range chunks { for _, chunk := range chunks {
db.AddEntry(chunk) db.AddEntry(chunk)
@@ -162,7 +187,7 @@ func (a *App) KnowledgeBase(db *InMemoryDatabase) func(c *fiber.Ctx) error {
db.SaveDB() db.SaveDB()
} }
if err := db.SaveToStore(apiKey, apiURL); err != nil { if err := db.SaveToStore(); err != nil {
fmt.Println("Error storing in the KB", err) fmt.Println("Error storing in the KB", err)
} }
}() }()

View File

@@ -10,9 +10,9 @@ import (
"strings" "strings"
"sync" "sync"
. "github.com/mudler/local-agent-framework/agent"
"jaytaylor.com/html2text" "jaytaylor.com/html2text"
"github.com/mudler/local-agent-framework/llm"
sitemap "github.com/oxffaa/gopher-parse-sitemap" sitemap "github.com/oxffaa/gopher-parse-sitemap"
) )
@@ -20,6 +20,7 @@ type InMemoryDatabase struct {
sync.Mutex sync.Mutex
Database []string Database []string
path string path string
rag RAGDB
} }
func loadDB(path string) ([]string, error) { func loadDB(path string) ([]string, error) {
@@ -33,7 +34,7 @@ func loadDB(path string) ([]string, error) {
return poolData, err return poolData, err
} }
func NewInMemoryDB(knowledgebase string) (*InMemoryDatabase, error) { func NewInMemoryDB(knowledgebase string, store RAGDB) (*InMemoryDatabase, error) {
// if file exists, try to load an existing pool. // if file exists, try to load an existing pool.
// if file does not exist, create a new pool. // if file does not exist, create a new pool.
@@ -44,6 +45,7 @@ func NewInMemoryDB(knowledgebase string) (*InMemoryDatabase, error) {
return &InMemoryDatabase{ return &InMemoryDatabase{
Database: []string{}, Database: []string{},
path: poolfile, path: poolfile,
rag: store,
}, nil }, nil
} }
@@ -54,15 +56,17 @@ func NewInMemoryDB(knowledgebase string) (*InMemoryDatabase, error) {
return &InMemoryDatabase{ return &InMemoryDatabase{
Database: poolData, Database: poolData,
path: poolfile, path: poolfile,
rag: store,
}, nil }, nil
} }
func (db *InMemoryDatabase) SaveToStore(apiKey string, apiURL string) error { func (db *InMemoryDatabase) SaveToStore() error {
for _, d := range db.Database { for _, d := range db.Database {
lai := llm.NewClient(apiKey, apiURL+"/v1") if d == "" {
laiStore := llm.NewStoreClient(apiURL, apiKey) // skip empty chunks
continue
err := llm.StoreStringEmbeddingInVectorDB(laiStore, lai, d) }
err := db.rag.Store(d)
if err != nil { if err != nil {
return fmt.Errorf("Error storing in the KB: %w", err) return fmt.Errorf("Error storing in the KB: %w", err)
} }
@@ -119,8 +123,6 @@ func Sitemap(url string) (res []string, err error) {
// and returns a slice of strings where each string is a chunk of the paragraph // and returns a slice of strings where each string is a chunk of the paragraph
// that is at most maxChunkSize long, ensuring that words are not split. // that is at most maxChunkSize long, ensuring that words are not split.
func splitParagraphIntoChunks(paragraph string, maxChunkSize int) []string { func splitParagraphIntoChunks(paragraph string, maxChunkSize int) []string {
// Check if the paragraph length is less than or equal to maxChunkSize.
// If so, return the paragraph as the only chunk.
if len(paragraph) <= maxChunkSize { if len(paragraph) <= maxChunkSize {
return []string{paragraph} return []string{paragraph}
} }
@@ -131,11 +133,14 @@ func splitParagraphIntoChunks(paragraph string, maxChunkSize int) []string {
words := strings.Fields(paragraph) // Splits the paragraph into words. words := strings.Fields(paragraph) // Splits the paragraph into words.
for _, word := range words { for _, word := range words {
// Check if adding the next word would exceed the maxChunkSize. // If adding the next word would exceed maxChunkSize (considering a space if not the first word in a chunk),
// If so, add the currentChunk to the chunks slice and start a new chunk. // add the currentChunk to chunks, and reset currentChunk.
if currentChunk.Len()+len(word) > maxChunkSize { if currentChunk.Len() > 0 && currentChunk.Len()+len(word)+1 > maxChunkSize { // +1 for the space if not the first word
chunks = append(chunks, currentChunk.String()) chunks = append(chunks, currentChunk.String())
currentChunk.Reset() currentChunk.Reset()
} else if currentChunk.Len() == 0 && len(word) > maxChunkSize { // Word itself exceeds maxChunkSize, split the word
chunks = append(chunks, word)
continue
} }
// Add a space before the word if it's not the beginning of a new chunk. // Add a space before the word if it's not the beginning of a new chunk.
@@ -147,7 +152,7 @@ func splitParagraphIntoChunks(paragraph string, maxChunkSize int) []string {
currentChunk.WriteString(word) currentChunk.WriteString(word)
} }
// Add the last chunk if it's not empty. // After the loop, add any remaining content in currentChunk to chunks.
if currentChunk.Len() > 0 { if currentChunk.Len() > 0 {
chunks = append(chunks, currentChunk.String()) chunks = append(chunks, currentChunk.String())
} }

1
go.mod
View File

@@ -35,6 +35,7 @@ require (
github.com/mattn/go-runewidth v0.0.15 // indirect github.com/mattn/go-runewidth v0.0.15 // indirect
github.com/olekukonko/tablewriter v0.0.5 // indirect github.com/olekukonko/tablewriter v0.0.5 // indirect
github.com/oxffaa/gopher-parse-sitemap v0.0.0-20191021113419-005d2eb1def4 // indirect github.com/oxffaa/gopher-parse-sitemap v0.0.0-20191021113419-005d2eb1def4 // indirect
github.com/philippgille/chromem-go v0.5.0 // indirect
github.com/rivo/uniseg v0.2.0 // indirect github.com/rivo/uniseg v0.2.0 // indirect
github.com/ssor/bom v0.0.0-20170718123548-6386211fdfcf // indirect github.com/ssor/bom v0.0.0-20170718123548-6386211fdfcf // indirect
github.com/stretchr/testify v1.9.0 // indirect github.com/stretchr/testify v1.9.0 // indirect

2
go.sum
View File

@@ -59,6 +59,8 @@ github.com/onsi/gomega v1.31.1 h1:KYppCUK+bUgAZwHOu7EXVBKyQA6ILvOESHkn/tgoqvo=
github.com/onsi/gomega v1.31.1/go.mod h1:y40C95dwAD1Nz36SsEnxvfFe8FFfNxzI5eJ0EYGyAy0= github.com/onsi/gomega v1.31.1/go.mod h1:y40C95dwAD1Nz36SsEnxvfFe8FFfNxzI5eJ0EYGyAy0=
github.com/oxffaa/gopher-parse-sitemap v0.0.0-20191021113419-005d2eb1def4 h1:2vmb32OdDhjZf2ETGDlr9n8RYXx7c+jXPxMiPbwnA+8= github.com/oxffaa/gopher-parse-sitemap v0.0.0-20191021113419-005d2eb1def4 h1:2vmb32OdDhjZf2ETGDlr9n8RYXx7c+jXPxMiPbwnA+8=
github.com/oxffaa/gopher-parse-sitemap v0.0.0-20191021113419-005d2eb1def4/go.mod h1:2JQx4jDHmWrbABvpOayg/+OTU6ehN0IyK2EHzceXpJo= github.com/oxffaa/gopher-parse-sitemap v0.0.0-20191021113419-005d2eb1def4/go.mod h1:2JQx4jDHmWrbABvpOayg/+OTU6ehN0IyK2EHzceXpJo=
github.com/philippgille/chromem-go v0.5.0 h1:bryX0F3N6jnN/21iBd8i2/k9EzPTZn3nyiqAti19si8=
github.com/philippgille/chromem-go v0.5.0/go.mod h1:hTd+wGEm/fFPQl7ilfCwQXkgEUxceYh86iIdoKMolPo=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=

88
llm/rag/rag_chromem.go Normal file
View File

@@ -0,0 +1,88 @@
package rag
import (
"context"
"fmt"
"runtime"
"github.com/philippgille/chromem-go"
"github.com/sashabaranov/go-openai"
)
type ChromemDB struct {
collectionName string
collection *chromem.Collection
index int
}
func NewChromemDB(collection, path string, openaiClient *openai.Client) (*ChromemDB, error) {
// db, err := chromem.NewPersistentDB(path, true)
// if err != nil {
// return nil, err
// }
db := chromem.NewDB()
embeddingFunc := chromem.EmbeddingFunc(
func(ctx context.Context, text string) ([]float32, error) {
fmt.Println("Creating embeddings")
resp, err := openaiClient.CreateEmbeddings(ctx,
openai.EmbeddingRequestStrings{
Input: []string{text},
Model: openai.AdaEmbeddingV2,
},
)
if err != nil {
return []float32{}, fmt.Errorf("error getting keys: %v", err)
}
if len(resp.Data) == 0 {
return []float32{}, fmt.Errorf("no response from OpenAI API")
}
embedding := resp.Data[0].Embedding
return embedding, nil
},
)
c, err := db.GetOrCreateCollection(collection, nil, embeddingFunc)
if err != nil {
return nil, err
}
return &ChromemDB{
collectionName: collection,
collection: c,
index: 1,
}, nil
}
func (c *ChromemDB) Store(s string) error {
defer func() {
c.index++
}()
if s == "" {
return fmt.Errorf("empty string")
}
fmt.Println("Trying to store", s)
return c.collection.AddDocuments(context.Background(), []chromem.Document{
{
Content: s,
ID: fmt.Sprint(c.index),
},
}, runtime.NumCPU())
}
func (c *ChromemDB) Search(s string, similarEntries int) ([]string, error) {
res, err := c.collection.Query(context.Background(), s, similarEntries, nil, nil)
if err != nil {
return nil, err
}
var results []string
for _, r := range res {
results = append(results, r.Content)
}
return results, nil
}

View File

@@ -1,4 +1,4 @@
package llm package rag
import ( import (
"context" "context"
@@ -7,8 +7,20 @@ import (
"github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai"
) )
func StoreStringEmbeddingInVectorDB(client *StoreClient, openaiClient *openai.Client, s string) error { type LocalAIRAGDB struct {
resp, err := openaiClient.CreateEmbeddings(context.TODO(), client *StoreClient
openaiClient *openai.Client
}
func NewLocalAIRAGDB(storeClient *StoreClient, openaiClient *openai.Client) *LocalAIRAGDB {
return &LocalAIRAGDB{
client: storeClient,
openaiClient: openaiClient,
}
}
func (db *LocalAIRAGDB) Store(s string) error {
resp, err := db.openaiClient.CreateEmbeddings(context.TODO(),
openai.EmbeddingRequestStrings{ openai.EmbeddingRequestStrings{
Input: []string{s}, Input: []string{s},
Model: openai.AdaEmbeddingV2, Model: openai.AdaEmbeddingV2,
@@ -28,7 +40,7 @@ func StoreStringEmbeddingInVectorDB(client *StoreClient, openaiClient *openai.Cl
Keys: [][]float32{embedding}, Keys: [][]float32{embedding},
Values: []string{s}, Values: []string{s},
} }
err = client.Set(setReq) err = db.client.Set(setReq)
if err != nil { if err != nil {
return fmt.Errorf("error setting keys: %v", err) return fmt.Errorf("error setting keys: %v", err)
} }
@@ -36,9 +48,8 @@ func StoreStringEmbeddingInVectorDB(client *StoreClient, openaiClient *openai.Cl
return nil return nil
} }
func FindSimilarStrings(client *StoreClient, openaiClient *openai.Client, s string, similarEntries int) ([]string, error) { func (db *LocalAIRAGDB) Search(s string, similarEntries int) ([]string, error) {
resp, err := db.openaiClient.CreateEmbeddings(context.TODO(),
resp, err := openaiClient.CreateEmbeddings(context.TODO(),
openai.EmbeddingRequestStrings{ openai.EmbeddingRequestStrings{
Input: []string{s}, Input: []string{s},
Model: openai.AdaEmbeddingV2, Model: openai.AdaEmbeddingV2,
@@ -58,7 +69,7 @@ func FindSimilarStrings(client *StoreClient, openaiClient *openai.Client, s stri
TopK: similarEntries, // Number of similar entries you want to find TopK: similarEntries, // Number of similar entries you want to find
Key: embedding, // The key you're looking for similarities to Key: embedding, // The key you're looking for similarities to
} }
findResp, err := client.Find(findReq) findResp, err := db.client.Find(findReq)
if err != nil { if err != nil {
return []string{}, fmt.Errorf("error finding keys: %v", err) return []string{}, fmt.Errorf("error finding keys: %v", err)
} }

View File

@@ -1,4 +1,4 @@
package llm package rag
import ( import (
"bytes" "bytes"