feat: track plan action when is being executed, also tests (#72)
* feat: track plan action when is being executed, also tests * Update core/agent/agent_test.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update core/agent/actions.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
e5e238efc0
commit
abb3ffc109
@@ -225,6 +225,12 @@ func (a *Agent) handlePlanning(ctx context.Context, job *Job, chosenAction Actio
|
||||
return fmt.Errorf("error unmarshalling plan result: %w", err)
|
||||
}
|
||||
|
||||
stateResult := ActionState{ActionCurrentState{chosenAction, actionParams, reasoning}, action.ActionResult{
|
||||
Result: fmt.Sprintf("planning %s, subtasks: %+v", planResult.Goal, planResult.Subtasks),
|
||||
}}
|
||||
job.Result.SetResult(stateResult)
|
||||
job.CallbackWithResult(stateResult)
|
||||
|
||||
xlog.Info("[Planning] starts", "agent", a.Character.Name, "goal", planResult.Goal)
|
||||
for _, s := range planResult.Subtasks {
|
||||
xlog.Info("[Planning] subtask", "agent", a.Character.Name, "action", s.Action, "reasoning", s.Reasoning)
|
||||
@@ -242,25 +248,33 @@ func (a *Agent) handlePlanning(ctx context.Context, job *Job, chosenAction Actio
|
||||
"reasoning", reasoning,
|
||||
)
|
||||
|
||||
action := a.availableActions().Find(subtask.Action)
|
||||
subTaskAction := a.availableActions().Find(subtask.Action)
|
||||
subTaskReasoning := fmt.Sprintf("%s, overall goal is: %s", subtask.Reasoning, planResult.Goal)
|
||||
|
||||
params, err := a.generateParameters(ctx, pickTemplate, action, a.currentConversation, fmt.Sprintf("%s, overall goal is: %s", subtask.Reasoning, planResult.Goal))
|
||||
params, err := a.generateParameters(ctx, pickTemplate, subTaskAction, a.currentConversation, subTaskReasoning)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error generating action's parameters: %w", err)
|
||||
|
||||
}
|
||||
actionParams = params.actionParams
|
||||
|
||||
result, err := a.runAction(action, actionParams)
|
||||
if !job.Callback(ActionCurrentState{subTaskAction, actionParams, subTaskReasoning}) {
|
||||
job.Result.SetResult(ActionState{ActionCurrentState{chosenAction, actionParams, subTaskReasoning}, action.ActionResult{Result: "stopped by callback"}})
|
||||
job.Result.Conversation = a.currentConversation
|
||||
job.Result.Finish(nil)
|
||||
break
|
||||
}
|
||||
|
||||
result, err := a.runAction(subTaskAction, actionParams)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error running action: %w", err)
|
||||
}
|
||||
|
||||
stateResult := ActionState{ActionCurrentState{action, actionParams, subtask.Reasoning}, result}
|
||||
stateResult := ActionState{ActionCurrentState{subTaskAction, actionParams, subTaskReasoning}, result}
|
||||
job.Result.SetResult(stateResult)
|
||||
job.CallbackWithResult(stateResult)
|
||||
xlog.Debug("[subtask] Action executed", "agent", a.Character.Name, "action", action.Definition().Name, "result", result)
|
||||
a.addFunctionResultToConversation(action, actionParams, result)
|
||||
xlog.Debug("[subtask] Action executed", "agent", a.Character.Name, "action", subTaskAction.Definition().Name, "result", result)
|
||||
a.addFunctionResultToConversation(subTaskAction, actionParams, result)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -612,21 +612,24 @@ func (a *Agent) consumeJob(job *Job, role string) {
|
||||
|
||||
// If we don't have to reply , run the action!
|
||||
if !chosenAction.Definition().Name.Is(action.ReplyActionName) {
|
||||
result, err := a.runAction(chosenAction, actionParams)
|
||||
if err != nil {
|
||||
//job.Result.Finish(fmt.Errorf("error running action: %w", err))
|
||||
//return
|
||||
// make the LLM aware of the error of running the action instead of stopping the job here
|
||||
result.Result = fmt.Sprintf("Error running tool: %v", err)
|
||||
|
||||
if !chosenAction.Definition().Name.Is(action.PlanActionName) {
|
||||
result, err := a.runAction(chosenAction, actionParams)
|
||||
if err != nil {
|
||||
//job.Result.Finish(fmt.Errorf("error running action: %w", err))
|
||||
//return
|
||||
// make the LLM aware of the error of running the action instead of stopping the job here
|
||||
result.Result = fmt.Sprintf("Error running tool: %v", err)
|
||||
}
|
||||
|
||||
stateResult := ActionState{ActionCurrentState{chosenAction, actionParams, reasoning}, result}
|
||||
job.Result.SetResult(stateResult)
|
||||
job.CallbackWithResult(stateResult)
|
||||
xlog.Debug("Action executed", "agent", a.Character.Name, "action", chosenAction.Definition().Name, "result", result)
|
||||
|
||||
a.addFunctionResultToConversation(chosenAction, actionParams, result)
|
||||
}
|
||||
|
||||
stateResult := ActionState{ActionCurrentState{chosenAction, actionParams, reasoning}, result}
|
||||
job.Result.SetResult(stateResult)
|
||||
job.CallbackWithResult(stateResult)
|
||||
xlog.Debug("Action executed", "agent", a.Character.Name, "action", chosenAction.Definition().Name, "result", result)
|
||||
|
||||
a.addFunctionResultToConversation(chosenAction, actionParams, result)
|
||||
|
||||
//a.currentConversation = append(a.currentConversation, messages...)
|
||||
//a.currentConversation = messages
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ func TestAgent(t *testing.T) {
|
||||
|
||||
var testModel = os.Getenv("LOCALAGENT_MODEL")
|
||||
var apiURL = os.Getenv("LOCALAI_API_URL")
|
||||
var apiKeyURL = os.Getenv("LOCALAI_API_KEY")
|
||||
|
||||
func init() {
|
||||
if testModel == "" {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAgent/pkg/xlog"
|
||||
"github.com/mudler/LocalAgent/services/actions"
|
||||
|
||||
"github.com/mudler/LocalAgent/core/action"
|
||||
. "github.com/mudler/LocalAgent/core/agent"
|
||||
@@ -202,6 +203,38 @@ var _ = Describe("Agent test", func() {
|
||||
Expect(agent.State().Goal).To(ContainSubstring("guitar"), fmt.Sprint(agent.State()))
|
||||
})
|
||||
|
||||
It("Can generate a plan", func() {
|
||||
agent, err := New(
|
||||
WithLLMAPIURL(apiURL),
|
||||
WithModel(testModel),
|
||||
WithLLMAPIKey(apiKeyURL),
|
||||
WithActions(
|
||||
actions.NewSearch(map[string]string{}),
|
||||
),
|
||||
EnablePlanning,
|
||||
EnableForceReasoning,
|
||||
// EnableStandaloneJob,
|
||||
// WithRandomIdentity(),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go agent.Run()
|
||||
defer agent.Stop()
|
||||
|
||||
result := agent.Ask(
|
||||
WithText("plan a trip to San Francisco from Venice, Italy"),
|
||||
)
|
||||
Expect(len(result.State)).To(BeNumerically(">", 1))
|
||||
|
||||
actionsExecuted := []string{}
|
||||
for _, r := range result.State {
|
||||
xlog.Info(r.Result)
|
||||
actionsExecuted = append(actionsExecuted, r.Action.Definition().Name.String())
|
||||
}
|
||||
Expect(actionsExecuted).To(ContainElement("search_internet"), fmt.Sprint(result))
|
||||
Expect(actionsExecuted).To(ContainElement("plan"), fmt.Sprint(result))
|
||||
|
||||
})
|
||||
|
||||
/*
|
||||
It("it automatically performs things in the background", func() {
|
||||
agent, err := New(
|
||||
|
||||
Reference in New Issue
Block a user