rag: add KB to conversation

This commit is contained in:
Ettore Di Giacinto
2024-04-09 22:34:22 +02:00
parent 36abf837a9
commit 78ba7871e9
4 changed files with 87 additions and 27 deletions

View File

@@ -23,6 +23,7 @@ type Agent struct {
options *options
Character Character
client *openai.Client
storeClient *llm.StoreClient
jobQueue chan *Job
actionContext *action.ActionContext
context *action.ActionContext
@@ -43,6 +44,7 @@ func New(opts ...Option) (*Agent, error) {
}
client := llm.NewClient(options.LLMAPI.APIKey, options.LLMAPI.APIURL)
storeClient := llm.NewStoreClient(options.LLMAPI.APIURL, options.LLMAPI.APIKey)
c := context.Background()
if options.context != nil {
@@ -56,6 +58,7 @@ func New(opts ...Option) (*Agent, error) {
client: client,
Character: options.character,
currentState: &action.StateResult{},
storeClient: storeClient,
context: action.NewContext(ctx, cancel),
}
@@ -168,8 +171,9 @@ func (a *Agent) runAction(chosenAction Action, decisionResult *decisionResult) (
}
func (a *Agent) consumeJob(job *Job, role string) {
// We are self evaluating if we consume the job as a system role
selfEvaluation := role == SystemRole
// Consume the job and generate a response
a.Lock()
// Set the action context
ctx, cancel := context.WithCancel(context.Background())
@@ -199,6 +203,42 @@ func (a *Agent) consumeJob(job *Job, role string) {
})
}
// RAG
if a.options.enableKB {
// Walk conversation from bottom to top, and find the first message of the user
// to use it as a query to the KB
var userMessage string
for i := len(a.currentConversation) - 1; i >= 0; i-- {
if a.currentConversation[i].Role == "user" {
userMessage = a.currentConversation[i].Content
break
}
}
if userMessage != "" {
results, err := llm.FindSimilarStrings(a.storeClient, a.client, userMessage, a.options.kbResults)
if err != nil {
job.Result.Finish(fmt.Errorf("error finding similar strings inside KB: %w", err))
return
}
formatResults := ""
for _, r := range results {
formatResults += fmt.Sprintf("- %s \n", r)
}
if a.options.debugMode {
fmt.Println("Found similar strings in KB:")
fmt.Println(formatResults)
}
a.currentConversation = append(a.currentConversation,
openai.ChatCompletionMessage{
Role: "system",
Content: fmt.Sprintf("Given the user input you have the following in memory:\n%s", formatResults),
},
)
}
}
var pickTemplate string
var reEvaluationTemplate string

View File

@@ -19,7 +19,7 @@ type options struct {
randomIdentityGuidance string
randomIdentity bool
userActions Actions
enableHUD, standaloneJob, showCharacter bool
enableHUD, standaloneJob, showCharacter, enableKB bool
debugMode bool
initiateConversations bool
characterfile string
@@ -27,6 +27,7 @@ type options struct {
context context.Context
permanentGoal string
periodicRuns time.Duration
kbResults int
// callbacks
reasoningCallback func(ActionCurrentState) bool
@@ -65,6 +66,20 @@ var EnableHUD = func(o *options) error {
return nil
}
var EnableKnowledgeBase = func(o *options) error {
o.enableKB = true
o.kbResults = 5
return nil
}
func EnableKnowledgeBaseWithResults(results int) Option {
return func(o *options) error {
o.enableKB = true
o.kbResults = results
return nil
}
}
var EnableInitiateConversations = func(o *options) error {
o.initiateConversations = true
return nil

View File

@@ -7,10 +7,7 @@ import (
"github.com/sashabaranov/go-openai"
)
func StoreStringEmbeddingInVectorDB(apiHost string, openaiClient *openai.Client, s string) error {
// Example usage
client := NewStoreClient(apiHost)
func StoreStringEmbeddingInVectorDB(client *StoreClient, openaiClient *openai.Client, s string) error {
resp, err := openaiClient.CreateEmbeddings(context.TODO(),
openai.EmbeddingRequestStrings{
Input: []string{s},
@@ -39,8 +36,7 @@ func StoreStringEmbeddingInVectorDB(apiHost string, openaiClient *openai.Client,
return nil
}
func FindSimilarStrings(apiHost string, openaiClient *openai.Client, s string, similarEntries int) ([]string, error) {
client := NewStoreClient(apiHost)
func FindSimilarStrings(client *StoreClient, openaiClient *openai.Client, s string, similarEntries int) ([]string, error) {
resp, err := openaiClient.CreateEmbeddings(context.TODO(),
openai.EmbeddingRequestStrings{

View File

@@ -11,6 +11,7 @@ import (
// Define a struct to hold your store API client
type StoreClient struct {
BaseURL string
APIToken string
Client *http.Client
}
@@ -45,9 +46,10 @@ type FindResponse struct {
}
// Constructor for StoreClient
func NewStoreClient(baseUrl string) *StoreClient {
func NewStoreClient(baseUrl, apiToken string) *StoreClient {
return &StoreClient{
BaseURL: baseUrl,
APIToken: apiToken,
Client: &http.Client{},
}
}
@@ -105,6 +107,10 @@ func (c *StoreClient) doRequest(path string, data interface{}) error {
if err != nil {
return err
}
// Set Bearer token
if c.APIToken != "" {
req.Header.Set("Authorization", "Bearer "+c.APIToken)
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.Client.Do(req)
@@ -132,7 +138,10 @@ func (c *StoreClient) doRequestWithResponse(path string, data interface{}) ([]by
return nil, err
}
req.Header.Set("Content-Type", "application/json")
// Set Bearer token
if c.APIToken != "" {
req.Header.Set("Authorization", "Bearer "+c.APIToken)
}
resp, err := c.Client.Do(req)
if err != nil {
return nil, err