fix state update, save/load

This commit is contained in:
mudler
2024-04-04 16:58:25 +02:00
parent 9173156e40
commit b4fd482f66
7 changed files with 151 additions and 23 deletions

View File

@@ -24,7 +24,7 @@ func NewContext(ctx context.Context, cancel context.CancelFunc) *ActionContext {
} }
} }
type ActionParams map[string]string type ActionParams map[string]interface{}
func (ap ActionParams) Read(s string) error { func (ap ActionParams) Read(s string) error {
err := json.Unmarshal([]byte(s), &ap) err := json.Unmarshal([]byte(s), &ap)

View File

@@ -1,6 +1,8 @@
package action package action
import ( import (
"fmt"
"github.com/sashabaranov/go-openai/jsonschema" "github.com/sashabaranov/go-openai/jsonschema"
) )
@@ -31,7 +33,7 @@ type StateResult struct {
} }
func (a *StateAction) Run(ActionParams) (string, error) { func (a *StateAction) Run(ActionParams) (string, error) {
return "no-op", nil return "internal state has been updated", nil
} }
func (a *StateAction) Definition() ActionDefinition { func (a *StateAction) Definition() ActionDefinition {
@@ -68,3 +70,23 @@ func (a *StateAction) Definition() ActionDefinition {
}, },
} }
} }
const fmtT = `=====================
NowDoing: %s
DoingNext: %s
Your current goal is: %s
You have done: %+v
You have a short memory with: %+v
=====================
`
func (c StateResult) String() string {
return fmt.Sprintf(
fmtT,
c.NowDoing,
c.DoingNext,
c.Goal,
c.DoneHistory,
c.Memories,
)
}

View File

