feat: add retries to pickAction

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2025-03-28 22:04:04 +01:00
parent 05cb8ba2eb
commit 906b4ebd76
2 changed files with 52 additions and 36 deletions

View File

@@ -24,39 +24,53 @@ type decisionResult struct {
func (a *Agent) decision(
ctx context.Context,
conversation []openai.ChatCompletionMessage,
tools []openai.Tool, toolchoice any) (*decisionResult, error) {
tools []openai.Tool, toolchoice any, maxRetries int) (*decisionResult, error) {
decision := openai.ChatCompletionRequest{
Model: a.options.LLMAPI.Model,
Messages: conversation,
Tools: tools,
ToolChoice: toolchoice,
var lastErr error
for attempts := 0; attempts < maxRetries; attempts++ {
decision := openai.ChatCompletionRequest{
Model: a.options.LLMAPI.Model,
Messages: conversation,
Tools: tools,
ToolChoice: toolchoice,
}
resp, err := a.client.CreateChatCompletion(ctx, decision)
if err != nil {
lastErr = err
xlog.Warn("Attempt to make a decision failed", "attempt", attempts+1, "error", err)
continue
}
if len(resp.Choices) != 1 {
lastErr = fmt.Errorf("no choices: %d", len(resp.Choices))
xlog.Warn("Attempt to make a decision failed", "attempt", attempts+1, "error", lastErr)
continue
}
msg := resp.Choices[0].Message
if len(msg.ToolCalls) != 1 {
if err := a.saveConversation(append(conversation, msg), "decision"); err != nil {
xlog.Error("Error saving conversation", "error", err)
}
return &decisionResult{message: msg.Content}, nil
}
params := types.ActionParams{}
if err := params.Read(msg.ToolCalls[0].Function.Arguments); err != nil {
lastErr = err
xlog.Warn("Attempt to parse action parameters failed", "attempt", attempts+1, "error", err)
continue
}
if err := a.saveConversation(append(conversation, msg), "decision"); err != nil {
xlog.Error("Error saving conversation", "error", err)
}
return &decisionResult{actionParams: params, actioName: msg.ToolCalls[0].Function.Name, message: msg.Content}, nil
}
resp, err := a.client.CreateChatCompletion(ctx, decision)
if err != nil {
return nil, err
}
if len(resp.Choices) != 1 {
return nil, fmt.Errorf("no choices: %d", len(resp.Choices))
}
msg := resp.Choices[0].Message
if len(msg.ToolCalls) != 1 {
return &decisionResult{message: msg.Content}, nil
}
params := types.ActionParams{}
if err := params.Read(msg.ToolCalls[0].Function.Arguments); err != nil {
return nil, err
}
if err := a.saveConversation(append(conversation, msg), "decision"); err != nil {
xlog.Error("Error saving conversation", "error", err)
}
return &decisionResult{actionParams: params, actioName: msg.ToolCalls[0].Function.Name, message: msg.Content}, nil
return nil, fmt.Errorf("failed to make a decision after %d attempts: %w", maxRetries, lastErr)
}
type Messages []openai.ChatCompletionMessage
@@ -170,6 +184,7 @@ func (a *Agent) generateParameters(ctx context.Context, pickTemplate string, act
Type: openai.ToolTypeFunction,
Function: openai.ToolFunction{Name: act.Definition().Name.String()},
},
maxAttempts,
)
if attemptErr == nil && result.actionParams != nil {
return result, nil
@@ -340,7 +355,7 @@ func (a *Agent) prepareHUD() (promptHUD *PromptHUD) {
}
// pickAction picks an action based on the conversation
func (a *Agent) pickAction(ctx context.Context, templ string, messages []openai.ChatCompletionMessage) (types.Action, types.ActionParams, string, error) {
func (a *Agent) pickAction(ctx context.Context, templ string, messages []openai.ChatCompletionMessage, maxRetries int) (types.Action, types.ActionParams, string, error) {
c := messages
if !a.options.forceReasoning {
@@ -349,7 +364,8 @@ func (a *Agent) pickAction(ctx context.Context, templ string, messages []openai.
thought, err := a.decision(ctx,
messages,
a.availableActions().ToTools(),
nil)
nil,
maxRetries)
if err != nil {
return nil, nil, "", err
}
@@ -390,7 +406,7 @@ func (a *Agent) pickAction(ctx context.Context, templ string, messages []openai.
thought, err := a.decision(ctx,
c,
types.Actions{action.NewReasoning()}.ToTools(),
action.NewReasoning().Definition().Name)
action.NewReasoning().Definition().Name, maxRetries)
if err != nil {
return nil, nil, "", err
}
@@ -421,7 +437,7 @@ func (a *Agent) pickAction(ctx context.Context, templ string, messages []openai.
Content: "Given the assistant thought, pick the relevant action: " + reason,
}),
types.Actions{intentionsTools}.ToTools(),
intentionsTools.Definition().Name)
intentionsTools.Definition().Name, maxRetries)
if err != nil {
return nil, nil, "", fmt.Errorf("failed to get the action tool parameters: %v", err)
}

View File

@@ -481,7 +481,7 @@ func (a *Agent) consumeJob(job *types.Job, role string) {
job.ResetNextAction()
} else {
var err error
chosenAction, actionParams, reasoning, err = a.pickAction(job.GetContext(), pickTemplate, conv)
chosenAction, actionParams, reasoning, err = a.pickAction(job.GetContext(), pickTemplate, conv, maxRetries)
if err != nil {
xlog.Error("Error picking action", "error", err)
job.Result.Finish(err)
@@ -634,7 +634,7 @@ func (a *Agent) consumeJob(job *types.Job, role string) {
// given the result, we can now ask OpenAI to complete the conversation or
// to continue using another tool given the result
followingAction, followingParams, reasoning, err := a.pickAction(job.GetContext(), reEvaluationTemplate, conv)
followingAction, followingParams, reasoning, err := a.pickAction(job.GetContext(), reEvaluationTemplate, conv, maxRetries)
if err != nil {
job.Result.Conversation = conv
job.Result.Finish(fmt.Errorf("error picking action: %w", err))