rag: add KB to conversation

This commit is contained in:
Ettore Di Giacinto
2024-04-09 22:34:22 +02:00
parent 36abf837a9
commit 78ba7871e9
4 changed files with 87 additions and 27 deletions

View File

@@ -23,6 +23,7 @@ type Agent struct {
options *options options *options
Character Character Character Character
client *openai.Client client *openai.Client
storeClient *llm.StoreClient
jobQueue chan *Job jobQueue chan *Job
actionContext *action.ActionContext actionContext *action.ActionContext
context *action.ActionContext context *action.ActionContext
@@ -43,6 +44,7 @@ func New(opts ...Option) (*Agent, error) {
} }
client := llm.NewClient(options.LLMAPI.APIKey, options.LLMAPI.APIURL) client := llm.NewClient(options.LLMAPI.APIKey, options.LLMAPI.APIURL)
storeClient := llm.NewStoreClient(options.LLMAPI.APIURL, options.LLMAPI.APIKey)
c := context.Background() c := context.Background()
if options.context != nil { if options.context != nil {
@@ -56,6 +58,7 @@ func New(opts ...Option) (*Agent, error) {
client: client, client: client,
Character: options.character, Character: options.character,
currentState: &action.StateResult{}, currentState: &action.StateResult{},
storeClient: storeClient,
context: action.NewContext(ctx, cancel), context: action.NewContext(ctx, cancel),
} }
@@ -168,8 +171,9 @@ func (a *Agent) runAction(chosenAction Action, decisionResult *decisionResult) (
} }
func (a *Agent) consumeJob(job *Job, role string) { func (a *Agent) consumeJob(job *Job, role string) {
// We are self evaluating if we consume the job as a system role
selfEvaluation := role == SystemRole selfEvaluation := role == SystemRole
// Consume the job and generate a response
a.Lock() a.Lock()
// Set the action context // Set the action context
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
@@ -199,6 +203,42 @@ func (a *Agent) consumeJob(job *Job, role string) {
}) })
} }
// RAG
if a.options.enableKB {
// Walk conversation from bottom to top, and find the first message of the user
// to use it as a query to the KB
var userMessage string
for i := len(a.currentConversation) - 1; i >= 0; i-- {
if a.currentConversation[i].Role == "user" {
userMessage = a.currentConversation[i].Content
break
}
}
if userMessage != "" {
results, err := llm.FindSimilarStrings(a.storeClient, a.client, userMessage, a.options.kbResults)
if err != nil {
job.Result.Finish(fmt.Errorf("error finding similar strings inside KB: %w", err))
return
}
formatResults := ""
for _, r := range results {
formatResults += fmt.Sprintf("- %s \n", r)
}
if a.options.debugMode {
fmt.Println("Found similar strings in KB:")
fmt.Println(formatResults)
}
a.currentConversation = append(a.currentConversation,
openai.ChatCompletionMessage{
Role: "system",
Content: fmt.Sprintf("Given the user input you have the following in memory:\n%s", formatResults),
},
)
}
}
var pickTemplate string var pickTemplate string
var reEvaluationTemplate string var reEvaluationTemplate string
@@ -567,7 +607,7 @@ func (a *Agent) Run() error {
a.consumeJob(job, UserRole) a.consumeJob(job, UserRole)
timer.Reset(a.options.periodicRuns) timer.Reset(a.options.periodicRuns)
case <-a.context.Done(): case <-a.context.Done():
// Agent has been canceled, return error // Agent has been canceled, return error
return ErrContextCanceled return ErrContextCanceled
case <-timer.C: case <-timer.C:
if !a.options.standaloneJob { if !a.options.standaloneJob {

View File

@@ -14,19 +14,20 @@ type llmOptions struct {
} }
type options struct { type options struct {
LLMAPI llmOptions LLMAPI llmOptions
character Character character Character
randomIdentityGuidance string randomIdentityGuidance string
randomIdentity bool randomIdentity bool
userActions Actions userActions Actions
enableHUD, standaloneJob, showCharacter bool enableHUD, standaloneJob, showCharacter, enableKB bool
debugMode bool debugMode bool
initiateConversations bool initiateConversations bool
characterfile string characterfile string
statefile string statefile string
context context.Context context context.Context
permanentGoal string permanentGoal string
periodicRuns time.Duration periodicRuns time.Duration
kbResults int
// callbacks // callbacks
reasoningCallback func(ActionCurrentState) bool reasoningCallback func(ActionCurrentState) bool
@@ -65,6 +66,20 @@ var EnableHUD = func(o *options) error {
return nil return nil
} }
var EnableKnowledgeBase = func(o *options) error {
o.enableKB = true
o.kbResults = 5
return nil
}
func EnableKnowledgeBaseWithResults(results int) Option {
return func(o *options) error {
o.enableKB = true
o.kbResults = results
return nil
}
}
var EnableInitiateConversations = func(o *options) error { var EnableInitiateConversations = func(o *options) error {
o.initiateConversations = true o.initiateConversations = true
return nil return nil

View File

@@ -7,10 +7,7 @@ import (
"github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai"
) )
func StoreStringEmbeddingInVectorDB(apiHost string, openaiClient *openai.Client, s string) error { func StoreStringEmbeddingInVectorDB(client *StoreClient, openaiClient *openai.Client, s string) error {
// Example usage
client := NewStoreClient(apiHost)
resp, err := openaiClient.CreateEmbeddings(context.TODO(), resp, err := openaiClient.CreateEmbeddings(context.TODO(),
openai.EmbeddingRequestStrings{ openai.EmbeddingRequestStrings{
Input: []string{s}, Input: []string{s},
@@ -39,8 +36,7 @@ func StoreStringEmbeddingInVectorDB(apiHost string, openaiClient *openai.Client,
return nil return nil
} }
func FindSimilarStrings(apiHost string, openaiClient *openai.Client, s string, similarEntries int) ([]string, error) { func FindSimilarStrings(client *StoreClient, openaiClient *openai.Client, s string, similarEntries int) ([]string, error) {
client := NewStoreClient(apiHost)
resp, err := openaiClient.CreateEmbeddings(context.TODO(), resp, err := openaiClient.CreateEmbeddings(context.TODO(),
openai.EmbeddingRequestStrings{ openai.EmbeddingRequestStrings{

View File

@@ -10,8 +10,9 @@ import (
// Define a struct to hold your store API client // Define a struct to hold your store API client
type StoreClient struct { type StoreClient struct {
BaseURL string BaseURL string
Client *http.Client APIToken string
Client *http.Client
} }
// Define request and response struct formats based on the API documentation // Define request and response struct formats based on the API documentation
@@ -45,10 +46,11 @@ type FindResponse struct {
} }
// Constructor for StoreClient // Constructor for StoreClient
func NewStoreClient(baseUrl string) *StoreClient { func NewStoreClient(baseUrl, apiToken string) *StoreClient {
return &StoreClient{ return &StoreClient{
BaseURL: baseUrl, BaseURL: baseUrl,
Client: &http.Client{}, APIToken: apiToken,
Client: &http.Client{},
} }
} }
@@ -105,6 +107,10 @@ func (c *StoreClient) doRequest(path string, data interface{}) error {
if err != nil { if err != nil {
return err return err
} }
// Set Bearer token
if c.APIToken != "" {
req.Header.Set("Authorization", "Bearer "+c.APIToken)
}
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
resp, err := c.Client.Do(req) resp, err := c.Client.Do(req)
@@ -132,7 +138,10 @@ func (c *StoreClient) doRequestWithResponse(path string, data interface{}) ([]by
return nil, err return nil, err
} }
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
// Set Bearer token
if c.APIToken != "" {
req.Header.Set("Authorization", "Bearer "+c.APIToken)
}
resp, err := c.Client.Do(req) resp, err := c.Client.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err