Make Knowledgebase RAG functional (almost)
This commit is contained in:
88
llm/rag/rag_chromem.go
Normal file
88
llm/rag/rag_chromem.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package rag
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime"
|
||||
|
||||
"github.com/philippgille/chromem-go"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
type ChromemDB struct {
|
||||
collectionName string
|
||||
collection *chromem.Collection
|
||||
index int
|
||||
}
|
||||
|
||||
func NewChromemDB(collection, path string, openaiClient *openai.Client) (*ChromemDB, error) {
|
||||
// db, err := chromem.NewPersistentDB(path, true)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
db := chromem.NewDB()
|
||||
|
||||
embeddingFunc := chromem.EmbeddingFunc(
|
||||
func(ctx context.Context, text string) ([]float32, error) {
|
||||
fmt.Println("Creating embeddings")
|
||||
resp, err := openaiClient.CreateEmbeddings(ctx,
|
||||
openai.EmbeddingRequestStrings{
|
||||
Input: []string{text},
|
||||
Model: openai.AdaEmbeddingV2,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return []float32{}, fmt.Errorf("error getting keys: %v", err)
|
||||
}
|
||||
|
||||
if len(resp.Data) == 0 {
|
||||
return []float32{}, fmt.Errorf("no response from OpenAI API")
|
||||
}
|
||||
|
||||
embedding := resp.Data[0].Embedding
|
||||
|
||||
return embedding, nil
|
||||
},
|
||||
)
|
||||
|
||||
c, err := db.GetOrCreateCollection(collection, nil, embeddingFunc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &ChromemDB{
|
||||
collectionName: collection,
|
||||
collection: c,
|
||||
index: 1,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *ChromemDB) Store(s string) error {
|
||||
defer func() {
|
||||
c.index++
|
||||
}()
|
||||
if s == "" {
|
||||
return fmt.Errorf("empty string")
|
||||
}
|
||||
fmt.Println("Trying to store", s)
|
||||
return c.collection.AddDocuments(context.Background(), []chromem.Document{
|
||||
{
|
||||
Content: s,
|
||||
ID: fmt.Sprint(c.index),
|
||||
},
|
||||
}, runtime.NumCPU())
|
||||
}
|
||||
|
||||
func (c *ChromemDB) Search(s string, similarEntries int) ([]string, error) {
|
||||
res, err := c.collection.Query(context.Background(), s, similarEntries, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var results []string
|
||||
for _, r := range res {
|
||||
results = append(results, r.Content)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package llm
|
||||
package rag
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -7,8 +7,20 @@ import (
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
func StoreStringEmbeddingInVectorDB(client *StoreClient, openaiClient *openai.Client, s string) error {
|
||||
resp, err := openaiClient.CreateEmbeddings(context.TODO(),
|
||||
type LocalAIRAGDB struct {
|
||||
client *StoreClient
|
||||
openaiClient *openai.Client
|
||||
}
|
||||
|
||||
func NewLocalAIRAGDB(storeClient *StoreClient, openaiClient *openai.Client) *LocalAIRAGDB {
|
||||
return &LocalAIRAGDB{
|
||||
client: storeClient,
|
||||
openaiClient: openaiClient,
|
||||
}
|
||||
}
|
||||
|
||||
func (db *LocalAIRAGDB) Store(s string) error {
|
||||
resp, err := db.openaiClient.CreateEmbeddings(context.TODO(),
|
||||
openai.EmbeddingRequestStrings{
|
||||
Input: []string{s},
|
||||
Model: openai.AdaEmbeddingV2,
|
||||
@@ -28,7 +40,7 @@ func StoreStringEmbeddingInVectorDB(client *StoreClient, openaiClient *openai.Cl
|
||||
Keys: [][]float32{embedding},
|
||||
Values: []string{s},
|
||||
}
|
||||
err = client.Set(setReq)
|
||||
err = db.client.Set(setReq)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error setting keys: %v", err)
|
||||
}
|
||||
@@ -36,9 +48,8 @@ func StoreStringEmbeddingInVectorDB(client *StoreClient, openaiClient *openai.Cl
|
||||
return nil
|
||||
}
|
||||
|
||||
func FindSimilarStrings(client *StoreClient, openaiClient *openai.Client, s string, similarEntries int) ([]string, error) {
|
||||
|
||||
resp, err := openaiClient.CreateEmbeddings(context.TODO(),
|
||||
func (db *LocalAIRAGDB) Search(s string, similarEntries int) ([]string, error) {
|
||||
resp, err := db.openaiClient.CreateEmbeddings(context.TODO(),
|
||||
openai.EmbeddingRequestStrings{
|
||||
Input: []string{s},
|
||||
Model: openai.AdaEmbeddingV2,
|
||||
@@ -58,7 +69,7 @@ func FindSimilarStrings(client *StoreClient, openaiClient *openai.Client, s stri
|
||||
TopK: similarEntries, // Number of similar entries you want to find
|
||||
Key: embedding, // The key you're looking for similarities to
|
||||
}
|
||||
findResp, err := client.Find(findReq)
|
||||
findResp, err := db.client.Find(findReq)
|
||||
if err != nil {
|
||||
return []string{}, fmt.Errorf("error finding keys: %v", err)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package llm
|
||||
package rag
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
Reference in New Issue
Block a user