Do not do two requests if reasoning is disabled

Signed-off-by: mudler <mudler@localai.io>
This commit is contained in:
mudler
2024-05-23 18:58:53 +02:00
parent 3f9b454276
commit 989a2421ba
2 changed files with 47 additions and 37 deletions

View File

@@ -33,6 +33,7 @@ type Agent struct {
currentReasoning string
currentState *action.StateResult
nextAction Action
nextActionParams *action.ActionParams
currentConversation Messages
selfEvaluationInProgress bool
pause bool
@@ -182,10 +183,10 @@ func (a *Agent) Paused() bool {
return a.pause
}
func (a *Agent) runAction(chosenAction Action, decisionResult *decisionResult) (result string, err error) {
func (a *Agent) runAction(chosenAction Action, params action.ActionParams) (result string, err error) {
for _, action := range a.systemInternalActions() {
if action.Definition().Name == chosenAction.Definition().Name {
if result, err = action.Run(a.actionContext, decisionResult.actionParams); err != nil {
if result, err = action.Run(a.actionContext, params); err != nil {
return "", fmt.Errorf("error running action: %w", err)
}
}
@@ -197,7 +198,7 @@ func (a *Agent) runAction(chosenAction Action, decisionResult *decisionResult) (
// We need to store the result in the state
state := action.StateResult{}
err = decisionResult.actionParams.Unmarshal(&state)
err = params.Unmarshal(&state)
if err != nil {
return "", fmt.Errorf("error unmarshalling state of the agent: %w", err)
}
@@ -350,17 +351,20 @@ func (a *Agent) consumeJob(job *Job, role string) {
// choose an action first
var chosenAction Action
var reasoning string
var actionParams action.ActionParams
if a.currentReasoning != "" && a.nextAction != nil {
if a.nextAction != nil {
// if we are being re-evaluated, we already have the action
// and the reasoning. Consume it here and reset it
chosenAction = a.nextAction
reasoning = a.currentReasoning
actionParams = *a.nextActionParams
a.currentReasoning = ""
a.nextActionParams = nil
a.nextAction = nil
} else {
var err error
chosenAction, reasoning, err = a.pickAction(ctx, pickTemplate, a.currentConversation)
chosenAction, actionParams, reasoning, err = a.pickAction(ctx, pickTemplate, a.currentConversation)
if err != nil {
xlog.Error("Error picking action", "error", err)
job.Result.Finish(err)
@@ -391,16 +395,20 @@ func (a *Agent) consumeJob(job *Job, role string) {
return
}
xlog.Info("Generating parameters",
"agent", a.Character.Name,
"action", chosenAction.Definition().Name,
"reasoning", reasoning,
)
// if we force a reasoning, we need to generate the parameters
if a.options.forceReasoning || actionParams == nil {
xlog.Info("Generating parameters",
"agent", a.Character.Name,
"action", chosenAction.Definition().Name,
"reasoning", reasoning,
)
params, err := a.generateParameters(ctx, pickTemplate, chosenAction, a.currentConversation, reasoning)
if err != nil {
job.Result.Finish(fmt.Errorf("error generating action's parameters: %w", err))
return
params, err := a.generateParameters(ctx, pickTemplate, chosenAction, a.currentConversation, reasoning)
if err != nil {
job.Result.Finish(fmt.Errorf("error generating action's parameters: %w", err))
return
}
actionParams = params.actionParams
}
xlog.Info(
@@ -408,16 +416,16 @@ func (a *Agent) consumeJob(job *Job, role string) {
"agent", a.Character.Name,
"action", chosenAction.Definition().Name,
"reasoning", reasoning,
"params", params.actionParams.String(),
"params", actionParams.String(),
)
if params.actionParams == nil {
if actionParams == nil {
job.Result.Finish(fmt.Errorf("no parameters"))
return
}
if !job.Callback(ActionCurrentState{chosenAction, params.actionParams, reasoning}) {
job.Result.SetResult(ActionState{ActionCurrentState{chosenAction, params.actionParams, reasoning}, "stopped by callback"})
if !job.Callback(ActionCurrentState{chosenAction, actionParams, reasoning}) {
job.Result.SetResult(ActionState{ActionCurrentState{chosenAction, actionParams, reasoning}, "stopped by callback"})
job.Result.Finish(nil)
return
}
@@ -426,7 +434,7 @@ func (a *Agent) consumeJob(job *Job, role string) {
chosenAction.Definition().Name.Is(action.ConversationActionName) {
message := action.ConversationActionResponse{}
if err := params.actionParams.Unmarshal(&message); err != nil {
if err := actionParams.Unmarshal(&message); err != nil {
job.Result.Finish(fmt.Errorf("error unmarshalling conversation response: %w", err))
return
}
@@ -450,7 +458,7 @@ func (a *Agent) consumeJob(job *Job, role string) {
// If we don't have to reply , run the action!
if !chosenAction.Definition().Name.Is(action.ReplyActionName) {
result, err := a.runAction(chosenAction, params)
result, err := a.runAction(chosenAction, actionParams)
if err != nil {
//job.Result.Finish(fmt.Errorf("error running action: %w", err))
//return
@@ -458,7 +466,7 @@ func (a *Agent) consumeJob(job *Job, role string) {
result = fmt.Sprintf("Error running tool: %v", err)
}
stateResult := ActionState{ActionCurrentState{chosenAction, params.actionParams, reasoning}, result}
stateResult := ActionState{ActionCurrentState{chosenAction, actionParams, reasoning}, result}
job.Result.SetResult(stateResult)
job.CallbackWithResult(stateResult)
@@ -467,7 +475,7 @@ func (a *Agent) consumeJob(job *Job, role string) {
Role: "assistant",
FunctionCall: &openai.FunctionCall{
Name: chosenAction.Definition().Name.String(),
Arguments: params.actionParams.String(),
Arguments: actionParams.String(),
},
})
@@ -484,7 +492,7 @@ func (a *Agent) consumeJob(job *Job, role string) {
// given the result, we can now ask OpenAI to complete the conversation or
// to continue using another tool given the result
followingAction, reasoning, err := a.pickAction(ctx, reEvaluationTemplate, a.currentConversation)
followingAction, followingParams, reasoning, err := a.pickAction(ctx, reEvaluationTemplate, a.currentConversation)
if err != nil {
job.Result.Finish(fmt.Errorf("error picking action: %w", err))
return
@@ -498,6 +506,7 @@ func (a *Agent) consumeJob(job *Job, role string) {
// call ourselves again
a.currentReasoning = reasoning
a.nextAction = followingAction
a.nextActionParams = &followingParams
job.Text = ""
a.consumeJob(job, role)
return
@@ -521,7 +530,7 @@ func (a *Agent) consumeJob(job *Job, role string) {
// decode the response
replyResponse := action.ReplyResponse{}
if err := params.actionParams.Unmarshal(&replyResponse); err != nil {
if err := actionParams.Unmarshal(&replyResponse); err != nil {
job.Result.Finish(fmt.Errorf("error unmarshalling reply response: %w", err))
return
}