feat: add retries to pickAction

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2025-03-28 22:04:04 +01:00
parent 05cb8ba2eb
commit 906b4ebd76
2 changed files with 52 additions and 36 deletions

View File

@@ -24,8 +24,10 @@ type decisionResult struct {
func (a *Agent) decision( func (a *Agent) decision(
ctx context.Context, ctx context.Context,
conversation []openai.ChatCompletionMessage, conversation []openai.ChatCompletionMessage,
tools []openai.Tool, toolchoice any) (*decisionResult, error) { tools []openai.Tool, toolchoice any, maxRetries int) (*decisionResult, error) {
var lastErr error
for attempts := 0; attempts < maxRetries; attempts++ {
decision := openai.ChatCompletionRequest{ decision := openai.ChatCompletionRequest{
Model: a.options.LLMAPI.Model, Model: a.options.LLMAPI.Model,
Messages: conversation, Messages: conversation,
@@ -35,21 +37,30 @@ func (a *Agent) decision(
resp, err := a.client.CreateChatCompletion(ctx, decision) resp, err := a.client.CreateChatCompletion(ctx, decision)
if err != nil { if err != nil {
return nil, err lastErr = err
xlog.Warn("Attempt to make a decision failed", "attempt", attempts+1, "error", err)
continue
} }
if len(resp.Choices) != 1 { if len(resp.Choices) != 1 {
return nil, fmt.Errorf("no choices: %d", len(resp.Choices)) lastErr = fmt.Errorf("no choices: %d", len(resp.Choices))
xlog.Warn("Attempt to make a decision failed", "attempt", attempts+1, "error", lastErr)
continue
} }
msg := resp.Choices[0].Message msg := resp.Choices[0].Message
if len(msg.ToolCalls) != 1 { if len(msg.ToolCalls) != 1 {
if err := a.saveConversation(append(conversation, msg), "decision"); err != nil {
xlog.Error("Error saving conversation", "error", err)
}
return &decisionResult{message: msg.Content}, nil return &decisionResult{message: msg.Content}, nil
} }
params := types.ActionParams{} params := types.ActionParams{}
if err := params.Read(msg.ToolCalls[0].Function.Arguments); err != nil { if err := params.Read(msg.ToolCalls[0].Function.Arguments); err != nil {
return nil, err lastErr = err
xlog.Warn("Attempt to parse action parameters failed", "attempt", attempts+1, "error", err)
continue
} }
if err := a.saveConversation(append(conversation, msg), "decision"); err != nil { if err := a.saveConversation(append(conversation, msg), "decision"); err != nil {
@@ -57,6 +68,9 @@ func (a *Agent) decision(
} }
return &decisionResult{actionParams: params, actioName: msg.ToolCalls[0].Function.Name, message: msg.Content}, nil return &decisionResult{actionParams: params, actioName: msg.ToolCalls[0].Function.Name, message: msg.Content}, nil
}
return nil, fmt.Errorf("failed to make a decision after %d attempts: %w", maxRetries, lastErr)
} }
type Messages []openai.ChatCompletionMessage type Messages []openai.ChatCompletionMessage
@@ -170,6 +184,7 @@ func (a *Agent) generateParameters(ctx context.Context, pickTemplate string, act
Type: openai.ToolTypeFunction, Type: openai.ToolTypeFunction,
Function: openai.ToolFunction{Name: act.Definition().Name.String()}, Function: openai.ToolFunction{Name: act.Definition().Name.String()},
}, },
maxAttempts,
) )
if attemptErr == nil && result.actionParams != nil { if attemptErr == nil && result.actionParams != nil {
return result, nil return result, nil
@@ -340,7 +355,7 @@ func (a *Agent) prepareHUD() (promptHUD *PromptHUD) {
} }
// pickAction picks an action based on the conversation // pickAction picks an action based on the conversation
func (a *Agent) pickAction(ctx context.Context, templ string, messages []openai.ChatCompletionMessage) (types.Action, types.ActionParams, string, error) { func (a *Agent) pickAction(ctx context.Context, templ string, messages []openai.ChatCompletionMessage, maxRetries int) (types.Action, types.ActionParams, string, error) {
c := messages c := messages
if !a.options.forceReasoning { if !a.options.forceReasoning {
@@ -349,7 +364,8 @@ func (a *Agent) pickAction(ctx context.Context, templ string, messages []openai.
thought, err := a.decision(ctx, thought, err := a.decision(ctx,
messages, messages,
a.availableActions().ToTools(), a.availableActions().ToTools(),
nil) nil,
maxRetries)
if err != nil { if err != nil {
return nil, nil, "", err return nil, nil, "", err
} }
@@ -390,7 +406,7 @@ func (a *Agent) pickAction(ctx context.Context, templ string, messages []openai.
thought, err := a.decision(ctx, thought, err := a.decision(ctx,
c, c,
types.Actions{action.NewReasoning()}.ToTools(), types.Actions{action.NewReasoning()}.ToTools(),
action.NewReasoning().Definition().Name) action.NewReasoning().Definition().Name, maxRetries)
if err != nil { if err != nil {
return nil, nil, "", err return nil, nil, "", err
} }
@@ -421,7 +437,7 @@ func (a *Agent) pickAction(ctx context.Context, templ string, messages []openai.
Content: "Given the assistant thought, pick the relevant action: " + reason, Content: "Given the assistant thought, pick the relevant action: " + reason,
}), }),
types.Actions{intentionsTools}.ToTools(), types.Actions{intentionsTools}.ToTools(),
intentionsTools.Definition().Name) intentionsTools.Definition().Name, maxRetries)
if err != nil { if err != nil {
return nil, nil, "", fmt.Errorf("failed to get the action tool parameters: %v", err) return nil, nil, "", fmt.Errorf("failed to get the action tool parameters: %v", err)
} }

View File

@@ -481,7 +481,7 @@ func (a *Agent) consumeJob(job *types.Job, role string) {
job.ResetNextAction() job.ResetNextAction()
} else { } else {
var err error var err error
chosenAction, actionParams, reasoning, err = a.pickAction(job.GetContext(), pickTemplate, conv) chosenAction, actionParams, reasoning, err = a.pickAction(job.GetContext(), pickTemplate, conv, maxRetries)
if err != nil { if err != nil {
xlog.Error("Error picking action", "error", err) xlog.Error("Error picking action", "error", err)
job.Result.Finish(err) job.Result.Finish(err)
@@ -634,7 +634,7 @@ func (a *Agent) consumeJob(job *types.Job, role string) {
// given the result, we can now ask OpenAI to complete the conversation or // given the result, we can now ask OpenAI to complete the conversation or
// to continue using another tool given the result // to continue using another tool given the result
followingAction, followingParams, reasoning, err := a.pickAction(job.GetContext(), reEvaluationTemplate, conv) followingAction, followingParams, reasoning, err := a.pickAction(job.GetContext(), reEvaluationTemplate, conv, maxRetries)
if err != nil { if err != nil {
job.Result.Conversation = conv job.Result.Conversation = conv
job.Result.Finish(fmt.Errorf("error picking action: %w", err)) job.Result.Finish(fmt.Errorf("error picking action: %w", err))