Customize embedding model

This commit is contained in:
mudler
2024-12-17 19:53:30 +01:00
parent 2561f2f175
commit 6bafb48cec
2 changed files with 15 additions and 12 deletions

View File

@@ -23,6 +23,7 @@ var apiKey = os.Getenv("API_KEY")
var vectorStore = os.Getenv("VECTOR_STORE") var vectorStore = os.Getenv("VECTOR_STORE")
var kbdisableIndexing = os.Getenv("KBDISABLEINDEX") var kbdisableIndexing = os.Getenv("KBDISABLEINDEX")
var timeout = os.Getenv("TIMEOUT") var timeout = os.Getenv("TIMEOUT")
var embeddingModel = os.Getenv("EMBEDDING_MODEL")
const defaultChunkSize = 4098 const defaultChunkSize = 4098
@@ -63,7 +64,7 @@ func main() {
dbStore = rag.NewLocalAIRAGDB(laiStore, lai) dbStore = rag.NewLocalAIRAGDB(laiStore, lai)
default: default:
var err error var err error
dbStore, err = rag.NewChromemDB("local-agent-framework", stateDir, lai) dbStore, err = rag.NewChromemDB("local-agent-framework", stateDir, lai, embeddingModel)
if err != nil { if err != nil {
panic(err) panic(err)
} }

View File

@@ -15,9 +15,10 @@ type ChromemDB struct {
index int index int
client *openai.Client client *openai.Client
db *chromem.DB db *chromem.DB
embeddingsModel string
} }
func NewChromemDB(collection, path string, openaiClient *openai.Client) (*ChromemDB, error) { func NewChromemDB(collection, path string, openaiClient *openai.Client, embeddingsModel string) (*ChromemDB, error) {
// db, err := chromem.NewPersistentDB(path, true) // db, err := chromem.NewPersistentDB(path, true)
// if err != nil { // if err != nil {
// return nil, err // return nil, err
@@ -29,6 +30,7 @@ func NewChromemDB(collection, path string, openaiClient *openai.Client) (*Chrome
index: 1, index: 1,
db: db, db: db,
client: openaiClient, client: openaiClient,
embeddingsModel: embeddingsModel,
} }
c, err := db.GetOrCreateCollection(collection, nil, chromem.embedding()) c, err := db.GetOrCreateCollection(collection, nil, chromem.embedding())
@@ -59,7 +61,7 @@ func (c *ChromemDB) embedding() chromem.EmbeddingFunc {
resp, err := c.client.CreateEmbeddings(ctx, resp, err := c.client.CreateEmbeddings(ctx,
openai.EmbeddingRequestStrings{ openai.EmbeddingRequestStrings{
Input: []string{text}, Input: []string{text},
Model: openai.AdaEmbeddingV2, Model: openai.EmbeddingModel(c.embeddingsModel),
}, },
) )
if err != nil { if err != nil {