Make Knowledgebase RAG functional (almost)
This commit is contained in:
@@ -48,6 +48,7 @@ type AgentPool struct {
|
||||
agents map[string]*Agent
|
||||
managers map[string]Manager
|
||||
apiURL, model string
|
||||
ragDB RAGDB
|
||||
}
|
||||
|
||||
type AgentPoolData map[string]AgentConfig
|
||||
@@ -63,7 +64,7 @@ func loadPoolFromFile(path string) (*AgentPoolData, error) {
|
||||
return poolData, err
|
||||
}
|
||||
|
||||
func NewAgentPool(model, apiURL, directory string) (*AgentPool, error) {
|
||||
func NewAgentPool(model, apiURL, directory string, RagDB RAGDB) (*AgentPool, error) {
|
||||
// if file exists, try to load an existing pool.
|
||||
// if file does not exist, create a new pool.
|
||||
|
||||
@@ -76,6 +77,7 @@ func NewAgentPool(model, apiURL, directory string) (*AgentPool, error) {
|
||||
pooldir: directory,
|
||||
apiURL: apiURL,
|
||||
model: model,
|
||||
ragDB: RagDB,
|
||||
agents: make(map[string]*Agent),
|
||||
pool: make(map[string]AgentConfig),
|
||||
managers: make(map[string]Manager),
|
||||
@@ -90,6 +92,7 @@ func NewAgentPool(model, apiURL, directory string) (*AgentPool, error) {
|
||||
file: poolfile,
|
||||
apiURL: apiURL,
|
||||
pooldir: directory,
|
||||
ragDB: RagDB,
|
||||
model: model,
|
||||
agents: make(map[string]*Agent),
|
||||
managers: make(map[string]Manager),
|
||||
@@ -229,6 +232,7 @@ func (a *AgentPool) startAgentWithConfig(name string, config *AgentConfig) error
|
||||
),
|
||||
WithStateFile(stateFile),
|
||||
WithCharacterFile(characterFile),
|
||||
WithRAGDB(a.ragDB),
|
||||
WithAgentReasoningCallback(func(state ActionCurrentState) bool {
|
||||
fmt.Println("Reasoning", state.Reasoning)
|
||||
manager.Send(
|
||||
|
||||
@@ -12,6 +12,8 @@ import (
|
||||
fiber "github.com/gofiber/fiber/v2"
|
||||
|
||||
. "github.com/mudler/local-agent-framework/agent"
|
||||
"github.com/mudler/local-agent-framework/llm"
|
||||
"github.com/mudler/local-agent-framework/llm/rag"
|
||||
)
|
||||
|
||||
type (
|
||||
@@ -24,6 +26,9 @@ type (
|
||||
var testModel = os.Getenv("TEST_MODEL")
|
||||
var apiURL = os.Getenv("API_URL")
|
||||
var apiKey = os.Getenv("API_KEY")
|
||||
var vectorStore = os.Getenv("VECTOR_STORE")
|
||||
|
||||
const defaultChunkSize = 4098
|
||||
|
||||
func init() {
|
||||
if testModel == "" {
|
||||
@@ -50,12 +55,27 @@ func main() {
|
||||
stateDir := cwd + "/pool"
|
||||
os.MkdirAll(stateDir, 0755)
|
||||
|
||||
pool, err := NewAgentPool(testModel, apiURL, stateDir)
|
||||
var dbStore RAGDB
|
||||
lai := llm.NewClient(apiKey, apiURL+"/v1")
|
||||
|
||||
switch vectorStore {
|
||||
case "localai":
|
||||
laiStore := rag.NewStoreClient(apiURL, apiKey)
|
||||
dbStore = rag.NewLocalAIRAGDB(laiStore, lai)
|
||||
default:
|
||||
var err error
|
||||
dbStore, err = rag.NewChromemDB("local-agent-framework", stateDir, lai)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
pool, err := NewAgentPool(testModel, apiURL, stateDir, dbStore)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
db, err := NewInMemoryDB(stateDir)
|
||||
db, err := NewInMemoryDB(stateDir, dbStore)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
@@ -144,6 +164,10 @@ func (a *App) KnowledgeBase(db *InMemoryDatabase) func(c *fiber.Ctx) error {
|
||||
if website == "" {
|
||||
return fmt.Errorf("please enter a URL")
|
||||
}
|
||||
chunkSize := defaultChunkSize
|
||||
if payload.ChunkSize > 0 {
|
||||
chunkSize = payload.ChunkSize
|
||||
}
|
||||
|
||||
go func() {
|
||||
content, err := Sitemap(website)
|
||||
@@ -151,9 +175,10 @@ func (a *App) KnowledgeBase(db *InMemoryDatabase) func(c *fiber.Ctx) error {
|
||||
fmt.Println("Error walking sitemap for website", err)
|
||||
}
|
||||
fmt.Println("Found pages: ", len(content))
|
||||
fmt.Println("ChunkSize: ", chunkSize)
|
||||
|
||||
for _, c := range content {
|
||||
chunks := splitParagraphIntoChunks(c, payload.ChunkSize)
|
||||
chunks := splitParagraphIntoChunks(c, chunkSize)
|
||||
fmt.Println("chunks: ", len(chunks))
|
||||
for _, chunk := range chunks {
|
||||
db.AddEntry(chunk)
|
||||
@@ -162,7 +187,7 @@ func (a *App) KnowledgeBase(db *InMemoryDatabase) func(c *fiber.Ctx) error {
|
||||
db.SaveDB()
|
||||
}
|
||||
|
||||
if err := db.SaveToStore(apiKey, apiURL); err != nil {
|
||||
if err := db.SaveToStore(); err != nil {
|
||||
fmt.Println("Error storing in the KB", err)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -10,9 +10,9 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
. "github.com/mudler/local-agent-framework/agent"
|
||||
"jaytaylor.com/html2text"
|
||||
|
||||
"github.com/mudler/local-agent-framework/llm"
|
||||
sitemap "github.com/oxffaa/gopher-parse-sitemap"
|
||||
)
|
||||
|
||||
@@ -20,6 +20,7 @@ type InMemoryDatabase struct {
|
||||
sync.Mutex
|
||||
Database []string
|
||||
path string
|
||||
rag RAGDB
|
||||
}
|
||||
|
||||
func loadDB(path string) ([]string, error) {
|
||||
@@ -33,7 +34,7 @@ func loadDB(path string) ([]string, error) {
|
||||
return poolData, err
|
||||
}
|
||||
|
||||
func NewInMemoryDB(knowledgebase string) (*InMemoryDatabase, error) {
|
||||
func NewInMemoryDB(knowledgebase string, store RAGDB) (*InMemoryDatabase, error) {
|
||||
// if file exists, try to load an existing pool.
|
||||
// if file does not exist, create a new pool.
|
||||
|
||||
@@ -44,6 +45,7 @@ func NewInMemoryDB(knowledgebase string) (*InMemoryDatabase, error) {
|
||||
return &InMemoryDatabase{
|
||||
Database: []string{},
|
||||
path: poolfile,
|
||||
rag: store,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -54,15 +56,17 @@ func NewInMemoryDB(knowledgebase string) (*InMemoryDatabase, error) {
|
||||
return &InMemoryDatabase{
|
||||
Database: poolData,
|
||||
path: poolfile,
|
||||
rag: store,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (db *InMemoryDatabase) SaveToStore(apiKey string, apiURL string) error {
|
||||
func (db *InMemoryDatabase) SaveToStore() error {
|
||||
for _, d := range db.Database {
|
||||
lai := llm.NewClient(apiKey, apiURL+"/v1")
|
||||
laiStore := llm.NewStoreClient(apiURL, apiKey)
|
||||
|
||||
err := llm.StoreStringEmbeddingInVectorDB(laiStore, lai, d)
|
||||
if d == "" {
|
||||
// skip empty chunks
|
||||
continue
|
||||
}
|
||||
err := db.rag.Store(d)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error storing in the KB: %w", err)
|
||||
}
|
||||
@@ -119,8 +123,6 @@ func Sitemap(url string) (res []string, err error) {
|
||||
// and returns a slice of strings where each string is a chunk of the paragraph
|
||||
// that is at most maxChunkSize long, ensuring that words are not split.
|
||||
func splitParagraphIntoChunks(paragraph string, maxChunkSize int) []string {
|
||||
// Check if the paragraph length is less than or equal to maxChunkSize.
|
||||
// If so, return the paragraph as the only chunk.
|
||||
if len(paragraph) <= maxChunkSize {
|
||||
return []string{paragraph}
|
||||
}
|
||||
@@ -131,11 +133,14 @@ func splitParagraphIntoChunks(paragraph string, maxChunkSize int) []string {
|
||||
words := strings.Fields(paragraph) // Splits the paragraph into words.
|
||||
|
||||
for _, word := range words {
|
||||
// Check if adding the next word would exceed the maxChunkSize.
|
||||
// If so, add the currentChunk to the chunks slice and start a new chunk.
|
||||
if currentChunk.Len()+len(word) > maxChunkSize {
|
||||
// If adding the next word would exceed maxChunkSize (considering a space if not the first word in a chunk),
|
||||
// add the currentChunk to chunks, and reset currentChunk.
|
||||
if currentChunk.Len() > 0 && currentChunk.Len()+len(word)+1 > maxChunkSize { // +1 for the space if not the first word
|
||||
chunks = append(chunks, currentChunk.String())
|
||||
currentChunk.Reset()
|
||||
} else if currentChunk.Len() == 0 && len(word) > maxChunkSize { // Word itself exceeds maxChunkSize, split the word
|
||||
chunks = append(chunks, word)
|
||||
continue
|
||||
}
|
||||
|
||||
// Add a space before the word if it's not the beginning of a new chunk.
|
||||
@@ -147,7 +152,7 @@ func splitParagraphIntoChunks(paragraph string, maxChunkSize int) []string {
|
||||
currentChunk.WriteString(word)
|
||||
}
|
||||
|
||||
// Add the last chunk if it's not empty.
|
||||
// After the loop, add any remaining content in currentChunk to chunks.
|
||||
if currentChunk.Len() > 0 {
|
||||
chunks = append(chunks, currentChunk.String())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user