Make Knowledgebase RAG functional (almost)
This commit is contained in:
@@ -23,7 +23,6 @@ type Agent struct {
|
|||||||
options *options
|
options *options
|
||||||
Character Character
|
Character Character
|
||||||
client *openai.Client
|
client *openai.Client
|
||||||
storeClient *llm.StoreClient
|
|
||||||
jobQueue chan *Job
|
jobQueue chan *Job
|
||||||
actionContext *action.ActionContext
|
actionContext *action.ActionContext
|
||||||
context *action.ActionContext
|
context *action.ActionContext
|
||||||
@@ -37,6 +36,11 @@ type Agent struct {
|
|||||||
newConversations chan openai.ChatCompletionMessage
|
newConversations chan openai.ChatCompletionMessage
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type RAGDB interface {
|
||||||
|
Store(s string) error
|
||||||
|
Search(s string, similarEntries int) ([]string, error)
|
||||||
|
}
|
||||||
|
|
||||||
func New(opts ...Option) (*Agent, error) {
|
func New(opts ...Option) (*Agent, error) {
|
||||||
options, err := newOptions(opts...)
|
options, err := newOptions(opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -44,7 +48,6 @@ func New(opts ...Option) (*Agent, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
client := llm.NewClient(options.LLMAPI.APIKey, options.LLMAPI.APIURL)
|
client := llm.NewClient(options.LLMAPI.APIKey, options.LLMAPI.APIURL)
|
||||||
storeClient := llm.NewStoreClient(options.LLMAPI.APIURL, options.LLMAPI.APIKey)
|
|
||||||
|
|
||||||
c := context.Background()
|
c := context.Background()
|
||||||
if options.context != nil {
|
if options.context != nil {
|
||||||
@@ -58,7 +61,6 @@ func New(opts ...Option) (*Agent, error) {
|
|||||||
client: client,
|
client: client,
|
||||||
Character: options.character,
|
Character: options.character,
|
||||||
currentState: &action.StateResult{},
|
currentState: &action.StateResult{},
|
||||||
storeClient: storeClient,
|
|
||||||
context: action.NewContext(ctx, cancel),
|
context: action.NewContext(ctx, cancel),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -204,7 +206,7 @@ func (a *Agent) consumeJob(job *Job, role string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RAG
|
// RAG
|
||||||
if a.options.enableKB {
|
if a.options.enableKB && a.options.ragdb != nil {
|
||||||
// Walk conversation from bottom to top, and find the first message of the user
|
// Walk conversation from bottom to top, and find the first message of the user
|
||||||
// to use it as a query to the KB
|
// to use it as a query to the KB
|
||||||
var userMessage string
|
var userMessage string
|
||||||
@@ -216,7 +218,7 @@ func (a *Agent) consumeJob(job *Job, role string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if userMessage != "" {
|
if userMessage != "" {
|
||||||
results, err := llm.FindSimilarStrings(a.storeClient, a.client, userMessage, a.options.kbResults)
|
results, err := a.options.ragdb.Search(userMessage, a.options.kbResults)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if a.options.debugMode {
|
if a.options.debugMode {
|
||||||
fmt.Println("Error finding similar strings inside KB:", err)
|
fmt.Println("Error finding similar strings inside KB:", err)
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ type options struct {
|
|||||||
permanentGoal string
|
permanentGoal string
|
||||||
periodicRuns time.Duration
|
periodicRuns time.Duration
|
||||||
kbResults int
|
kbResults int
|
||||||
|
ragdb RAGDB
|
||||||
|
|
||||||
// callbacks
|
// callbacks
|
||||||
reasoningCallback func(ActionCurrentState) bool
|
reasoningCallback func(ActionCurrentState) bool
|
||||||
@@ -102,6 +103,13 @@ var EnablePersonality = func(o *options) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithRAGDB(db RAGDB) Option {
|
||||||
|
return func(o *options) error {
|
||||||
|
o.ragdb = db
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func WithLLMAPIURL(url string) Option {
|
func WithLLMAPIURL(url string) Option {
|
||||||
return func(o *options) error {
|
return func(o *options) error {
|
||||||
o.LLMAPI.APIURL = url
|
o.LLMAPI.APIURL = url
|
||||||
|
|||||||
@@ -48,6 +48,7 @@ type AgentPool struct {
|
|||||||
agents map[string]*Agent
|
agents map[string]*Agent
|
||||||
managers map[string]Manager
|
managers map[string]Manager
|
||||||
apiURL, model string
|
apiURL, model string
|
||||||
|
ragDB RAGDB
|
||||||
}
|
}
|
||||||
|
|
||||||
type AgentPoolData map[string]AgentConfig
|
type AgentPoolData map[string]AgentConfig
|
||||||
@@ -63,7 +64,7 @@ func loadPoolFromFile(path string) (*AgentPoolData, error) {
|
|||||||
return poolData, err
|
return poolData, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAgentPool(model, apiURL, directory string) (*AgentPool, error) {
|
func NewAgentPool(model, apiURL, directory string, RagDB RAGDB) (*AgentPool, error) {
|
||||||
// if file exists, try to load an existing pool.
|
// if file exists, try to load an existing pool.
|
||||||
// if file does not exist, create a new pool.
|
// if file does not exist, create a new pool.
|
||||||
|
|
||||||
@@ -76,6 +77,7 @@ func NewAgentPool(model, apiURL, directory string) (*AgentPool, error) {
|
|||||||
pooldir: directory,
|
pooldir: directory,
|
||||||
apiURL: apiURL,
|
apiURL: apiURL,
|
||||||
model: model,
|
model: model,
|
||||||
|
ragDB: RagDB,
|
||||||
agents: make(map[string]*Agent),
|
agents: make(map[string]*Agent),
|
||||||
pool: make(map[string]AgentConfig),
|
pool: make(map[string]AgentConfig),
|
||||||
managers: make(map[string]Manager),
|
managers: make(map[string]Manager),
|
||||||
@@ -90,6 +92,7 @@ func NewAgentPool(model, apiURL, directory string) (*AgentPool, error) {
|
|||||||
file: poolfile,
|
file: poolfile,
|
||||||
apiURL: apiURL,
|
apiURL: apiURL,
|
||||||
pooldir: directory,
|
pooldir: directory,
|
||||||
|
ragDB: RagDB,
|
||||||
model: model,
|
model: model,
|
||||||
agents: make(map[string]*Agent),
|
agents: make(map[string]*Agent),
|
||||||
managers: make(map[string]Manager),
|
managers: make(map[string]Manager),
|
||||||
@@ -229,6 +232,7 @@ func (a *AgentPool) startAgentWithConfig(name string, config *AgentConfig) error
|
|||||||
),
|
),
|
||||||
WithStateFile(stateFile),
|
WithStateFile(stateFile),
|
||||||
WithCharacterFile(characterFile),
|
WithCharacterFile(characterFile),
|
||||||
|
WithRAGDB(a.ragDB),
|
||||||
WithAgentReasoningCallback(func(state ActionCurrentState) bool {
|
WithAgentReasoningCallback(func(state ActionCurrentState) bool {
|
||||||
fmt.Println("Reasoning", state.Reasoning)
|
fmt.Println("Reasoning", state.Reasoning)
|
||||||
manager.Send(
|
manager.Send(
|
||||||
|
|||||||
@@ -12,6 +12,8 @@ import (
|
|||||||
fiber "github.com/gofiber/fiber/v2"
|
fiber "github.com/gofiber/fiber/v2"
|
||||||
|
|
||||||
. "github.com/mudler/local-agent-framework/agent"
|
. "github.com/mudler/local-agent-framework/agent"
|
||||||
|
"github.com/mudler/local-agent-framework/llm"
|
||||||
|
"github.com/mudler/local-agent-framework/llm/rag"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
@@ -24,6 +26,9 @@ type (
|
|||||||
var testModel = os.Getenv("TEST_MODEL")
|
var testModel = os.Getenv("TEST_MODEL")
|
||||||
var apiURL = os.Getenv("API_URL")
|
var apiURL = os.Getenv("API_URL")
|
||||||
var apiKey = os.Getenv("API_KEY")
|
var apiKey = os.Getenv("API_KEY")
|
||||||
|
var vectorStore = os.Getenv("VECTOR_STORE")
|
||||||
|
|
||||||
|
const defaultChunkSize = 4098
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
if testModel == "" {
|
if testModel == "" {
|
||||||
@@ -50,12 +55,27 @@ func main() {
|
|||||||
stateDir := cwd + "/pool"
|
stateDir := cwd + "/pool"
|
||||||
os.MkdirAll(stateDir, 0755)
|
os.MkdirAll(stateDir, 0755)
|
||||||
|
|
||||||
pool, err := NewAgentPool(testModel, apiURL, stateDir)
|
var dbStore RAGDB
|
||||||
|
lai := llm.NewClient(apiKey, apiURL+"/v1")
|
||||||
|
|
||||||
|
switch vectorStore {
|
||||||
|
case "localai":
|
||||||
|
laiStore := rag.NewStoreClient(apiURL, apiKey)
|
||||||
|
dbStore = rag.NewLocalAIRAGDB(laiStore, lai)
|
||||||
|
default:
|
||||||
|
var err error
|
||||||
|
dbStore, err = rag.NewChromemDB("local-agent-framework", stateDir, lai)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pool, err := NewAgentPool(testModel, apiURL, stateDir, dbStore)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
db, err := NewInMemoryDB(stateDir)
|
db, err := NewInMemoryDB(stateDir, dbStore)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
@@ -144,6 +164,10 @@ func (a *App) KnowledgeBase(db *InMemoryDatabase) func(c *fiber.Ctx) error {
|
|||||||
if website == "" {
|
if website == "" {
|
||||||
return fmt.Errorf("please enter a URL")
|
return fmt.Errorf("please enter a URL")
|
||||||
}
|
}
|
||||||
|
chunkSize := defaultChunkSize
|
||||||
|
if payload.ChunkSize > 0 {
|
||||||
|
chunkSize = payload.ChunkSize
|
||||||
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
content, err := Sitemap(website)
|
content, err := Sitemap(website)
|
||||||
@@ -151,9 +175,10 @@ func (a *App) KnowledgeBase(db *InMemoryDatabase) func(c *fiber.Ctx) error {
|
|||||||
fmt.Println("Error walking sitemap for website", err)
|
fmt.Println("Error walking sitemap for website", err)
|
||||||
}
|
}
|
||||||
fmt.Println("Found pages: ", len(content))
|
fmt.Println("Found pages: ", len(content))
|
||||||
|
fmt.Println("ChunkSize: ", chunkSize)
|
||||||
|
|
||||||
for _, c := range content {
|
for _, c := range content {
|
||||||
chunks := splitParagraphIntoChunks(c, payload.ChunkSize)
|
chunks := splitParagraphIntoChunks(c, chunkSize)
|
||||||
fmt.Println("chunks: ", len(chunks))
|
fmt.Println("chunks: ", len(chunks))
|
||||||
for _, chunk := range chunks {
|
for _, chunk := range chunks {
|
||||||
db.AddEntry(chunk)
|
db.AddEntry(chunk)
|
||||||
@@ -162,7 +187,7 @@ func (a *App) KnowledgeBase(db *InMemoryDatabase) func(c *fiber.Ctx) error {
|
|||||||
db.SaveDB()
|
db.SaveDB()
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := db.SaveToStore(apiKey, apiURL); err != nil {
|
if err := db.SaveToStore(); err != nil {
|
||||||
fmt.Println("Error storing in the KB", err)
|
fmt.Println("Error storing in the KB", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|||||||
@@ -10,9 +10,9 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
. "github.com/mudler/local-agent-framework/agent"
|
||||||
"jaytaylor.com/html2text"
|
"jaytaylor.com/html2text"
|
||||||
|
|
||||||
"github.com/mudler/local-agent-framework/llm"
|
|
||||||
sitemap "github.com/oxffaa/gopher-parse-sitemap"
|
sitemap "github.com/oxffaa/gopher-parse-sitemap"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -20,6 +20,7 @@ type InMemoryDatabase struct {
|
|||||||
sync.Mutex
|
sync.Mutex
|
||||||
Database []string
|
Database []string
|
||||||
path string
|
path string
|
||||||
|
rag RAGDB
|
||||||
}
|
}
|
||||||
|
|
||||||
func loadDB(path string) ([]string, error) {
|
func loadDB(path string) ([]string, error) {
|
||||||
@@ -33,7 +34,7 @@ func loadDB(path string) ([]string, error) {
|
|||||||
return poolData, err
|
return poolData, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewInMemoryDB(knowledgebase string) (*InMemoryDatabase, error) {
|
func NewInMemoryDB(knowledgebase string, store RAGDB) (*InMemoryDatabase, error) {
|
||||||
// if file exists, try to load an existing pool.
|
// if file exists, try to load an existing pool.
|
||||||
// if file does not exist, create a new pool.
|
// if file does not exist, create a new pool.
|
||||||
|
|
||||||
@@ -44,6 +45,7 @@ func NewInMemoryDB(knowledgebase string) (*InMemoryDatabase, error) {
|
|||||||
return &InMemoryDatabase{
|
return &InMemoryDatabase{
|
||||||
Database: []string{},
|
Database: []string{},
|
||||||
path: poolfile,
|
path: poolfile,
|
||||||
|
rag: store,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -54,15 +56,17 @@ func NewInMemoryDB(knowledgebase string) (*InMemoryDatabase, error) {
|
|||||||
return &InMemoryDatabase{
|
return &InMemoryDatabase{
|
||||||
Database: poolData,
|
Database: poolData,
|
||||||
path: poolfile,
|
path: poolfile,
|
||||||
|
rag: store,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *InMemoryDatabase) SaveToStore(apiKey string, apiURL string) error {
|
func (db *InMemoryDatabase) SaveToStore() error {
|
||||||
for _, d := range db.Database {
|
for _, d := range db.Database {
|
||||||
lai := llm.NewClient(apiKey, apiURL+"/v1")
|
if d == "" {
|
||||||
laiStore := llm.NewStoreClient(apiURL, apiKey)
|
// skip empty chunks
|
||||||
|
continue
|
||||||
err := llm.StoreStringEmbeddingInVectorDB(laiStore, lai, d)
|
}
|
||||||
|
err := db.rag.Store(d)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Error storing in the KB: %w", err)
|
return fmt.Errorf("Error storing in the KB: %w", err)
|
||||||
}
|
}
|
||||||
@@ -119,8 +123,6 @@ func Sitemap(url string) (res []string, err error) {
|
|||||||
// and returns a slice of strings where each string is a chunk of the paragraph
|
// and returns a slice of strings where each string is a chunk of the paragraph
|
||||||
// that is at most maxChunkSize long, ensuring that words are not split.
|
// that is at most maxChunkSize long, ensuring that words are not split.
|
||||||
func splitParagraphIntoChunks(paragraph string, maxChunkSize int) []string {
|
func splitParagraphIntoChunks(paragraph string, maxChunkSize int) []string {
|
||||||
// Check if the paragraph length is less than or equal to maxChunkSize.
|
|
||||||
// If so, return the paragraph as the only chunk.
|
|
||||||
if len(paragraph) <= maxChunkSize {
|
if len(paragraph) <= maxChunkSize {
|
||||||
return []string{paragraph}
|
return []string{paragraph}
|
||||||
}
|
}
|
||||||
@@ -131,11 +133,14 @@ func splitParagraphIntoChunks(paragraph string, maxChunkSize int) []string {
|
|||||||
words := strings.Fields(paragraph) // Splits the paragraph into words.
|
words := strings.Fields(paragraph) // Splits the paragraph into words.
|
||||||
|
|
||||||
for _, word := range words {
|
for _, word := range words {
|
||||||
// Check if adding the next word would exceed the maxChunkSize.
|
// If adding the next word would exceed maxChunkSize (considering a space if not the first word in a chunk),
|
||||||
// If so, add the currentChunk to the chunks slice and start a new chunk.
|
// add the currentChunk to chunks, and reset currentChunk.
|
||||||
if currentChunk.Len()+len(word) > maxChunkSize {
|
if currentChunk.Len() > 0 && currentChunk.Len()+len(word)+1 > maxChunkSize { // +1 for the space if not the first word
|
||||||
chunks = append(chunks, currentChunk.String())
|
chunks = append(chunks, currentChunk.String())
|
||||||
currentChunk.Reset()
|
currentChunk.Reset()
|
||||||
|
} else if currentChunk.Len() == 0 && len(word) > maxChunkSize { // Word itself exceeds maxChunkSize, split the word
|
||||||
|
chunks = append(chunks, word)
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add a space before the word if it's not the beginning of a new chunk.
|
// Add a space before the word if it's not the beginning of a new chunk.
|
||||||
@@ -147,7 +152,7 @@ func splitParagraphIntoChunks(paragraph string, maxChunkSize int) []string {
|
|||||||
currentChunk.WriteString(word)
|
currentChunk.WriteString(word)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add the last chunk if it's not empty.
|
// After the loop, add any remaining content in currentChunk to chunks.
|
||||||
if currentChunk.Len() > 0 {
|
if currentChunk.Len() > 0 {
|
||||||
chunks = append(chunks, currentChunk.String())
|
chunks = append(chunks, currentChunk.String())
|
||||||
}
|
}
|
||||||
|
|||||||
1
go.mod
1
go.mod
@@ -35,6 +35,7 @@ require (
|
|||||||
github.com/mattn/go-runewidth v0.0.15 // indirect
|
github.com/mattn/go-runewidth v0.0.15 // indirect
|
||||||
github.com/olekukonko/tablewriter v0.0.5 // indirect
|
github.com/olekukonko/tablewriter v0.0.5 // indirect
|
||||||
github.com/oxffaa/gopher-parse-sitemap v0.0.0-20191021113419-005d2eb1def4 // indirect
|
github.com/oxffaa/gopher-parse-sitemap v0.0.0-20191021113419-005d2eb1def4 // indirect
|
||||||
|
github.com/philippgille/chromem-go v0.5.0 // indirect
|
||||||
github.com/rivo/uniseg v0.2.0 // indirect
|
github.com/rivo/uniseg v0.2.0 // indirect
|
||||||
github.com/ssor/bom v0.0.0-20170718123548-6386211fdfcf // indirect
|
github.com/ssor/bom v0.0.0-20170718123548-6386211fdfcf // indirect
|
||||||
github.com/stretchr/testify v1.9.0 // indirect
|
github.com/stretchr/testify v1.9.0 // indirect
|
||||||
|
|||||||
2
go.sum
2
go.sum
@@ -59,6 +59,8 @@ github.com/onsi/gomega v1.31.1 h1:KYppCUK+bUgAZwHOu7EXVBKyQA6ILvOESHkn/tgoqvo=
|
|||||||
github.com/onsi/gomega v1.31.1/go.mod h1:y40C95dwAD1Nz36SsEnxvfFe8FFfNxzI5eJ0EYGyAy0=
|
github.com/onsi/gomega v1.31.1/go.mod h1:y40C95dwAD1Nz36SsEnxvfFe8FFfNxzI5eJ0EYGyAy0=
|
||||||
github.com/oxffaa/gopher-parse-sitemap v0.0.0-20191021113419-005d2eb1def4 h1:2vmb32OdDhjZf2ETGDlr9n8RYXx7c+jXPxMiPbwnA+8=
|
github.com/oxffaa/gopher-parse-sitemap v0.0.0-20191021113419-005d2eb1def4 h1:2vmb32OdDhjZf2ETGDlr9n8RYXx7c+jXPxMiPbwnA+8=
|
||||||
github.com/oxffaa/gopher-parse-sitemap v0.0.0-20191021113419-005d2eb1def4/go.mod h1:2JQx4jDHmWrbABvpOayg/+OTU6ehN0IyK2EHzceXpJo=
|
github.com/oxffaa/gopher-parse-sitemap v0.0.0-20191021113419-005d2eb1def4/go.mod h1:2JQx4jDHmWrbABvpOayg/+OTU6ehN0IyK2EHzceXpJo=
|
||||||
|
github.com/philippgille/chromem-go v0.5.0 h1:bryX0F3N6jnN/21iBd8i2/k9EzPTZn3nyiqAti19si8=
|
||||||
|
github.com/philippgille/chromem-go v0.5.0/go.mod h1:hTd+wGEm/fFPQl7ilfCwQXkgEUxceYh86iIdoKMolPo=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
||||||
|
|||||||
88
llm/rag/rag_chromem.go
Normal file
88
llm/rag/rag_chromem.go
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
package rag
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"runtime"
|
||||||
|
|
||||||
|
"github.com/philippgille/chromem-go"
|
||||||
|
"github.com/sashabaranov/go-openai"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ChromemDB struct {
|
||||||
|
collectionName string
|
||||||
|
collection *chromem.Collection
|
||||||
|
index int
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewChromemDB(collection, path string, openaiClient *openai.Client) (*ChromemDB, error) {
|
||||||
|
// db, err := chromem.NewPersistentDB(path, true)
|
||||||
|
// if err != nil {
|
||||||
|
// return nil, err
|
||||||
|
// }
|
||||||
|
db := chromem.NewDB()
|
||||||
|
|
||||||
|
embeddingFunc := chromem.EmbeddingFunc(
|
||||||
|
func(ctx context.Context, text string) ([]float32, error) {
|
||||||
|
fmt.Println("Creating embeddings")
|
||||||
|
resp, err := openaiClient.CreateEmbeddings(ctx,
|
||||||
|
openai.EmbeddingRequestStrings{
|
||||||
|
Input: []string{text},
|
||||||
|
Model: openai.AdaEmbeddingV2,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return []float32{}, fmt.Errorf("error getting keys: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(resp.Data) == 0 {
|
||||||
|
return []float32{}, fmt.Errorf("no response from OpenAI API")
|
||||||
|
}
|
||||||
|
|
||||||
|
embedding := resp.Data[0].Embedding
|
||||||
|
|
||||||
|
return embedding, nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
c, err := db.GetOrCreateCollection(collection, nil, embeddingFunc)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ChromemDB{
|
||||||
|
collectionName: collection,
|
||||||
|
collection: c,
|
||||||
|
index: 1,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChromemDB) Store(s string) error {
|
||||||
|
defer func() {
|
||||||
|
c.index++
|
||||||
|
}()
|
||||||
|
if s == "" {
|
||||||
|
return fmt.Errorf("empty string")
|
||||||
|
}
|
||||||
|
fmt.Println("Trying to store", s)
|
||||||
|
return c.collection.AddDocuments(context.Background(), []chromem.Document{
|
||||||
|
{
|
||||||
|
Content: s,
|
||||||
|
ID: fmt.Sprint(c.index),
|
||||||
|
},
|
||||||
|
}, runtime.NumCPU())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChromemDB) Search(s string, similarEntries int) ([]string, error) {
|
||||||
|
res, err := c.collection.Query(context.Background(), s, similarEntries, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var results []string
|
||||||
|
for _, r := range res {
|
||||||
|
results = append(results, r.Content)
|
||||||
|
}
|
||||||
|
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package llm
|
package rag
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -7,8 +7,20 @@ import (
|
|||||||
"github.com/sashabaranov/go-openai"
|
"github.com/sashabaranov/go-openai"
|
||||||
)
|
)
|
||||||
|
|
||||||
func StoreStringEmbeddingInVectorDB(client *StoreClient, openaiClient *openai.Client, s string) error {
|
type LocalAIRAGDB struct {
|
||||||
resp, err := openaiClient.CreateEmbeddings(context.TODO(),
|
client *StoreClient
|
||||||
|
openaiClient *openai.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewLocalAIRAGDB(storeClient *StoreClient, openaiClient *openai.Client) *LocalAIRAGDB {
|
||||||
|
return &LocalAIRAGDB{
|
||||||
|
client: storeClient,
|
||||||
|
openaiClient: openaiClient,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *LocalAIRAGDB) Store(s string) error {
|
||||||
|
resp, err := db.openaiClient.CreateEmbeddings(context.TODO(),
|
||||||
openai.EmbeddingRequestStrings{
|
openai.EmbeddingRequestStrings{
|
||||||
Input: []string{s},
|
Input: []string{s},
|
||||||
Model: openai.AdaEmbeddingV2,
|
Model: openai.AdaEmbeddingV2,
|
||||||
@@ -28,7 +40,7 @@ func StoreStringEmbeddingInVectorDB(client *StoreClient, openaiClient *openai.Cl
|
|||||||
Keys: [][]float32{embedding},
|
Keys: [][]float32{embedding},
|
||||||
Values: []string{s},
|
Values: []string{s},
|
||||||
}
|
}
|
||||||
err = client.Set(setReq)
|
err = db.client.Set(setReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error setting keys: %v", err)
|
return fmt.Errorf("error setting keys: %v", err)
|
||||||
}
|
}
|
||||||
@@ -36,9 +48,8 @@ func StoreStringEmbeddingInVectorDB(client *StoreClient, openaiClient *openai.Cl
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func FindSimilarStrings(client *StoreClient, openaiClient *openai.Client, s string, similarEntries int) ([]string, error) {
|
func (db *LocalAIRAGDB) Search(s string, similarEntries int) ([]string, error) {
|
||||||
|
resp, err := db.openaiClient.CreateEmbeddings(context.TODO(),
|
||||||
resp, err := openaiClient.CreateEmbeddings(context.TODO(),
|
|
||||||
openai.EmbeddingRequestStrings{
|
openai.EmbeddingRequestStrings{
|
||||||
Input: []string{s},
|
Input: []string{s},
|
||||||
Model: openai.AdaEmbeddingV2,
|
Model: openai.AdaEmbeddingV2,
|
||||||
@@ -58,7 +69,7 @@ func FindSimilarStrings(client *StoreClient, openaiClient *openai.Client, s stri
|
|||||||
TopK: similarEntries, // Number of similar entries you want to find
|
TopK: similarEntries, // Number of similar entries you want to find
|
||||||
Key: embedding, // The key you're looking for similarities to
|
Key: embedding, // The key you're looking for similarities to
|
||||||
}
|
}
|
||||||
findResp, err := client.Find(findReq)
|
findResp, err := db.client.Find(findReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return []string{}, fmt.Errorf("error finding keys: %v", err)
|
return []string{}, fmt.Errorf("error finding keys: %v", err)
|
||||||
}
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package llm
|
package rag
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
Reference in New Issue
Block a user