@@ -117,13 +117,12 @@ const hudTemplate = `You have a character and your replies and actions might be
{{end}} {{end}}
This is your current state: This is your current state:
{{if .CurrentState.NowDoing}}NowDoing: {{.CurrentState.NowDoing}} {{end}} NowDoing: {{if .CurrentState.NowDoing}}{{.CurrentState.NowDoing}}{{else}}Nothing{{end}}
{{if .CurrentState.DoingNext}}DoingNext: {{.CurrentState.DoingNext}} {{end}} DoingNext: {{if .CurrentState.DoingNext}}{{.CurrentState.DoingNext}}{{else}}Nothing{{end}}
{{if .PermanentGoal}}Your permanent goal is: {{.PermanentGoal}} {{end}} Your permanent goal is: {{if .PermanentGoal}}{{.PermanentGoal}}{{else}}Nothing{{end}}
{{if .CurrentState.Goal}}Your current goal is: {{.CurrentState.Goal}} {{end}} Your current goal is: {{if .CurrentState.Goal}}{{.CurrentState.Goal}}{{else}}Nothing{{end}}
You have done: {{range .CurrentState.DoneHistory}}{{.}} {{end}} You have done: {{range .CurrentState.DoneHistory}}{{.}} {{end}}
You have a short memory with: {{range .CurrentState.Memories}}{{.}} {{end}} You have a short memory with: {{range .CurrentState.Memories}}{{.}} {{end}}`
`
// pickAction picks an action based on the conversation // pickAction picks an action based on the conversation
func (a *Agent) pickAction(ctx context.Context, templ string, messages []openai.ChatCompletionMessage) (Action, string, error) { func (a *Agent) pickAction(ctx context.Context, templ string, messages []openai.ChatCompletionMessage) (Action, string, error) {
@@ -160,9 +159,11 @@ func (a *Agent) pickAction(ctx context.Context, templ string, messages []openai.
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
//fmt.Println("=== HUD START ===", hud.String(), "=== HUD END ===")
//fmt.Println("=== PROMPT START ===", prompt.String(), "=== PROMPT END ===") if a.options.debugMode {
fmt.Println("=== HUD START ===", hud.String(), "=== HUD END ===")
fmt.Println("=== PROMPT START ===", prompt.String(), "=== PROMPT END ===")
}
// Get all the available actions IDs // Get all the available actions IDs
actionsID := []string{} actionsID := []string{}
@@ -214,7 +215,7 @@ func (a *Agent) pickAction(ctx context.Context, templ string, messages []openai.
Actions{intentionsTools}.ToTools(), Actions{intentionsTools}.ToTools(),
intentionsTools.Definition().Name) intentionsTools.Definition().Name)
if err != nil { if err != nil {
return nil, "", err return nil, "", fmt.Errorf("failed to get the action tool parameters: %v", err)
} }
actionChoice := action.IntentResponse{} actionChoice := action.IntentResponse{}

View File

@@ -3,6 +3,7 @@ package agent
import ( import (
"context" "context"
"fmt" "fmt"
"os"
"strings" "strings"
"sync" "sync"
"time" "time"
@@ -92,6 +93,35 @@ func New(opts ...Option) (*Agent, error) {
} }
} }
if a.options.statefile != "" {
if _, err := os.Stat(a.options.statefile); err == nil {
if err = a.LoadState(a.options.statefile); err != nil {
return a, fmt.Errorf("failed to load state: %v", err)
}
}
}
if a.options.characterfile != "" {
if _, err := os.Stat(a.options.characterfile); err == nil {
// if there is a file, load the character back
if err = a.LoadCharacter(a.options.characterfile); err != nil {
return a, fmt.Errorf("failed to load character: %v", err)
}
} else {
// otherwise save it for next time
if err = a.SaveCharacter(a.options.characterfile); err != nil {
return a, fmt.Errorf("failed to save character: %v", err)
}
}
}
if a.options.debugMode {
fmt.Println("=== Agent in Debug mode ===")
fmt.Println(a.Character.String())
fmt.Println(a.State().String())
fmt.Println("Permanent goal: ", a.options.permanentGoal)
}
return a, nil return a, nil
} }
@@ -144,16 +174,28 @@ func (a *Agent) runAction(chosenAction Action, decisionResult *decisionResult) (
} }
} }
if a.options.debugMode {
fmt.Println("Action", chosenAction.Definition().Name)
fmt.Println("Result", result)
}
if chosenAction.Definition().Name.Is(action.StateActionName) { if chosenAction.Definition().Name.Is(action.StateActionName) {
// We need to store the result in the state // We need to store the result in the state
state := action.StateResult{} state := action.StateResult{}
err = decisionResult.actionParams.Unmarshal(&result) err = decisionResult.actionParams.Unmarshal(&state)
if err != nil { if err != nil {
return "", err return "", fmt.Errorf("error unmarshalling state of the agent: %w", err)
} }
// update the current state with the one we just got from the action // update the current state with the one we just got from the action
a.currentState = &state a.currentState = &state
// update the state file
if a.options.statefile != "" {
if err := a.SaveState(a.options.statefile); err != nil {
return "", err
}
}
} }
return result, nil return result, nil
@@ -207,7 +249,7 @@ func (a *Agent) consumeJob(job *Job, role string) {
params, err := a.generateParameters(ctx, chosenAction, a.currentConversation) params, err := a.generateParameters(ctx, chosenAction, a.currentConversation)
if err != nil { if err != nil {
job.Result.Finish(err) job.Result.Finish(fmt.Errorf("error generating action's parameters: %w", err))
return return
} }
@@ -222,6 +264,7 @@ func (a *Agent) consumeJob(job *Job, role string) {
return return
} }
// If we don't have to reply , run the action!
if !chosenAction.Definition().Name.Is(action.ReplyActionName) { if !chosenAction.Definition().Name.Is(action.ReplyActionName) {
result, err := a.runAction(chosenAction, params) result, err := a.runAction(chosenAction, params)
if err != nil { if err != nil {

View File

@@ -125,5 +125,27 @@ var _ = Describe("Agent test", func() {
} }
Expect(reasons).To(ContainElement(testActionResult), fmt.Sprint(res)) Expect(reasons).To(ContainElement(testActionResult), fmt.Sprint(res))
}) })
It("updates the state with internal actions", func() {
agent, err := New(
WithLLMAPIURL(apiModel),
WithModel(testModel),
EnableHUD,
DebugMode,
// EnableStandaloneJob,
WithRandomIdentity(),
WithPermanentGoal("I want to learn to play music"),
)
Expect(err).ToNot(HaveOccurred())
go agent.Run()
defer agent.Stop()
result := agent.Ask(
WithText("Update your goals such as you want to learn to play the guitar"),
)
fmt.Printf("%+v\n", result)
Expect(result.Error).ToNot(HaveOccurred())
Expect(agent.State().Goal).To(ContainSubstring("guitar"), fmt.Sprint(agent.State()))
})
}) })
}) })

View File

@@ -19,6 +19,9 @@ type options struct {
randomIdentity bool randomIdentity bool
userActions Actions userActions Actions
enableHUD, standaloneJob bool enableHUD, standaloneJob bool
debugMode bool
characterfile string
statefile string
context context.Context context context.Context
permanentGoal string permanentGoal string
} }
@@ -54,6 +57,11 @@ var EnableHUD = func(o *options) error {
return nil return nil
} }
var DebugMode = func(o *options) error {
o.debugMode = true
return nil
}
// EnableStandaloneJob is an option to enable the agent // EnableStandaloneJob is an option to enable the agent
// to run jobs in the background automatically // to run jobs in the background automatically
var EnableStandaloneJob = func(o *options) error { var EnableStandaloneJob = func(o *options) error {

View File

@@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"os" "os"
"path/filepath"
"github.com/mudler/local-agent-framework/action" "github.com/mudler/local-agent-framework/action"
"github.com/mudler/local-agent-framework/llm" "github.com/mudler/local-agent-framework/llm"
@@ -40,8 +41,39 @@ func Load(path string) (*Character, error) {
return &c, nil return &c, nil
} }
func (a *Agent) Save(path string) error { func (a *Agent) State() action.StateResult {
data, err := json.Marshal(a.options.character) return *a.currentState
}
func (a *Agent) LoadState(path string) error {
data, err := os.ReadFile(path)
if err != nil {
return err
}
return json.Unmarshal(data, a.currentState)
}
func (a *Agent) LoadCharacter(path string) error {
data, err := os.ReadFile(path)
if err != nil {
return err
}
return json.Unmarshal(data, &a.Character)
}
func (a *Agent) SaveState(path string) error {
os.MkdirAll(filepath.Dir(path), 0755)
data, err := json.Marshal(a.currentState)
if err != nil {
return err
}
os.WriteFile(path, data, 0644)
return nil
}
func (a *Agent) SaveCharacter(path string) error {
os.MkdirAll(filepath.Dir(path), 0755)
data, err := json.Marshal(a.Character)
if err != nil { if err != nil {
return err return err
} }
@@ -59,7 +91,7 @@ func (a *Agent) generateIdentity(guidance string) error {
} }
if !a.validCharacter() { if !a.validCharacter() {
return fmt.Errorf("generated character is not valid ( guidance: %s ): %v", guidance, a.String()) return fmt.Errorf("generated character is not valid ( guidance: %s ): %v", guidance, a.Character.String())
} }
return nil return nil
} }
@@ -80,13 +112,13 @@ Hobbies: %v
Music taste: %v Music taste: %v
=====================` =====================`
func (a *Agent) String() string { func (c *Character) String() string {
return fmt.Sprintf( return fmt.Sprintf(
fmtT, fmtT,
a.Character.Name, c.Name,
a.Character.Age, c.Age,
a.Character.Occupation, c.Occupation,
a.Character.Hobbies, c.Hobbies,
a.Character.MusicTaste, c.MusicTaste,
) )
} }