diff --git a/agent/agent.go b/agent/agent.go index 59e996c..662391a 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -23,6 +23,7 @@ type Agent struct { options *options Character Character client *openai.Client + storeClient *llm.StoreClient jobQueue chan *Job actionContext *action.ActionContext context *action.ActionContext @@ -43,6 +44,7 @@ func New(opts ...Option) (*Agent, error) { } client := llm.NewClient(options.LLMAPI.APIKey, options.LLMAPI.APIURL) + storeClient := llm.NewStoreClient(options.LLMAPI.APIURL, options.LLMAPI.APIKey) c := context.Background() if options.context != nil { @@ -56,6 +58,7 @@ func New(opts ...Option) (*Agent, error) { client: client, Character: options.character, currentState: &action.StateResult{}, + storeClient: storeClient, context: action.NewContext(ctx, cancel), } @@ -168,8 +171,9 @@ func (a *Agent) runAction(chosenAction Action, decisionResult *decisionResult) ( } func (a *Agent) consumeJob(job *Job, role string) { + // We are self evaluating if we consume the job as a system role selfEvaluation := role == SystemRole - // Consume the job and generate a response + a.Lock() // Set the action context ctx, cancel := context.WithCancel(context.Background()) @@ -199,6 +203,42 @@ func (a *Agent) consumeJob(job *Job, role string) { }) } + // RAG + if a.options.enableKB { + // Walk conversation from bottom to top, and find the first message of the user + // to use it as a query to the KB + var userMessage string + for i := len(a.currentConversation) - 1; i >= 0; i-- { + if a.currentConversation[i].Role == "user" { + userMessage = a.currentConversation[i].Content + break + } + } + + if userMessage != "" { + results, err := llm.FindSimilarStrings(a.storeClient, a.client, userMessage, a.options.kbResults) + if err != nil { + job.Result.Finish(fmt.Errorf("error finding similar strings inside KB: %w", err)) + return + } + + formatResults := "" + for _, r := range results { + formatResults += fmt.Sprintf("- %s \n", r) + } + if a.options.debugMode { + fmt.Println("Found similar strings in KB:") + fmt.Println(formatResults) + } + a.currentConversation = append(a.currentConversation, + openai.ChatCompletionMessage{ + Role: "system", + Content: fmt.Sprintf("Given the user input you have the following in memory:\n%s", formatResults), + }, + ) + } + } + var pickTemplate string var reEvaluationTemplate string @@ -567,7 +607,7 @@ func (a *Agent) Run() error { a.consumeJob(job, UserRole) timer.Reset(a.options.periodicRuns) case <-a.context.Done(): - // Agent has been canceled, return error + // Agent has been canceled, return error return ErrContextCanceled case <-timer.C: if !a.options.standaloneJob { diff --git a/agent/options.go b/agent/options.go index 441117a..1d7cde8 100644 --- a/agent/options.go +++ b/agent/options.go @@ -14,19 +14,20 @@ type llmOptions struct { } type options struct { - LLMAPI llmOptions - character Character - randomIdentityGuidance string - randomIdentity bool - userActions Actions - enableHUD, standaloneJob, showCharacter bool - debugMode bool - initiateConversations bool - characterfile string - statefile string - context context.Context - permanentGoal string - periodicRuns time.Duration + LLMAPI llmOptions + character Character + randomIdentityGuidance string + randomIdentity bool + userActions Actions + enableHUD, standaloneJob, showCharacter, enableKB bool + debugMode bool + initiateConversations bool + characterfile string + statefile string + context context.Context + permanentGoal string + periodicRuns time.Duration + kbResults int // callbacks reasoningCallback func(ActionCurrentState) bool @@ -65,6 +66,20 @@ var EnableHUD = func(o *options) error { return nil } +var EnableKnowledgeBase = func(o *options) error { + o.enableKB = true + o.kbResults = 5 + return nil +} + +func EnableKnowledgeBaseWithResults(results int) Option { + return func(o *options) error { + o.enableKB = true + o.kbResults = results + return nil + } +} + var EnableInitiateConversations = func(o *options) error { o.initiateConversations = true return nil diff --git a/llm/rag.go b/llm/rag.go index 10ca98e..af0371b 100644 --- a/llm/rag.go +++ b/llm/rag.go @@ -7,10 +7,7 @@ import ( "github.com/sashabaranov/go-openai" ) -func StoreStringEmbeddingInVectorDB(apiHost string, openaiClient *openai.Client, s string) error { - // Example usage - client := NewStoreClient(apiHost) - +func StoreStringEmbeddingInVectorDB(client *StoreClient, openaiClient *openai.Client, s string) error { resp, err := openaiClient.CreateEmbeddings(context.TODO(), openai.EmbeddingRequestStrings{ Input: []string{s}, @@ -39,8 +36,7 @@ func StoreStringEmbeddingInVectorDB(apiHost string, openaiClient *openai.Client, return nil } -func FindSimilarStrings(apiHost string, openaiClient *openai.Client, s string, similarEntries int) ([]string, error) { - client := NewStoreClient(apiHost) +func FindSimilarStrings(client *StoreClient, openaiClient *openai.Client, s string, similarEntries int) ([]string, error) { resp, err := openaiClient.CreateEmbeddings(context.TODO(), openai.EmbeddingRequestStrings{ diff --git a/llm/store.go b/llm/store.go index ba3bebe..3b85568 100644 --- a/llm/store.go +++ b/llm/store.go @@ -10,8 +10,9 @@ import ( // Define a struct to hold your store API client type StoreClient struct { - BaseURL string - Client *http.Client + BaseURL string + APIToken string + Client *http.Client } // Define request and response struct formats based on the API documentation @@ -45,10 +46,11 @@ type FindResponse struct { } // Constructor for StoreClient -func NewStoreClient(baseUrl string) *StoreClient { +func NewStoreClient(baseUrl, apiToken string) *StoreClient { return &StoreClient{ - BaseURL: baseUrl, - Client: &http.Client{}, + BaseURL: baseUrl, + APIToken: apiToken, + Client: &http.Client{}, } } @@ -105,6 +107,10 @@ func (c *StoreClient) doRequest(path string, data interface{}) error { if err != nil { return err } + // Set Bearer token + if c.APIToken != "" { + req.Header.Set("Authorization", "Bearer "+c.APIToken) + } req.Header.Set("Content-Type", "application/json") resp, err := c.Client.Do(req) @@ -132,7 +138,10 @@ func (c *StoreClient) doRequestWithResponse(path string, data interface{}) ([]by return nil, err } req.Header.Set("Content-Type", "application/json") - + // Set Bearer token + if c.APIToken != "" { + req.Header.Set("Authorization", "Bearer "+c.APIToken) + } resp, err := c.Client.Do(req) if err != nil { return nil, err