rag: add KB to conversation
This commit is contained in:
@@ -23,6 +23,7 @@ type Agent struct {
|
|||||||
options *options
|
options *options
|
||||||
Character Character
|
Character Character
|
||||||
client *openai.Client
|
client *openai.Client
|
||||||
|
storeClient *llm.StoreClient
|
||||||
jobQueue chan *Job
|
jobQueue chan *Job
|
||||||
actionContext *action.ActionContext
|
actionContext *action.ActionContext
|
||||||
context *action.ActionContext
|
context *action.ActionContext
|
||||||
@@ -43,6 +44,7 @@ func New(opts ...Option) (*Agent, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
client := llm.NewClient(options.LLMAPI.APIKey, options.LLMAPI.APIURL)
|
client := llm.NewClient(options.LLMAPI.APIKey, options.LLMAPI.APIURL)
|
||||||
|
storeClient := llm.NewStoreClient(options.LLMAPI.APIURL, options.LLMAPI.APIKey)
|
||||||
|
|
||||||
c := context.Background()
|
c := context.Background()
|
||||||
if options.context != nil {
|
if options.context != nil {
|
||||||
@@ -56,6 +58,7 @@ func New(opts ...Option) (*Agent, error) {
|
|||||||
client: client,
|
client: client,
|
||||||
Character: options.character,
|
Character: options.character,
|
||||||
currentState: &action.StateResult{},
|
currentState: &action.StateResult{},
|
||||||
|
storeClient: storeClient,
|
||||||
context: action.NewContext(ctx, cancel),
|
context: action.NewContext(ctx, cancel),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -168,8 +171,9 @@ func (a *Agent) runAction(chosenAction Action, decisionResult *decisionResult) (
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Agent) consumeJob(job *Job, role string) {
|
func (a *Agent) consumeJob(job *Job, role string) {
|
||||||
|
// We are self evaluating if we consume the job as a system role
|
||||||
selfEvaluation := role == SystemRole
|
selfEvaluation := role == SystemRole
|
||||||
// Consume the job and generate a response
|
|
||||||
a.Lock()
|
a.Lock()
|
||||||
// Set the action context
|
// Set the action context
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
@@ -199,6 +203,42 @@ func (a *Agent) consumeJob(job *Job, role string) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RAG
|
||||||
|
if a.options.enableKB {
|
||||||
|
// Walk conversation from bottom to top, and find the first message of the user
|
||||||
|
// to use it as a query to the KB
|
||||||
|
var userMessage string
|
||||||
|
for i := len(a.currentConversation) - 1; i >= 0; i-- {
|
||||||
|
if a.currentConversation[i].Role == "user" {
|
||||||
|
userMessage = a.currentConversation[i].Content
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if userMessage != "" {
|
||||||
|
results, err := llm.FindSimilarStrings(a.storeClient, a.client, userMessage, a.options.kbResults)
|
||||||
|
if err != nil {
|
||||||
|
job.Result.Finish(fmt.Errorf("error finding similar strings inside KB: %w", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
formatResults := ""
|
||||||
|
for _, r := range results {
|
||||||
|
formatResults += fmt.Sprintf("- %s \n", r)
|
||||||
|
}
|
||||||
|
if a.options.debugMode {
|
||||||
|
fmt.Println("Found similar strings in KB:")
|
||||||
|
fmt.Println(formatResults)
|
||||||
|
}
|
||||||
|
a.currentConversation = append(a.currentConversation,
|
||||||
|
openai.ChatCompletionMessage{
|
||||||
|
Role: "system",
|
||||||
|
Content: fmt.Sprintf("Given the user input you have the following in memory:\n%s", formatResults),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var pickTemplate string
|
var pickTemplate string
|
||||||
var reEvaluationTemplate string
|
var reEvaluationTemplate string
|
||||||
|
|
||||||
@@ -567,7 +607,7 @@ func (a *Agent) Run() error {
|
|||||||
a.consumeJob(job, UserRole)
|
a.consumeJob(job, UserRole)
|
||||||
timer.Reset(a.options.periodicRuns)
|
timer.Reset(a.options.periodicRuns)
|
||||||
case <-a.context.Done():
|
case <-a.context.Done():
|
||||||
// Agent has been canceled, return error
|
// Agent has been canceled, return error
|
||||||
return ErrContextCanceled
|
return ErrContextCanceled
|
||||||
case <-timer.C:
|
case <-timer.C:
|
||||||
if !a.options.standaloneJob {
|
if !a.options.standaloneJob {
|
||||||
|
|||||||
@@ -14,19 +14,20 @@ type llmOptions struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type options struct {
|
type options struct {
|
||||||
LLMAPI llmOptions
|
LLMAPI llmOptions
|
||||||
character Character
|
character Character
|
||||||
randomIdentityGuidance string
|
randomIdentityGuidance string
|
||||||
randomIdentity bool
|
randomIdentity bool
|
||||||
userActions Actions
|
userActions Actions
|
||||||
enableHUD, standaloneJob, showCharacter bool
|
enableHUD, standaloneJob, showCharacter, enableKB bool
|
||||||
debugMode bool
|
debugMode bool
|
||||||
initiateConversations bool
|
initiateConversations bool
|
||||||
characterfile string
|
characterfile string
|
||||||
statefile string
|
statefile string
|
||||||
context context.Context
|
context context.Context
|
||||||
permanentGoal string
|
permanentGoal string
|
||||||
periodicRuns time.Duration
|
periodicRuns time.Duration
|
||||||
|
kbResults int
|
||||||
|
|
||||||
// callbacks
|
// callbacks
|
||||||
reasoningCallback func(ActionCurrentState) bool
|
reasoningCallback func(ActionCurrentState) bool
|
||||||
@@ -65,6 +66,20 @@ var EnableHUD = func(o *options) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var EnableKnowledgeBase = func(o *options) error {
|
||||||
|
o.enableKB = true
|
||||||
|
o.kbResults = 5
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func EnableKnowledgeBaseWithResults(results int) Option {
|
||||||
|
return func(o *options) error {
|
||||||
|
o.enableKB = true
|
||||||
|
o.kbResults = results
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var EnableInitiateConversations = func(o *options) error {
|
var EnableInitiateConversations = func(o *options) error {
|
||||||
o.initiateConversations = true
|
o.initiateConversations = true
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -7,10 +7,7 @@ import (
|
|||||||
"github.com/sashabaranov/go-openai"
|
"github.com/sashabaranov/go-openai"
|
||||||
)
|
)
|
||||||
|
|
||||||
func StoreStringEmbeddingInVectorDB(apiHost string, openaiClient *openai.Client, s string) error {
|
func StoreStringEmbeddingInVectorDB(client *StoreClient, openaiClient *openai.Client, s string) error {
|
||||||
// Example usage
|
|
||||||
client := NewStoreClient(apiHost)
|
|
||||||
|
|
||||||
resp, err := openaiClient.CreateEmbeddings(context.TODO(),
|
resp, err := openaiClient.CreateEmbeddings(context.TODO(),
|
||||||
openai.EmbeddingRequestStrings{
|
openai.EmbeddingRequestStrings{
|
||||||
Input: []string{s},
|
Input: []string{s},
|
||||||
@@ -39,8 +36,7 @@ func StoreStringEmbeddingInVectorDB(apiHost string, openaiClient *openai.Client,
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func FindSimilarStrings(apiHost string, openaiClient *openai.Client, s string, similarEntries int) ([]string, error) {
|
func FindSimilarStrings(client *StoreClient, openaiClient *openai.Client, s string, similarEntries int) ([]string, error) {
|
||||||
client := NewStoreClient(apiHost)
|
|
||||||
|
|
||||||
resp, err := openaiClient.CreateEmbeddings(context.TODO(),
|
resp, err := openaiClient.CreateEmbeddings(context.TODO(),
|
||||||
openai.EmbeddingRequestStrings{
|
openai.EmbeddingRequestStrings{
|
||||||
|
|||||||
21
llm/store.go
21
llm/store.go
@@ -10,8 +10,9 @@ import (
|
|||||||
|
|
||||||
// Define a struct to hold your store API client
|
// Define a struct to hold your store API client
|
||||||
type StoreClient struct {
|
type StoreClient struct {
|
||||||
BaseURL string
|
BaseURL string
|
||||||
Client *http.Client
|
APIToken string
|
||||||
|
Client *http.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
// Define request and response struct formats based on the API documentation
|
// Define request and response struct formats based on the API documentation
|
||||||
@@ -45,10 +46,11 @@ type FindResponse struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Constructor for StoreClient
|
// Constructor for StoreClient
|
||||||
func NewStoreClient(baseUrl string) *StoreClient {
|
func NewStoreClient(baseUrl, apiToken string) *StoreClient {
|
||||||
return &StoreClient{
|
return &StoreClient{
|
||||||
BaseURL: baseUrl,
|
BaseURL: baseUrl,
|
||||||
Client: &http.Client{},
|
APIToken: apiToken,
|
||||||
|
Client: &http.Client{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -105,6 +107,10 @@ func (c *StoreClient) doRequest(path string, data interface{}) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
// Set Bearer token
|
||||||
|
if c.APIToken != "" {
|
||||||
|
req.Header.Set("Authorization", "Bearer "+c.APIToken)
|
||||||
|
}
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
resp, err := c.Client.Do(req)
|
resp, err := c.Client.Do(req)
|
||||||
@@ -132,7 +138,10 @@ func (c *StoreClient) doRequestWithResponse(path string, data interface{}) ([]by
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
// Set Bearer token
|
||||||
|
if c.APIToken != "" {
|
||||||
|
req.Header.Set("Authorization", "Bearer "+c.APIToken)
|
||||||
|
}
|
||||||
resp, err := c.Client.Do(req)
|
resp, err := c.Client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
Reference in New Issue
Block a user