feat: add retries to pickAction
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
@@ -24,39 +24,53 @@ type decisionResult struct {
|
|||||||
func (a *Agent) decision(
|
func (a *Agent) decision(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
conversation []openai.ChatCompletionMessage,
|
conversation []openai.ChatCompletionMessage,
|
||||||
tools []openai.Tool, toolchoice any) (*decisionResult, error) {
|
tools []openai.Tool, toolchoice any, maxRetries int) (*decisionResult, error) {
|
||||||
|
|
||||||
decision := openai.ChatCompletionRequest{
|
var lastErr error
|
||||||
Model: a.options.LLMAPI.Model,
|
for attempts := 0; attempts < maxRetries; attempts++ {
|
||||||
Messages: conversation,
|
decision := openai.ChatCompletionRequest{
|
||||||
Tools: tools,
|
Model: a.options.LLMAPI.Model,
|
||||||
ToolChoice: toolchoice,
|
Messages: conversation,
|
||||||
|
Tools: tools,
|
||||||
|
ToolChoice: toolchoice,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := a.client.CreateChatCompletion(ctx, decision)
|
||||||
|
if err != nil {
|
||||||
|
lastErr = err
|
||||||
|
xlog.Warn("Attempt to make a decision failed", "attempt", attempts+1, "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(resp.Choices) != 1 {
|
||||||
|
lastErr = fmt.Errorf("no choices: %d", len(resp.Choices))
|
||||||
|
xlog.Warn("Attempt to make a decision failed", "attempt", attempts+1, "error", lastErr)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := resp.Choices[0].Message
|
||||||
|
if len(msg.ToolCalls) != 1 {
|
||||||
|
if err := a.saveConversation(append(conversation, msg), "decision"); err != nil {
|
||||||
|
xlog.Error("Error saving conversation", "error", err)
|
||||||
|
}
|
||||||
|
return &decisionResult{message: msg.Content}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
params := types.ActionParams{}
|
||||||
|
if err := params.Read(msg.ToolCalls[0].Function.Arguments); err != nil {
|
||||||
|
lastErr = err
|
||||||
|
xlog.Warn("Attempt to parse action parameters failed", "attempt", attempts+1, "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := a.saveConversation(append(conversation, msg), "decision"); err != nil {
|
||||||
|
xlog.Error("Error saving conversation", "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &decisionResult{actionParams: params, actioName: msg.ToolCalls[0].Function.Name, message: msg.Content}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := a.client.CreateChatCompletion(ctx, decision)
|
return nil, fmt.Errorf("failed to make a decision after %d attempts: %w", maxRetries, lastErr)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(resp.Choices) != 1 {
|
|
||||||
return nil, fmt.Errorf("no choices: %d", len(resp.Choices))
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := resp.Choices[0].Message
|
|
||||||
if len(msg.ToolCalls) != 1 {
|
|
||||||
return &decisionResult{message: msg.Content}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
params := types.ActionParams{}
|
|
||||||
if err := params.Read(msg.ToolCalls[0].Function.Arguments); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := a.saveConversation(append(conversation, msg), "decision"); err != nil {
|
|
||||||
xlog.Error("Error saving conversation", "error", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &decisionResult{actionParams: params, actioName: msg.ToolCalls[0].Function.Name, message: msg.Content}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Messages []openai.ChatCompletionMessage
|
type Messages []openai.ChatCompletionMessage
|
||||||
@@ -170,6 +184,7 @@ func (a *Agent) generateParameters(ctx context.Context, pickTemplate string, act
|
|||||||
Type: openai.ToolTypeFunction,
|
Type: openai.ToolTypeFunction,
|
||||||
Function: openai.ToolFunction{Name: act.Definition().Name.String()},
|
Function: openai.ToolFunction{Name: act.Definition().Name.String()},
|
||||||
},
|
},
|
||||||
|
maxAttempts,
|
||||||
)
|
)
|
||||||
if attemptErr == nil && result.actionParams != nil {
|
if attemptErr == nil && result.actionParams != nil {
|
||||||
return result, nil
|
return result, nil
|
||||||
@@ -340,7 +355,7 @@ func (a *Agent) prepareHUD() (promptHUD *PromptHUD) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// pickAction picks an action based on the conversation
|
// pickAction picks an action based on the conversation
|
||||||
func (a *Agent) pickAction(ctx context.Context, templ string, messages []openai.ChatCompletionMessage) (types.Action, types.ActionParams, string, error) {
|
func (a *Agent) pickAction(ctx context.Context, templ string, messages []openai.ChatCompletionMessage, maxRetries int) (types.Action, types.ActionParams, string, error) {
|
||||||
c := messages
|
c := messages
|
||||||
|
|
||||||
if !a.options.forceReasoning {
|
if !a.options.forceReasoning {
|
||||||
@@ -349,7 +364,8 @@ func (a *Agent) pickAction(ctx context.Context, templ string, messages []openai.
|
|||||||
thought, err := a.decision(ctx,
|
thought, err := a.decision(ctx,
|
||||||
messages,
|
messages,
|
||||||
a.availableActions().ToTools(),
|
a.availableActions().ToTools(),
|
||||||
nil)
|
nil,
|
||||||
|
maxRetries)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, "", err
|
return nil, nil, "", err
|
||||||
}
|
}
|
||||||
@@ -390,7 +406,7 @@ func (a *Agent) pickAction(ctx context.Context, templ string, messages []openai.
|
|||||||
thought, err := a.decision(ctx,
|
thought, err := a.decision(ctx,
|
||||||
c,
|
c,
|
||||||
types.Actions{action.NewReasoning()}.ToTools(),
|
types.Actions{action.NewReasoning()}.ToTools(),
|
||||||
action.NewReasoning().Definition().Name)
|
action.NewReasoning().Definition().Name, maxRetries)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, "", err
|
return nil, nil, "", err
|
||||||
}
|
}
|
||||||
@@ -421,7 +437,7 @@ func (a *Agent) pickAction(ctx context.Context, templ string, messages []openai.
|
|||||||
Content: "Given the assistant thought, pick the relevant action: " + reason,
|
Content: "Given the assistant thought, pick the relevant action: " + reason,
|
||||||
}),
|
}),
|
||||||
types.Actions{intentionsTools}.ToTools(),
|
types.Actions{intentionsTools}.ToTools(),
|
||||||
intentionsTools.Definition().Name)
|
intentionsTools.Definition().Name, maxRetries)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, "", fmt.Errorf("failed to get the action tool parameters: %v", err)
|
return nil, nil, "", fmt.Errorf("failed to get the action tool parameters: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -481,7 +481,7 @@ func (a *Agent) consumeJob(job *types.Job, role string) {
|
|||||||
job.ResetNextAction()
|
job.ResetNextAction()
|
||||||
} else {
|
} else {
|
||||||
var err error
|
var err error
|
||||||
chosenAction, actionParams, reasoning, err = a.pickAction(job.GetContext(), pickTemplate, conv)
|
chosenAction, actionParams, reasoning, err = a.pickAction(job.GetContext(), pickTemplate, conv, maxRetries)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
xlog.Error("Error picking action", "error", err)
|
xlog.Error("Error picking action", "error", err)
|
||||||
job.Result.Finish(err)
|
job.Result.Finish(err)
|
||||||
@@ -634,7 +634,7 @@ func (a *Agent) consumeJob(job *types.Job, role string) {
|
|||||||
|
|
||||||
// given the result, we can now ask OpenAI to complete the conversation or
|
// given the result, we can now ask OpenAI to complete the conversation or
|
||||||
// to continue using another tool given the result
|
// to continue using another tool given the result
|
||||||
followingAction, followingParams, reasoning, err := a.pickAction(job.GetContext(), reEvaluationTemplate, conv)
|
followingAction, followingParams, reasoning, err := a.pickAction(job.GetContext(), reEvaluationTemplate, conv, maxRetries)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
job.Result.Conversation = conv
|
job.Result.Conversation = conv
|
||||||
job.Result.Finish(fmt.Errorf("error picking action: %w", err))
|
job.Result.Finish(fmt.Errorf("error picking action: %w", err))
|
||||||
|
|||||||
Reference in New Issue
Block a user