Customize embedding model

This commit is contained in:
mudler
2024-12-17 19:53:30 +01:00
parent 2561f2f175
commit 6bafb48cec
2 changed files with 15 additions and 12 deletions

View File

@@ -23,6 +23,7 @@ var apiKey = os.Getenv("API_KEY")
var vectorStore = os.Getenv("VECTOR_STORE") var vectorStore = os.Getenv("VECTOR_STORE")
var kbdisableIndexing = os.Getenv("KBDISABLEINDEX") var kbdisableIndexing = os.Getenv("KBDISABLEINDEX")
var timeout = os.Getenv("TIMEOUT") var timeout = os.Getenv("TIMEOUT")
var embeddingModel = os.Getenv("EMBEDDING_MODEL")
const defaultChunkSize = 4098 const defaultChunkSize = 4098
@@ -63,7 +64,7 @@ func main() {
dbStore = rag.NewLocalAIRAGDB(laiStore, lai) dbStore = rag.NewLocalAIRAGDB(laiStore, lai)
default: default:
var err error var err error
dbStore, err = rag.NewChromemDB("local-agent-framework", stateDir, lai) dbStore, err = rag.NewChromemDB("local-agent-framework", stateDir, lai, embeddingModel)
if err != nil { if err != nil {
panic(err) panic(err)
} }

View File

@@ -10,14 +10,15 @@ import (
) )
type ChromemDB struct { type ChromemDB struct {
collectionName string collectionName string
collection *chromem.Collection collection *chromem.Collection
index int index int
client *openai.Client client *openai.Client
db *chromem.DB db *chromem.DB
embeddingsModel string
} }
func NewChromemDB(collection, path string, openaiClient *openai.Client) (*ChromemDB, error) { func NewChromemDB(collection, path string, openaiClient *openai.Client, embeddingsModel string) (*ChromemDB, error) {
// db, err := chromem.NewPersistentDB(path, true) // db, err := chromem.NewPersistentDB(path, true)
// if err != nil { // if err != nil {
// return nil, err // return nil, err
@@ -25,10 +26,11 @@ func NewChromemDB(collection, path string, openaiClient *openai.Client) (*Chrome
db := chromem.NewDB() db := chromem.NewDB()
chromem := &ChromemDB{ chromem := &ChromemDB{
collectionName: collection, collectionName: collection,
index: 1, index: 1,
db: db, db: db,
client: openaiClient, client: openaiClient,
embeddingsModel: embeddingsModel,
} }
c, err := db.GetOrCreateCollection(collection, nil, chromem.embedding()) c, err := db.GetOrCreateCollection(collection, nil, chromem.embedding())
@@ -59,7 +61,7 @@ func (c *ChromemDB) embedding() chromem.EmbeddingFunc {
resp, err := c.client.CreateEmbeddings(ctx, resp, err := c.client.CreateEmbeddings(ctx,
openai.EmbeddingRequestStrings{ openai.EmbeddingRequestStrings{
Input: []string{text}, Input: []string{text},
Model: openai.AdaEmbeddingV2, Model: openai.EmbeddingModel(c.embeddingsModel),
}, },
) )
if err != nil { if err != nil {