feat: add retries to pickAction
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
@@ -24,39 +24,53 @@ type decisionResult struct {
|
||||
func (a *Agent) decision(
|
||||
ctx context.Context,
|
||||
conversation []openai.ChatCompletionMessage,
|
||||
tools []openai.Tool, toolchoice any) (*decisionResult, error) {
|
||||
tools []openai.Tool, toolchoice any, maxRetries int) (*decisionResult, error) {
|
||||
|
||||
decision := openai.ChatCompletionRequest{
|
||||
Model: a.options.LLMAPI.Model,
|
||||
Messages: conversation,
|
||||
Tools: tools,
|
||||
ToolChoice: toolchoice,
|
||||
var lastErr error
|
||||
for attempts := 0; attempts < maxRetries; attempts++ {
|
||||
decision := openai.ChatCompletionRequest{
|
||||
Model: a.options.LLMAPI.Model,
|
||||
Messages: conversation,
|
||||
Tools: tools,
|
||||
ToolChoice: toolchoice,
|
||||
}
|
||||
|
||||
resp, err := a.client.CreateChatCompletion(ctx, decision)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
xlog.Warn("Attempt to make a decision failed", "attempt", attempts+1, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if len(resp.Choices) != 1 {
|
||||
lastErr = fmt.Errorf("no choices: %d", len(resp.Choices))
|
||||
xlog.Warn("Attempt to make a decision failed", "attempt", attempts+1, "error", lastErr)
|
||||
continue
|
||||
}
|
||||
|
||||
msg := resp.Choices[0].Message
|
||||
if len(msg.ToolCalls) != 1 {
|
||||
if err := a.saveConversation(append(conversation, msg), "decision"); err != nil {
|
||||
xlog.Error("Error saving conversation", "error", err)
|
||||
}
|
||||
return &decisionResult{message: msg.Content}, nil
|
||||
}
|
||||
|
||||
params := types.ActionParams{}
|
||||
if err := params.Read(msg.ToolCalls[0].Function.Arguments); err != nil {
|
||||
lastErr = err
|
||||
xlog.Warn("Attempt to parse action parameters failed", "attempt", attempts+1, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := a.saveConversation(append(conversation, msg), "decision"); err != nil {
|
||||
xlog.Error("Error saving conversation", "error", err)
|
||||
}
|
||||
|
||||
return &decisionResult{actionParams: params, actioName: msg.ToolCalls[0].Function.Name, message: msg.Content}, nil
|
||||
}
|
||||
|
||||
resp, err := a.client.CreateChatCompletion(ctx, decision)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(resp.Choices) != 1 {
|
||||
return nil, fmt.Errorf("no choices: %d", len(resp.Choices))
|
||||
}
|
||||
|
||||
msg := resp.Choices[0].Message
|
||||
if len(msg.ToolCalls) != 1 {
|
||||
return &decisionResult{message: msg.Content}, nil
|
||||
}
|
||||
|
||||
params := types.ActionParams{}
|
||||
if err := params.Read(msg.ToolCalls[0].Function.Arguments); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := a.saveConversation(append(conversation, msg), "decision"); err != nil {
|
||||
xlog.Error("Error saving conversation", "error", err)
|
||||
}
|
||||
|
||||
return &decisionResult{actionParams: params, actioName: msg.ToolCalls[0].Function.Name, message: msg.Content}, nil
|
||||
return nil, fmt.Errorf("failed to make a decision after %d attempts: %w", maxRetries, lastErr)
|
||||
}
|
||||
|
||||
type Messages []openai.ChatCompletionMessage
|
||||
@@ -170,6 +184,7 @@ func (a *Agent) generateParameters(ctx context.Context, pickTemplate string, act
|
||||
Type: openai.ToolTypeFunction,
|
||||
Function: openai.ToolFunction{Name: act.Definition().Name.String()},
|
||||
},
|
||||
maxAttempts,
|
||||
)
|
||||
if attemptErr == nil && result.actionParams != nil {
|
||||
return result, nil
|
||||
@@ -340,7 +355,7 @@ func (a *Agent) prepareHUD() (promptHUD *PromptHUD) {
|
||||
}
|
||||
|
||||
// pickAction picks an action based on the conversation
|
||||
func (a *Agent) pickAction(ctx context.Context, templ string, messages []openai.ChatCompletionMessage) (types.Action, types.ActionParams, string, error) {
|
||||
func (a *Agent) pickAction(ctx context.Context, templ string, messages []openai.ChatCompletionMessage, maxRetries int) (types.Action, types.ActionParams, string, error) {
|
||||
c := messages
|
||||
|
||||
if !a.options.forceReasoning {
|
||||
@@ -349,7 +364,8 @@ func (a *Agent) pickAction(ctx context.Context, templ string, messages []openai.
|
||||
thought, err := a.decision(ctx,
|
||||
messages,
|
||||
a.availableActions().ToTools(),
|
||||
nil)
|
||||
nil,
|
||||
maxRetries)
|
||||
if err != nil {
|
||||
return nil, nil, "", err
|
||||
}
|
||||
@@ -390,7 +406,7 @@ func (a *Agent) pickAction(ctx context.Context, templ string, messages []openai.
|
||||
thought, err := a.decision(ctx,
|
||||
c,
|
||||
types.Actions{action.NewReasoning()}.ToTools(),
|
||||
action.NewReasoning().Definition().Name)
|
||||
action.NewReasoning().Definition().Name, maxRetries)
|
||||
if err != nil {
|
||||
return nil, nil, "", err
|
||||
}
|
||||
@@ -421,7 +437,7 @@ func (a *Agent) pickAction(ctx context.Context, templ string, messages []openai.
|
||||
Content: "Given the assistant thought, pick the relevant action: " + reason,
|
||||
}),
|
||||
types.Actions{intentionsTools}.ToTools(),
|
||||
intentionsTools.Definition().Name)
|
||||
intentionsTools.Definition().Name, maxRetries)
|
||||
if err != nil {
|
||||
return nil, nil, "", fmt.Errorf("failed to get the action tool parameters: %v", err)
|
||||
}
|
||||
|
||||
@@ -481,7 +481,7 @@ func (a *Agent) consumeJob(job *types.Job, role string) {
|
||||
job.ResetNextAction()
|
||||
} else {
|
||||
var err error
|
||||
chosenAction, actionParams, reasoning, err = a.pickAction(job.GetContext(), pickTemplate, conv)
|
||||
chosenAction, actionParams, reasoning, err = a.pickAction(job.GetContext(), pickTemplate, conv, maxRetries)
|
||||
if err != nil {
|
||||
xlog.Error("Error picking action", "error", err)
|
||||
job.Result.Finish(err)
|
||||
@@ -634,7 +634,7 @@ func (a *Agent) consumeJob(job *types.Job, role string) {
|
||||
|
||||
// given the result, we can now ask OpenAI to complete the conversation or
|
||||
// to continue using another tool given the result
|
||||
followingAction, followingParams, reasoning, err := a.pickAction(job.GetContext(), reEvaluationTemplate, conv)
|
||||
followingAction, followingParams, reasoning, err := a.pickAction(job.GetContext(), reEvaluationTemplate, conv, maxRetries)
|
||||
if err != nil {
|
||||
job.Result.Conversation = conv
|
||||
job.Result.Finish(fmt.Errorf("error picking action: %w", err))
|
||||
|
||||
Reference in New Issue
Block a user