diff --git a/agent/agent.go b/agent/agent.go index df2fd10..b1d8097 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -72,6 +72,18 @@ func (a *Agent) Ask(opts ...JobOption) []ActionState { return j.Result.WaitResult() } +func (a *Agent) CurrentConversation() []openai.ChatCompletionMessage { + a.Lock() + defer a.Unlock() + return a.currentConversation +} + +func (a *Agent) ResetConversation() { + a.Lock() + defer a.Unlock() + a.currentConversation = []openai.ChatCompletionMessage{} +} + var ErrContextCanceled = fmt.Errorf("context canceled") func (a *Agent) Stop() { @@ -109,7 +121,7 @@ func (a *Agent) Run() error { // before clearing it out // Clear the conversation - a.currentConversation = []openai.ChatCompletionMessage{} + // a.currentConversation = []openai.ChatCompletionMessage{} } } } diff --git a/agent/agent_test.go b/agent/agent_test.go index 312f305..1f19c18 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -12,6 +12,7 @@ import ( const testActionResult = "In Boston it's 30C today, it's sunny, and humidity is at 98%" const testActionResult2 = "In milan it's very hot today, it is 45C and the humidity is at 200%" +const testActionResult3 = "In paris it's very cold today, it is 2C and the humidity is at 10%" var _ Action = &TestAction{} @@ -22,11 +23,12 @@ type TestAction struct { func (a *TestAction) Run(action.ActionParams) (string, error) { res := a.response[a.responseN] + a.responseN++ + if len(a.response) == a.responseN { a.responseN = 0 - } else { - a.responseN++ } + return res, nil } @@ -56,7 +58,7 @@ var _ = Describe("Agent test", func() { WithLLMAPIURL(apiModel), WithModel(testModel), // WithRandomIdentity(), - WithActions(&TestAction{response: []string{testActionResult, testActionResult2}}), + WithActions(&TestAction{response: []string{testActionResult, testActionResult2, testActionResult3}}), ) Expect(err).ToNot(HaveOccurred()) go agent.Run() diff --git a/agent/jobs.go b/agent/jobs.go index 53b1a96..1d1640a 100644 --- a/agent/jobs.go +++ b/agent/jobs.go @@ -156,9 +156,8 @@ func (a *Agent) consumeJob(job *Job) { // TODO: Use llava to explain the image content } - messages := a.currentConversation if job.Text != "" { - messages = append(messages, openai.ChatCompletionMessage{ + a.currentConversation = append(a.currentConversation, openai.ChatCompletionMessage{ Role: "user", Content: job.Text, }) @@ -177,7 +176,7 @@ func (a *Agent) consumeJob(job *Job) { a.nextAction = nil } else { var err error - chosenAction, reasoning, err = a.pickAction(ctx, pickActionTemplate, messages) + chosenAction, reasoning, err = a.pickAction(ctx, pickActionTemplate, a.currentConversation) if err != nil { fmt.Printf("error picking action: %v\n", err) return @@ -189,7 +188,7 @@ func (a *Agent) consumeJob(job *Job) { return } - params, err := a.generateParameters(ctx, chosenAction, messages) + params, err := a.generateParameters(ctx, chosenAction, a.currentConversation) if err != nil { fmt.Printf("error generating parameters: %v\n", err) return @@ -224,7 +223,7 @@ func (a *Agent) consumeJob(job *Job) { job.CallbackWithResult(stateResult) // calling the function - messages = append(messages, openai.ChatCompletionMessage{ + a.currentConversation = append(a.currentConversation, openai.ChatCompletionMessage{ Role: "assistant", FunctionCall: &openai.FunctionCall{ Name: chosenAction.Definition().Name.String(), @@ -233,18 +232,19 @@ func (a *Agent) consumeJob(job *Job) { }) // result of calling the function - messages = append(messages, openai.ChatCompletionMessage{ + a.currentConversation = append(a.currentConversation, openai.ChatCompletionMessage{ Role: openai.ChatMessageRoleTool, Content: result, Name: chosenAction.Definition().Name.String(), ToolCallID: chosenAction.Definition().Name.String(), }) - a.currentConversation = append(a.currentConversation, messages...) + //a.currentConversation = append(a.currentConversation, messages...) + //a.currentConversation = messages // given the result, we can now ask OpenAI to complete the conversation or // to continue using another tool given the result - followingAction, reasoning, err := a.pickAction(ctx, reEvalTemplate, messages) + followingAction, reasoning, err := a.pickAction(ctx, reEvalTemplate, a.currentConversation) if err != nil { fmt.Printf("error picking action: %v\n", err) return @@ -267,7 +267,7 @@ func (a *Agent) consumeJob(job *Job) { resp, err := a.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{ Model: a.options.LLMAPI.Model, - Messages: messages, + Messages: a.currentConversation, }, ) if err != nil || len(resp.Choices) != 1 {