Customize embedding model
This commit is contained in:
@@ -23,6 +23,7 @@ var apiKey = os.Getenv("API_KEY")
|
|||||||
var vectorStore = os.Getenv("VECTOR_STORE")
|
var vectorStore = os.Getenv("VECTOR_STORE")
|
||||||
var kbdisableIndexing = os.Getenv("KBDISABLEINDEX")
|
var kbdisableIndexing = os.Getenv("KBDISABLEINDEX")
|
||||||
var timeout = os.Getenv("TIMEOUT")
|
var timeout = os.Getenv("TIMEOUT")
|
||||||
|
var embeddingModel = os.Getenv("EMBEDDING_MODEL")
|
||||||
|
|
||||||
const defaultChunkSize = 4098
|
const defaultChunkSize = 4098
|
||||||
|
|
||||||
@@ -63,7 +64,7 @@ func main() {
|
|||||||
dbStore = rag.NewLocalAIRAGDB(laiStore, lai)
|
dbStore = rag.NewLocalAIRAGDB(laiStore, lai)
|
||||||
default:
|
default:
|
||||||
var err error
|
var err error
|
||||||
dbStore, err = rag.NewChromemDB("local-agent-framework", stateDir, lai)
|
dbStore, err = rag.NewChromemDB("local-agent-framework", stateDir, lai, embeddingModel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,9 +15,10 @@ type ChromemDB struct {
|
|||||||
index int
|
index int
|
||||||
client *openai.Client
|
client *openai.Client
|
||||||
db *chromem.DB
|
db *chromem.DB
|
||||||
|
embeddingsModel string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewChromemDB(collection, path string, openaiClient *openai.Client) (*ChromemDB, error) {
|
func NewChromemDB(collection, path string, openaiClient *openai.Client, embeddingsModel string) (*ChromemDB, error) {
|
||||||
// db, err := chromem.NewPersistentDB(path, true)
|
// db, err := chromem.NewPersistentDB(path, true)
|
||||||
// if err != nil {
|
// if err != nil {
|
||||||
// return nil, err
|
// return nil, err
|
||||||
@@ -29,6 +30,7 @@ func NewChromemDB(collection, path string, openaiClient *openai.Client) (*Chrome
|
|||||||
index: 1,
|
index: 1,
|
||||||
db: db,
|
db: db,
|
||||||
client: openaiClient,
|
client: openaiClient,
|
||||||
|
embeddingsModel: embeddingsModel,
|
||||||
}
|
}
|
||||||
|
|
||||||
c, err := db.GetOrCreateCollection(collection, nil, chromem.embedding())
|
c, err := db.GetOrCreateCollection(collection, nil, chromem.embedding())
|
||||||
@@ -59,7 +61,7 @@ func (c *ChromemDB) embedding() chromem.EmbeddingFunc {
|
|||||||
resp, err := c.client.CreateEmbeddings(ctx,
|
resp, err := c.client.CreateEmbeddings(ctx,
|
||||||
openai.EmbeddingRequestStrings{
|
openai.EmbeddingRequestStrings{
|
||||||
Input: []string{text},
|
Input: []string{text},
|
||||||
Model: openai.AdaEmbeddingV2,
|
Model: openai.EmbeddingModel(c.embeddingsModel),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
Reference in New Issue
Block a user