correctly store conversation

This commit is contained in:
mudler
2024-04-02 17:32:27 +02:00
parent 8e3a1fcbe5
commit 9417c5ca8f
3 changed files with 27 additions and 13 deletions

View File

@@ -72,6 +72,18 @@ func (a *Agent) Ask(opts ...JobOption) []ActionState {
return j.Result.WaitResult() return j.Result.WaitResult()
} }
func (a *Agent) CurrentConversation() []openai.ChatCompletionMessage {
a.Lock()
defer a.Unlock()
return a.currentConversation
}
func (a *Agent) ResetConversation() {
a.Lock()
defer a.Unlock()
a.currentConversation = []openai.ChatCompletionMessage{}
}
var ErrContextCanceled = fmt.Errorf("context canceled") var ErrContextCanceled = fmt.Errorf("context canceled")
func (a *Agent) Stop() { func (a *Agent) Stop() {
@@ -109,7 +121,7 @@ func (a *Agent) Run() error {
// before clearing it out // before clearing it out
// Clear the conversation // Clear the conversation
a.currentConversation = []openai.ChatCompletionMessage{} // a.currentConversation = []openai.ChatCompletionMessage{}
} }
} }
} }

View File

@@ -12,6 +12,7 @@ import (
const testActionResult = "In Boston it's 30C today, it's sunny, and humidity is at 98%" const testActionResult = "In Boston it's 30C today, it's sunny, and humidity is at 98%"
const testActionResult2 = "In milan it's very hot today, it is 45C and the humidity is at 200%" const testActionResult2 = "In milan it's very hot today, it is 45C and the humidity is at 200%"
const testActionResult3 = "In paris it's very cold today, it is 2C and the humidity is at 10%"
var _ Action = &TestAction{} var _ Action = &TestAction{}
@@ -22,11 +23,12 @@ type TestAction struct {
func (a *TestAction) Run(action.ActionParams) (string, error) { func (a *TestAction) Run(action.ActionParams) (string, error) {
res := a.response[a.responseN] res := a.response[a.responseN]
a.responseN++
if len(a.response) == a.responseN { if len(a.response) == a.responseN {
a.responseN = 0 a.responseN = 0
} else {
a.responseN++
} }
return res, nil return res, nil
} }
@@ -56,7 +58,7 @@ var _ = Describe("Agent test", func() {
WithLLMAPIURL(apiModel), WithLLMAPIURL(apiModel),
WithModel(testModel), WithModel(testModel),
// WithRandomIdentity(), // WithRandomIdentity(),
WithActions(&TestAction{response: []string{testActionResult, testActionResult2}}), WithActions(&TestAction{response: []string{testActionResult, testActionResult2, testActionResult3}}),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
go agent.Run() go agent.Run()

View File

@@ -156,9 +156,8 @@ func (a *Agent) consumeJob(job *Job) {
// TODO: Use llava to explain the image content // TODO: Use llava to explain the image content
} }
messages := a.currentConversation
if job.Text != "" { if job.Text != "" {
messages = append(messages, openai.ChatCompletionMessage{ a.currentConversation = append(a.currentConversation, openai.ChatCompletionMessage{
Role: "user", Role: "user",
Content: job.Text, Content: job.Text,
}) })
@@ -177,7 +176,7 @@ func (a *Agent) consumeJob(job *Job) {
a.nextAction = nil a.nextAction = nil
} else { } else {
var err error var err error
chosenAction, reasoning, err = a.pickAction(ctx, pickActionTemplate, messages) chosenAction, reasoning, err = a.pickAction(ctx, pickActionTemplate, a.currentConversation)
if err != nil { if err != nil {
fmt.Printf("error picking action: %v\n", err) fmt.Printf("error picking action: %v\n", err)
return return
@@ -189,7 +188,7 @@ func (a *Agent) consumeJob(job *Job) {
return return
} }
params, err := a.generateParameters(ctx, chosenAction, messages) params, err := a.generateParameters(ctx, chosenAction, a.currentConversation)
if err != nil { if err != nil {
fmt.Printf("error generating parameters: %v\n", err) fmt.Printf("error generating parameters: %v\n", err)
return return
@@ -224,7 +223,7 @@ func (a *Agent) consumeJob(job *Job) {
job.CallbackWithResult(stateResult) job.CallbackWithResult(stateResult)
// calling the function // calling the function
messages = append(messages, openai.ChatCompletionMessage{ a.currentConversation = append(a.currentConversation, openai.ChatCompletionMessage{
Role: "assistant", Role: "assistant",
FunctionCall: &openai.FunctionCall{ FunctionCall: &openai.FunctionCall{
Name: chosenAction.Definition().Name.String(), Name: chosenAction.Definition().Name.String(),
@@ -233,18 +232,19 @@ func (a *Agent) consumeJob(job *Job) {
}) })
// result of calling the function // result of calling the function
messages = append(messages, openai.ChatCompletionMessage{ a.currentConversation = append(a.currentConversation, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleTool, Role: openai.ChatMessageRoleTool,
Content: result, Content: result,
Name: chosenAction.Definition().Name.String(), Name: chosenAction.Definition().Name.String(),
ToolCallID: chosenAction.Definition().Name.String(), ToolCallID: chosenAction.Definition().Name.String(),
}) })
a.currentConversation = append(a.currentConversation, messages...) //a.currentConversation = append(a.currentConversation, messages...)
//a.currentConversation = messages
// given the result, we can now ask OpenAI to complete the conversation or // given the result, we can now ask OpenAI to complete the conversation or
// to continue using another tool given the result // to continue using another tool given the result
followingAction, reasoning, err := a.pickAction(ctx, reEvalTemplate, messages) followingAction, reasoning, err := a.pickAction(ctx, reEvalTemplate, a.currentConversation)
if err != nil { if err != nil {
fmt.Printf("error picking action: %v\n", err) fmt.Printf("error picking action: %v\n", err)
return return
@@ -267,7 +267,7 @@ func (a *Agent) consumeJob(job *Job) {
resp, err := a.client.CreateChatCompletion(ctx, resp, err := a.client.CreateChatCompletion(ctx,
openai.ChatCompletionRequest{ openai.ChatCompletionRequest{
Model: a.options.LLMAPI.Model, Model: a.options.LLMAPI.Model,
Messages: messages, Messages: a.currentConversation,
}, },
) )
if err != nil || len(resp.Choices) != 1 { if err != nil || len(resp.Choices) != 1 {