diff --git a/example/webui/main.go b/example/webui/main.go index e1df6f2..6e73b85 100644 --- a/example/webui/main.go +++ b/example/webui/main.go @@ -23,6 +23,7 @@ var apiKey = os.Getenv("API_KEY") var vectorStore = os.Getenv("VECTOR_STORE") var kbdisableIndexing = os.Getenv("KBDISABLEINDEX") var timeout = os.Getenv("TIMEOUT") +var embeddingModel = os.Getenv("EMBEDDING_MODEL") const defaultChunkSize = 4098 @@ -63,7 +64,7 @@ func main() { dbStore = rag.NewLocalAIRAGDB(laiStore, lai) default: var err error - dbStore, err = rag.NewChromemDB("local-agent-framework", stateDir, lai) + dbStore, err = rag.NewChromemDB("local-agent-framework", stateDir, lai, embeddingModel) if err != nil { panic(err) } diff --git a/llm/rag/chromem.go b/llm/rag/chromem.go index 617e769..c109a01 100644 --- a/llm/rag/chromem.go +++ b/llm/rag/chromem.go @@ -10,14 +10,15 @@ import ( ) type ChromemDB struct { - collectionName string - collection *chromem.Collection - index int - client *openai.Client - db *chromem.DB + collectionName string + collection *chromem.Collection + index int + client *openai.Client + db *chromem.DB + embeddingsModel string } -func NewChromemDB(collection, path string, openaiClient *openai.Client) (*ChromemDB, error) { +func NewChromemDB(collection, path string, openaiClient *openai.Client, embeddingsModel string) (*ChromemDB, error) { // db, err := chromem.NewPersistentDB(path, true) // if err != nil { // return nil, err @@ -25,10 +26,11 @@ func NewChromemDB(collection, path string, openaiClient *openai.Client) (*Chrome db := chromem.NewDB() chromem := &ChromemDB{ - collectionName: collection, - index: 1, - db: db, - client: openaiClient, + collectionName: collection, + index: 1, + db: db, + client: openaiClient, + embeddingsModel: embeddingsModel, } c, err := db.GetOrCreateCollection(collection, nil, chromem.embedding()) @@ -59,7 +61,7 @@ func (c *ChromemDB) embedding() chromem.EmbeddingFunc { resp, err := c.client.CreateEmbeddings(ctx, openai.EmbeddingRequestStrings{ Input: []string{text}, - Model: openai.AdaEmbeddingV2, + Model: openai.EmbeddingModel(c.embeddingsModel), }, ) if err != nil {