correctly store conversation
This commit is contained in:
@@ -72,6 +72,18 @@ func (a *Agent) Ask(opts ...JobOption) []ActionState {
|
|||||||
return j.Result.WaitResult()
|
return j.Result.WaitResult()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Agent) CurrentConversation() []openai.ChatCompletionMessage {
|
||||||
|
a.Lock()
|
||||||
|
defer a.Unlock()
|
||||||
|
return a.currentConversation
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Agent) ResetConversation() {
|
||||||
|
a.Lock()
|
||||||
|
defer a.Unlock()
|
||||||
|
a.currentConversation = []openai.ChatCompletionMessage{}
|
||||||
|
}
|
||||||
|
|
||||||
var ErrContextCanceled = fmt.Errorf("context canceled")
|
var ErrContextCanceled = fmt.Errorf("context canceled")
|
||||||
|
|
||||||
func (a *Agent) Stop() {
|
func (a *Agent) Stop() {
|
||||||
@@ -109,7 +121,7 @@ func (a *Agent) Run() error {
|
|||||||
// before clearing it out
|
// before clearing it out
|
||||||
|
|
||||||
// Clear the conversation
|
// Clear the conversation
|
||||||
a.currentConversation = []openai.ChatCompletionMessage{}
|
// a.currentConversation = []openai.ChatCompletionMessage{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
|
|
||||||
const testActionResult = "In Boston it's 30C today, it's sunny, and humidity is at 98%"
|
const testActionResult = "In Boston it's 30C today, it's sunny, and humidity is at 98%"
|
||||||
const testActionResult2 = "In milan it's very hot today, it is 45C and the humidity is at 200%"
|
const testActionResult2 = "In milan it's very hot today, it is 45C and the humidity is at 200%"
|
||||||
|
const testActionResult3 = "In paris it's very cold today, it is 2C and the humidity is at 10%"
|
||||||
|
|
||||||
var _ Action = &TestAction{}
|
var _ Action = &TestAction{}
|
||||||
|
|
||||||
@@ -22,11 +23,12 @@ type TestAction struct {
|
|||||||
|
|
||||||
func (a *TestAction) Run(action.ActionParams) (string, error) {
|
func (a *TestAction) Run(action.ActionParams) (string, error) {
|
||||||
res := a.response[a.responseN]
|
res := a.response[a.responseN]
|
||||||
|
a.responseN++
|
||||||
|
|
||||||
if len(a.response) == a.responseN {
|
if len(a.response) == a.responseN {
|
||||||
a.responseN = 0
|
a.responseN = 0
|
||||||
} else {
|
|
||||||
a.responseN++
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return res, nil
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -56,7 +58,7 @@ var _ = Describe("Agent test", func() {
|
|||||||
WithLLMAPIURL(apiModel),
|
WithLLMAPIURL(apiModel),
|
||||||
WithModel(testModel),
|
WithModel(testModel),
|
||||||
// WithRandomIdentity(),
|
// WithRandomIdentity(),
|
||||||
WithActions(&TestAction{response: []string{testActionResult, testActionResult2}}),
|
WithActions(&TestAction{response: []string{testActionResult, testActionResult2, testActionResult3}}),
|
||||||
)
|
)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
go agent.Run()
|
go agent.Run()
|
||||||
|
|||||||
@@ -156,9 +156,8 @@ func (a *Agent) consumeJob(job *Job) {
|
|||||||
// TODO: Use llava to explain the image content
|
// TODO: Use llava to explain the image content
|
||||||
}
|
}
|
||||||
|
|
||||||
messages := a.currentConversation
|
|
||||||
if job.Text != "" {
|
if job.Text != "" {
|
||||||
messages = append(messages, openai.ChatCompletionMessage{
|
a.currentConversation = append(a.currentConversation, openai.ChatCompletionMessage{
|
||||||
Role: "user",
|
Role: "user",
|
||||||
Content: job.Text,
|
Content: job.Text,
|
||||||
})
|
})
|
||||||
@@ -177,7 +176,7 @@ func (a *Agent) consumeJob(job *Job) {
|
|||||||
a.nextAction = nil
|
a.nextAction = nil
|
||||||
} else {
|
} else {
|
||||||
var err error
|
var err error
|
||||||
chosenAction, reasoning, err = a.pickAction(ctx, pickActionTemplate, messages)
|
chosenAction, reasoning, err = a.pickAction(ctx, pickActionTemplate, a.currentConversation)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("error picking action: %v\n", err)
|
fmt.Printf("error picking action: %v\n", err)
|
||||||
return
|
return
|
||||||
@@ -189,7 +188,7 @@ func (a *Agent) consumeJob(job *Job) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
params, err := a.generateParameters(ctx, chosenAction, messages)
|
params, err := a.generateParameters(ctx, chosenAction, a.currentConversation)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("error generating parameters: %v\n", err)
|
fmt.Printf("error generating parameters: %v\n", err)
|
||||||
return
|
return
|
||||||
@@ -224,7 +223,7 @@ func (a *Agent) consumeJob(job *Job) {
|
|||||||
job.CallbackWithResult(stateResult)
|
job.CallbackWithResult(stateResult)
|
||||||
|
|
||||||
// calling the function
|
// calling the function
|
||||||
messages = append(messages, openai.ChatCompletionMessage{
|
a.currentConversation = append(a.currentConversation, openai.ChatCompletionMessage{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
FunctionCall: &openai.FunctionCall{
|
FunctionCall: &openai.FunctionCall{
|
||||||
Name: chosenAction.Definition().Name.String(),
|
Name: chosenAction.Definition().Name.String(),
|
||||||
@@ -233,18 +232,19 @@ func (a *Agent) consumeJob(job *Job) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
// result of calling the function
|
// result of calling the function
|
||||||
messages = append(messages, openai.ChatCompletionMessage{
|
a.currentConversation = append(a.currentConversation, openai.ChatCompletionMessage{
|
||||||
Role: openai.ChatMessageRoleTool,
|
Role: openai.ChatMessageRoleTool,
|
||||||
Content: result,
|
Content: result,
|
||||||
Name: chosenAction.Definition().Name.String(),
|
Name: chosenAction.Definition().Name.String(),
|
||||||
ToolCallID: chosenAction.Definition().Name.String(),
|
ToolCallID: chosenAction.Definition().Name.String(),
|
||||||
})
|
})
|
||||||
|
|
||||||
a.currentConversation = append(a.currentConversation, messages...)
|
//a.currentConversation = append(a.currentConversation, messages...)
|
||||||
|
//a.currentConversation = messages
|
||||||
|
|
||||||
// given the result, we can now ask OpenAI to complete the conversation or
|
// given the result, we can now ask OpenAI to complete the conversation or
|
||||||
// to continue using another tool given the result
|
// to continue using another tool given the result
|
||||||
followingAction, reasoning, err := a.pickAction(ctx, reEvalTemplate, messages)
|
followingAction, reasoning, err := a.pickAction(ctx, reEvalTemplate, a.currentConversation)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("error picking action: %v\n", err)
|
fmt.Printf("error picking action: %v\n", err)
|
||||||
return
|
return
|
||||||
@@ -267,7 +267,7 @@ func (a *Agent) consumeJob(job *Job) {
|
|||||||
resp, err := a.client.CreateChatCompletion(ctx,
|
resp, err := a.client.CreateChatCompletion(ctx,
|
||||||
openai.ChatCompletionRequest{
|
openai.ChatCompletionRequest{
|
||||||
Model: a.options.LLMAPI.Model,
|
Model: a.options.LLMAPI.Model,
|
||||||
Messages: messages,
|
Messages: a.currentConversation,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if err != nil || len(resp.Choices) != 1 {
|
if err != nil || len(resp.Choices) != 1 {
|
||||||
|
|||||||
Reference in New Issue
Block a user