fix state update, save/load
This commit is contained in:
@@ -24,7 +24,7 @@ func NewContext(ctx context.Context, cancel context.CancelFunc) *ActionContext {
|
||||
}
|
||||
}
|
||||
|
||||
type ActionParams map[string]string
|
||||
type ActionParams map[string]interface{}
|
||||
|
||||
func (ap ActionParams) Read(s string) error {
|
||||
err := json.Unmarshal([]byte(s), &ap)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package action
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/sashabaranov/go-openai/jsonschema"
|
||||
)
|
||||
|
||||
@@ -31,7 +33,7 @@ type StateResult struct {
|
||||
}
|
||||
|
||||
func (a *StateAction) Run(ActionParams) (string, error) {
|
||||
return "no-op", nil
|
||||
return "internal state has been updated", nil
|
||||
}
|
||||
|
||||
func (a *StateAction) Definition() ActionDefinition {
|
||||
@@ -68,3 +70,23 @@ func (a *StateAction) Definition() ActionDefinition {
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
const fmtT = `=====================
|
||||
NowDoing: %s
|
||||
DoingNext: %s
|
||||
Your current goal is: %s
|
||||
You have done: %+v
|
||||
You have a short memory with: %+v
|
||||
=====================
|
||||
`
|
||||
|
||||
func (c StateResult) String() string {
|
||||
return fmt.Sprintf(
|
||||
fmtT,
|
||||
c.NowDoing,
|
||||
c.DoingNext,
|
||||
c.Goal,
|
||||
c.DoneHistory,
|
||||
c.Memories,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -117,13 +117,12 @@ const hudTemplate = `You have a character and your replies and actions might be
|
||||
{{end}}
|
||||
|
||||
This is your current state:
|
||||
{{if .CurrentState.NowDoing}}NowDoing: {{.CurrentState.NowDoing}} {{end}}
|
||||
{{if .CurrentState.DoingNext}}DoingNext: {{.CurrentState.DoingNext}} {{end}}
|
||||
{{if .PermanentGoal}}Your permanent goal is: {{.PermanentGoal}} {{end}}
|
||||
{{if .CurrentState.Goal}}Your current goal is: {{.CurrentState.Goal}} {{end}}
|
||||
NowDoing: {{if .CurrentState.NowDoing}}{{.CurrentState.NowDoing}}{{else}}Nothing{{end}}
|
||||
DoingNext: {{if .CurrentState.DoingNext}}{{.CurrentState.DoingNext}}{{else}}Nothing{{end}}
|
||||
Your permanent goal is: {{if .PermanentGoal}}{{.PermanentGoal}}{{else}}Nothing{{end}}
|
||||
Your current goal is: {{if .CurrentState.Goal}}{{.CurrentState.Goal}}{{else}}Nothing{{end}}
|
||||
You have done: {{range .CurrentState.DoneHistory}}{{.}} {{end}}
|
||||
You have a short memory with: {{range .CurrentState.Memories}}{{.}} {{end}}
|
||||
`
|
||||
You have a short memory with: {{range .CurrentState.Memories}}{{.}} {{end}}`
|
||||
|
||||
// pickAction picks an action based on the conversation
|
||||
func (a *Agent) pickAction(ctx context.Context, templ string, messages []openai.ChatCompletionMessage) (Action, string, error) {
|
||||
@@ -160,9 +159,11 @@ func (a *Agent) pickAction(ctx context.Context, templ string, messages []openai.
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
//fmt.Println("=== HUD START ===", hud.String(), "=== HUD END ===")
|
||||
|
||||
//fmt.Println("=== PROMPT START ===", prompt.String(), "=== PROMPT END ===")
|
||||
if a.options.debugMode {
|
||||
fmt.Println("=== HUD START ===", hud.String(), "=== HUD END ===")
|
||||
fmt.Println("=== PROMPT START ===", prompt.String(), "=== PROMPT END ===")
|
||||
}
|
||||
|
||||
// Get all the available actions IDs
|
||||
actionsID := []string{}
|
||||
@@ -214,7 +215,7 @@ func (a *Agent) pickAction(ctx context.Context, templ string, messages []openai.
|
||||
Actions{intentionsTools}.ToTools(),
|
||||
intentionsTools.Definition().Name)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
return nil, "", fmt.Errorf("failed to get the action tool parameters: %v", err)
|
||||
}
|
||||
|
||||
actionChoice := action.IntentResponse{}
|
||||
|
||||
@@ -3,6 +3,7 @@ package agent
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -92,6 +93,35 @@ func New(opts ...Option) (*Agent, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if a.options.statefile != "" {
|
||||
if _, err := os.Stat(a.options.statefile); err == nil {
|
||||
if err = a.LoadState(a.options.statefile); err != nil {
|
||||
return a, fmt.Errorf("failed to load state: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if a.options.characterfile != "" {
|
||||
if _, err := os.Stat(a.options.characterfile); err == nil {
|
||||
// if there is a file, load the character back
|
||||
if err = a.LoadCharacter(a.options.characterfile); err != nil {
|
||||
return a, fmt.Errorf("failed to load character: %v", err)
|
||||
}
|
||||
} else {
|
||||
// otherwise save it for next time
|
||||
if err = a.SaveCharacter(a.options.characterfile); err != nil {
|
||||
return a, fmt.Errorf("failed to save character: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if a.options.debugMode {
|
||||
fmt.Println("=== Agent in Debug mode ===")
|
||||
fmt.Println(a.Character.String())
|
||||
fmt.Println(a.State().String())
|
||||
fmt.Println("Permanent goal: ", a.options.permanentGoal)
|
||||
}
|
||||
|
||||
return a, nil
|
||||
}
|
||||
|
||||
@@ -144,16 +174,28 @@ func (a *Agent) runAction(chosenAction Action, decisionResult *decisionResult) (
|
||||
}
|
||||
}
|
||||
|
||||
if a.options.debugMode {
|
||||
fmt.Println("Action", chosenAction.Definition().Name)
|
||||
fmt.Println("Result", result)
|
||||
}
|
||||
|
||||
if chosenAction.Definition().Name.Is(action.StateActionName) {
|
||||
// We need to store the result in the state
|
||||
state := action.StateResult{}
|
||||
|
||||
err = decisionResult.actionParams.Unmarshal(&result)
|
||||
err = decisionResult.actionParams.Unmarshal(&state)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", fmt.Errorf("error unmarshalling state of the agent: %w", err)
|
||||
}
|
||||
// update the current state with the one we just got from the action
|
||||
a.currentState = &state
|
||||
|
||||
// update the state file
|
||||
if a.options.statefile != "" {
|
||||
if err := a.SaveState(a.options.statefile); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
@@ -207,7 +249,7 @@ func (a *Agent) consumeJob(job *Job, role string) {
|
||||
|
||||
params, err := a.generateParameters(ctx, chosenAction, a.currentConversation)
|
||||
if err != nil {
|
||||
job.Result.Finish(err)
|
||||
job.Result.Finish(fmt.Errorf("error generating action's parameters: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -222,6 +264,7 @@ func (a *Agent) consumeJob(job *Job, role string) {
|
||||
return
|
||||
}
|
||||
|
||||
// If we don't have to reply , run the action!
|
||||
if !chosenAction.Definition().Name.Is(action.ReplyActionName) {
|
||||
result, err := a.runAction(chosenAction, params)
|
||||
if err != nil {
|
||||
|
||||
@@ -125,5 +125,27 @@ var _ = Describe("Agent test", func() {
|
||||
}
|
||||
Expect(reasons).To(ContainElement(testActionResult), fmt.Sprint(res))
|
||||
})
|
||||
|
||||
It("updates the state with internal actions", func() {
|
||||
agent, err := New(
|
||||
WithLLMAPIURL(apiModel),
|
||||
WithModel(testModel),
|
||||
EnableHUD,
|
||||
DebugMode,
|
||||
// EnableStandaloneJob,
|
||||
WithRandomIdentity(),
|
||||
WithPermanentGoal("I want to learn to play music"),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go agent.Run()
|
||||
defer agent.Stop()
|
||||
|
||||
result := agent.Ask(
|
||||
WithText("Update your goals such as you want to learn to play the guitar"),
|
||||
)
|
||||
fmt.Printf("%+v\n", result)
|
||||
Expect(result.Error).ToNot(HaveOccurred())
|
||||
Expect(agent.State().Goal).To(ContainSubstring("guitar"), fmt.Sprint(agent.State()))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -19,6 +19,9 @@ type options struct {
|
||||
randomIdentity bool
|
||||
userActions Actions
|
||||
enableHUD, standaloneJob bool
|
||||
debugMode bool
|
||||
characterfile string
|
||||
statefile string
|
||||
context context.Context
|
||||
permanentGoal string
|
||||
}
|
||||
@@ -54,6 +57,11 @@ var EnableHUD = func(o *options) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var DebugMode = func(o *options) error {
|
||||
o.debugMode = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// EnableStandaloneJob is an option to enable the agent
|
||||
// to run jobs in the background automatically
|
||||
var EnableStandaloneJob = func(o *options) error {
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/mudler/local-agent-framework/action"
|
||||
"github.com/mudler/local-agent-framework/llm"
|
||||
@@ -40,8 +41,39 @@ func Load(path string) (*Character, error) {
|
||||
return &c, nil
|
||||
}
|
||||
|
||||
func (a *Agent) Save(path string) error {
|
||||
data, err := json.Marshal(a.options.character)
|
||||
func (a *Agent) State() action.StateResult {
|
||||
return *a.currentState
|
||||
}
|
||||
|
||||
func (a *Agent) LoadState(path string) error {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(data, a.currentState)
|
||||
}
|
||||
|
||||
func (a *Agent) LoadCharacter(path string) error {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(data, &a.Character)
|
||||
}
|
||||
|
||||
func (a *Agent) SaveState(path string) error {
|
||||
os.MkdirAll(filepath.Dir(path), 0755)
|
||||
data, err := json.Marshal(a.currentState)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
os.WriteFile(path, data, 0644)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Agent) SaveCharacter(path string) error {
|
||||
os.MkdirAll(filepath.Dir(path), 0755)
|
||||
data, err := json.Marshal(a.Character)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -59,7 +91,7 @@ func (a *Agent) generateIdentity(guidance string) error {
|
||||
}
|
||||
|
||||
if !a.validCharacter() {
|
||||
return fmt.Errorf("generated character is not valid ( guidance: %s ): %v", guidance, a.String())
|
||||
return fmt.Errorf("generated character is not valid ( guidance: %s ): %v", guidance, a.Character.String())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -80,13 +112,13 @@ Hobbies: %v
|
||||
Music taste: %v
|
||||
=====================`
|
||||
|
||||
func (a *Agent) String() string {
|
||||
func (c *Character) String() string {
|
||||
return fmt.Sprintf(
|
||||
fmtT,
|
||||
a.Character.Name,
|
||||
a.Character.Age,
|
||||
a.Character.Occupation,
|
||||
a.Character.Hobbies,
|
||||
a.Character.MusicTaste,
|
||||
c.Name,
|
||||
c.Age,
|
||||
c.Occupation,
|
||||
c.Hobbies,
|
||||
c.MusicTaste,
|
||||
)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user