fix state update, save/load

This commit is contained in:
mudler
2024-04-04 16:58:25 +02:00
parent 9173156e40
commit b4fd482f66
7 changed files with 151 additions and 23 deletions

View File

@@ -117,13 +117,12 @@ const hudTemplate = `You have a character and your replies and actions might be
{{end}}
This is your current state:
{{if .CurrentState.NowDoing}}NowDoing: {{.CurrentState.NowDoing}} {{end}}
{{if .CurrentState.DoingNext}}DoingNext: {{.CurrentState.DoingNext}} {{end}}
{{if .PermanentGoal}}Your permanent goal is: {{.PermanentGoal}} {{end}}
{{if .CurrentState.Goal}}Your current goal is: {{.CurrentState.Goal}} {{end}}
NowDoing: {{if .CurrentState.NowDoing}}{{.CurrentState.NowDoing}}{{else}}Nothing{{end}}
DoingNext: {{if .CurrentState.DoingNext}}{{.CurrentState.DoingNext}}{{else}}Nothing{{end}}
Your permanent goal is: {{if .PermanentGoal}}{{.PermanentGoal}}{{else}}Nothing{{end}}
Your current goal is: {{if .CurrentState.Goal}}{{.CurrentState.Goal}}{{else}}Nothing{{end}}
You have done: {{range .CurrentState.DoneHistory}}{{.}} {{end}}
You have a short memory with: {{range .CurrentState.Memories}}{{.}} {{end}}
`
You have a short memory with: {{range .CurrentState.Memories}}{{.}} {{end}}`
// pickAction picks an action based on the conversation
func (a *Agent) pickAction(ctx context.Context, templ string, messages []openai.ChatCompletionMessage) (Action, string, error) {
@@ -160,9 +159,11 @@ func (a *Agent) pickAction(ctx context.Context, templ string, messages []openai.
if err != nil {
return nil, "", err
}
//fmt.Println("=== HUD START ===", hud.String(), "=== HUD END ===")
//fmt.Println("=== PROMPT START ===", prompt.String(), "=== PROMPT END ===")
if a.options.debugMode {
fmt.Println("=== HUD START ===", hud.String(), "=== HUD END ===")
fmt.Println("=== PROMPT START ===", prompt.String(), "=== PROMPT END ===")
}
// Get all the available actions IDs
actionsID := []string{}
@@ -214,7 +215,7 @@ func (a *Agent) pickAction(ctx context.Context, templ string, messages []openai.
Actions{intentionsTools}.ToTools(),
intentionsTools.Definition().Name)
if err != nil {
return nil, "", err
return nil, "", fmt.Errorf("failed to get the action tool parameters: %v", err)
}
actionChoice := action.IntentResponse{}

View File

@@ -3,6 +3,7 @@ package agent
import (
"context"
"fmt"
"os"
"strings"
"sync"
"time"
@@ -92,6 +93,35 @@ func New(opts ...Option) (*Agent, error) {
}
}
if a.options.statefile != "" {
if _, err := os.Stat(a.options.statefile); err == nil {
if err = a.LoadState(a.options.statefile); err != nil {
return a, fmt.Errorf("failed to load state: %v", err)
}
}
}
if a.options.characterfile != "" {
if _, err := os.Stat(a.options.characterfile); err == nil {
// if there is a file, load the character back
if err = a.LoadCharacter(a.options.characterfile); err != nil {
return a, fmt.Errorf("failed to load character: %v", err)
}
} else {
// otherwise save it for next time
if err = a.SaveCharacter(a.options.characterfile); err != nil {
return a, fmt.Errorf("failed to save character: %v", err)
}
}
}
if a.options.debugMode {
fmt.Println("=== Agent in Debug mode ===")
fmt.Println(a.Character.String())
fmt.Println(a.State().String())
fmt.Println("Permanent goal: ", a.options.permanentGoal)
}
return a, nil
}
@@ -144,16 +174,28 @@ func (a *Agent) runAction(chosenAction Action, decisionResult *decisionResult) (
}
}
if a.options.debugMode {
fmt.Println("Action", chosenAction.Definition().Name)
fmt.Println("Result", result)
}
if chosenAction.Definition().Name.Is(action.StateActionName) {
// We need to store the result in the state
state := action.StateResult{}
err = decisionResult.actionParams.Unmarshal(&result)
err = decisionResult.actionParams.Unmarshal(&state)
if err != nil {
return "", err
return "", fmt.Errorf("error unmarshalling state of the agent: %w", err)
}
// update the current state with the one we just got from the action
a.currentState = &state
// update the state file
if a.options.statefile != "" {
if err := a.SaveState(a.options.statefile); err != nil {
return "", err
}
}
}
return result, nil
@@ -207,7 +249,7 @@ func (a *Agent) consumeJob(job *Job, role string) {
params, err := a.generateParameters(ctx, chosenAction, a.currentConversation)
if err != nil {
job.Result.Finish(err)
job.Result.Finish(fmt.Errorf("error generating action's parameters: %w", err))
return
}
@@ -222,6 +264,7 @@ func (a *Agent) consumeJob(job *Job, role string) {
return
}
// If we don't have to reply , run the action!
if !chosenAction.Definition().Name.Is(action.ReplyActionName) {
result, err := a.runAction(chosenAction, params)
if err != nil {

View File

@@ -125,5 +125,27 @@ var _ = Describe("Agent test", func() {
}
Expect(reasons).To(ContainElement(testActionResult), fmt.Sprint(res))
})
It("updates the state with internal actions", func() {
agent, err := New(
WithLLMAPIURL(apiModel),
WithModel(testModel),
EnableHUD,
DebugMode,
// EnableStandaloneJob,
WithRandomIdentity(),
WithPermanentGoal("I want to learn to play music"),
)
Expect(err).ToNot(HaveOccurred())
go agent.Run()
defer agent.Stop()
result := agent.Ask(
WithText("Update your goals such as you want to learn to play the guitar"),
)
fmt.Printf("%+v\n", result)
Expect(result.Error).ToNot(HaveOccurred())
Expect(agent.State().Goal).To(ContainSubstring("guitar"), fmt.Sprint(agent.State()))
})
})
})

View File

@@ -19,6 +19,9 @@ type options struct {
randomIdentity bool
userActions Actions
enableHUD, standaloneJob bool
debugMode bool
characterfile string
statefile string
context context.Context
permanentGoal string
}
@@ -54,6 +57,11 @@ var EnableHUD = func(o *options) error {
return nil
}
var DebugMode = func(o *options) error {
o.debugMode = true
return nil
}
// EnableStandaloneJob is an option to enable the agent
// to run jobs in the background automatically
var EnableStandaloneJob = func(o *options) error {

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"github.com/mudler/local-agent-framework/action"
"github.com/mudler/local-agent-framework/llm"
@@ -40,8 +41,39 @@ func Load(path string) (*Character, error) {
return &c, nil
}
func (a *Agent) Save(path string) error {
data, err := json.Marshal(a.options.character)
func (a *Agent) State() action.StateResult {
return *a.currentState
}
func (a *Agent) LoadState(path string) error {
data, err := os.ReadFile(path)
if err != nil {
return err
}
return json.Unmarshal(data, a.currentState)
}
func (a *Agent) LoadCharacter(path string) error {
data, err := os.ReadFile(path)
if err != nil {
return err
}
return json.Unmarshal(data, &a.Character)
}
func (a *Agent) SaveState(path string) error {
os.MkdirAll(filepath.Dir(path), 0755)
data, err := json.Marshal(a.currentState)
if err != nil {
return err
}
os.WriteFile(path, data, 0644)
return nil
}
func (a *Agent) SaveCharacter(path string) error {
os.MkdirAll(filepath.Dir(path), 0755)
data, err := json.Marshal(a.Character)
if err != nil {
return err
}
@@ -59,7 +91,7 @@ func (a *Agent) generateIdentity(guidance string) error {
}
if !a.validCharacter() {
return fmt.Errorf("generated character is not valid ( guidance: %s ): %v", guidance, a.String())
return fmt.Errorf("generated character is not valid ( guidance: %s ): %v", guidance, a.Character.String())
}
return nil
}
@@ -80,13 +112,13 @@ Hobbies: %v
Music taste: %v
=====================`
func (a *Agent) String() string {
func (c *Character) String() string {
return fmt.Sprintf(
fmtT,
a.Character.Name,
a.Character.Age,
a.Character.Occupation,
a.Character.Hobbies,
a.Character.MusicTaste,
c.Name,
c.Age,
c.Occupation,
c.Hobbies,
c.MusicTaste,
)
}