From 906b4ebd76de450dcbe4ee3ae071c70f6f1b2f2e Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Fri, 28 Mar 2025 22:04:04 +0100 Subject: [PATCH] feat: add retries to pickAction Signed-off-by: Ettore Di Giacinto --- core/agent/actions.go | 84 +++++++++++++++++++++++++------------------ core/agent/agent.go | 4 +-- 2 files changed, 52 insertions(+), 36 deletions(-) diff --git a/core/agent/actions.go b/core/agent/actions.go index 033b812..4a954be 100644 --- a/core/agent/actions.go +++ b/core/agent/actions.go @@ -24,39 +24,53 @@ type decisionResult struct { func (a *Agent) decision( ctx context.Context, conversation []openai.ChatCompletionMessage, - tools []openai.Tool, toolchoice any) (*decisionResult, error) { + tools []openai.Tool, toolchoice any, maxRetries int) (*decisionResult, error) { - decision := openai.ChatCompletionRequest{ - Model: a.options.LLMAPI.Model, - Messages: conversation, - Tools: tools, - ToolChoice: toolchoice, + var lastErr error + for attempts := 0; attempts < maxRetries; attempts++ { + decision := openai.ChatCompletionRequest{ + Model: a.options.LLMAPI.Model, + Messages: conversation, + Tools: tools, + ToolChoice: toolchoice, + } + + resp, err := a.client.CreateChatCompletion(ctx, decision) + if err != nil { + lastErr = err + xlog.Warn("Attempt to make a decision failed", "attempt", attempts+1, "error", err) + continue + } + + if len(resp.Choices) != 1 { + lastErr = fmt.Errorf("no choices: %d", len(resp.Choices)) + xlog.Warn("Attempt to make a decision failed", "attempt", attempts+1, "error", lastErr) + continue + } + + msg := resp.Choices[0].Message + if len(msg.ToolCalls) != 1 { + if err := a.saveConversation(append(conversation, msg), "decision"); err != nil { + xlog.Error("Error saving conversation", "error", err) + } + return &decisionResult{message: msg.Content}, nil + } + + params := types.ActionParams{} + if err := params.Read(msg.ToolCalls[0].Function.Arguments); err != nil { + lastErr = err + xlog.Warn("Attempt to parse action parameters failed", "attempt", attempts+1, "error", err) + continue + } + + if err := a.saveConversation(append(conversation, msg), "decision"); err != nil { + xlog.Error("Error saving conversation", "error", err) + } + + return &decisionResult{actionParams: params, actioName: msg.ToolCalls[0].Function.Name, message: msg.Content}, nil } - resp, err := a.client.CreateChatCompletion(ctx, decision) - if err != nil { - return nil, err - } - - if len(resp.Choices) != 1 { - return nil, fmt.Errorf("no choices: %d", len(resp.Choices)) - } - - msg := resp.Choices[0].Message - if len(msg.ToolCalls) != 1 { - return &decisionResult{message: msg.Content}, nil - } - - params := types.ActionParams{} - if err := params.Read(msg.ToolCalls[0].Function.Arguments); err != nil { - return nil, err - } - - if err := a.saveConversation(append(conversation, msg), "decision"); err != nil { - xlog.Error("Error saving conversation", "error", err) - } - - return &decisionResult{actionParams: params, actioName: msg.ToolCalls[0].Function.Name, message: msg.Content}, nil + return nil, fmt.Errorf("failed to make a decision after %d attempts: %w", maxRetries, lastErr) } type Messages []openai.ChatCompletionMessage @@ -170,6 +184,7 @@ func (a *Agent) generateParameters(ctx context.Context, pickTemplate string, act Type: openai.ToolTypeFunction, Function: openai.ToolFunction{Name: act.Definition().Name.String()}, }, + maxAttempts, ) if attemptErr == nil && result.actionParams != nil { return result, nil @@ -340,7 +355,7 @@ func (a *Agent) prepareHUD() (promptHUD *PromptHUD) { } // pickAction picks an action based on the conversation -func (a *Agent) pickAction(ctx context.Context, templ string, messages []openai.ChatCompletionMessage) (types.Action, types.ActionParams, string, error) { +func (a *Agent) pickAction(ctx context.Context, templ string, messages []openai.ChatCompletionMessage, maxRetries int) (types.Action, types.ActionParams, string, error) { c := messages if !a.options.forceReasoning { @@ -349,7 +364,8 @@ func (a *Agent) pickAction(ctx context.Context, templ string, messages []openai. thought, err := a.decision(ctx, messages, a.availableActions().ToTools(), - nil) + nil, + maxRetries) if err != nil { return nil, nil, "", err } @@ -390,7 +406,7 @@ func (a *Agent) pickAction(ctx context.Context, templ string, messages []openai. thought, err := a.decision(ctx, c, types.Actions{action.NewReasoning()}.ToTools(), - action.NewReasoning().Definition().Name) + action.NewReasoning().Definition().Name, maxRetries) if err != nil { return nil, nil, "", err } @@ -421,7 +437,7 @@ func (a *Agent) pickAction(ctx context.Context, templ string, messages []openai. Content: "Given the assistant thought, pick the relevant action: " + reason, }), types.Actions{intentionsTools}.ToTools(), - intentionsTools.Definition().Name) + intentionsTools.Definition().Name, maxRetries) if err != nil { return nil, nil, "", fmt.Errorf("failed to get the action tool parameters: %v", err) } diff --git a/core/agent/agent.go b/core/agent/agent.go index 4c8f5ca..7712ea7 100644 --- a/core/agent/agent.go +++ b/core/agent/agent.go @@ -481,7 +481,7 @@ func (a *Agent) consumeJob(job *types.Job, role string) { job.ResetNextAction() } else { var err error - chosenAction, actionParams, reasoning, err = a.pickAction(job.GetContext(), pickTemplate, conv) + chosenAction, actionParams, reasoning, err = a.pickAction(job.GetContext(), pickTemplate, conv, maxRetries) if err != nil { xlog.Error("Error picking action", "error", err) job.Result.Finish(err) @@ -634,7 +634,7 @@ func (a *Agent) consumeJob(job *types.Job, role string) { // given the result, we can now ask OpenAI to complete the conversation or // to continue using another tool given the result - followingAction, followingParams, reasoning, err := a.pickAction(job.GetContext(), reEvaluationTemplate, conv) + followingAction, followingParams, reasoning, err := a.pickAction(job.GetContext(), reEvaluationTemplate, conv, maxRetries) if err != nil { job.Result.Conversation = conv job.Result.Finish(fmt.Errorf("error picking action: %w", err))