fix state update, save/load
This commit is contained in:
@@ -24,7 +24,7 @@ func NewContext(ctx context.Context, cancel context.CancelFunc) *ActionContext {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type ActionParams map[string]string
|
type ActionParams map[string]interface{}
|
||||||
|
|
||||||
func (ap ActionParams) Read(s string) error {
|
func (ap ActionParams) Read(s string) error {
|
||||||
err := json.Unmarshal([]byte(s), &ap)
|
err := json.Unmarshal([]byte(s), &ap)
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package action
|
package action
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai/jsonschema"
|
"github.com/sashabaranov/go-openai/jsonschema"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -31,7 +33,7 @@ type StateResult struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *StateAction) Run(ActionParams) (string, error) {
|
func (a *StateAction) Run(ActionParams) (string, error) {
|
||||||
return "no-op", nil
|
return "internal state has been updated", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *StateAction) Definition() ActionDefinition {
|
func (a *StateAction) Definition() ActionDefinition {
|
||||||
@@ -68,3 +70,23 @@ func (a *StateAction) Definition() ActionDefinition {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const fmtT = `=====================
|
||||||
|
NowDoing: %s
|
||||||
|
DoingNext: %s
|
||||||
|
Your current goal is: %s
|
||||||
|
You have done: %+v
|
||||||
|
You have a short memory with: %+v
|
||||||
|
=====================
|
||||||
|
`
|
||||||
|
|
||||||
|
func (c StateResult) String() string {
|
||||||
|
return fmt.Sprintf(
|
||||||
|
fmtT,
|
||||||
|
c.NowDoing,
|
||||||
|
c.DoingNext,
|
||||||
|
c.Goal,
|
||||||
|
c.DoneHistory,
|
||||||
|
c.Memories,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|||||||
@@ -117,13 +117,12 @@ const hudTemplate = `You have a character and your replies and actions might be
|
|||||||
{{end}}
|
{{end}}
|
||||||
|
|
||||||
This is your current state:
|
This is your current state:
|
||||||
{{if .CurrentState.NowDoing}}NowDoing: {{.CurrentState.NowDoing}} {{end}}
|
NowDoing: {{if .CurrentState.NowDoing}}{{.CurrentState.NowDoing}}{{else}}Nothing{{end}}
|
||||||
{{if .CurrentState.DoingNext}}DoingNext: {{.CurrentState.DoingNext}} {{end}}
|
DoingNext: {{if .CurrentState.DoingNext}}{{.CurrentState.DoingNext}}{{else}}Nothing{{end}}
|
||||||
{{if .PermanentGoal}}Your permanent goal is: {{.PermanentGoal}} {{end}}
|
Your permanent goal is: {{if .PermanentGoal}}{{.PermanentGoal}}{{else}}Nothing{{end}}
|
||||||
{{if .CurrentState.Goal}}Your current goal is: {{.CurrentState.Goal}} {{end}}
|
Your current goal is: {{if .CurrentState.Goal}}{{.CurrentState.Goal}}{{else}}Nothing{{end}}
|
||||||
You have done: {{range .CurrentState.DoneHistory}}{{.}} {{end}}
|
You have done: {{range .CurrentState.DoneHistory}}{{.}} {{end}}
|
||||||
You have a short memory with: {{range .CurrentState.Memories}}{{.}} {{end}}
|
You have a short memory with: {{range .CurrentState.Memories}}{{.}} {{end}}`
|
||||||
`
|
|
||||||
|
|
||||||
// pickAction picks an action based on the conversation
|
// pickAction picks an action based on the conversation
|
||||||
func (a *Agent) pickAction(ctx context.Context, templ string, messages []openai.ChatCompletionMessage) (Action, string, error) {
|
func (a *Agent) pickAction(ctx context.Context, templ string, messages []openai.ChatCompletionMessage) (Action, string, error) {
|
||||||
@@ -160,9 +159,11 @@ func (a *Agent) pickAction(ctx context.Context, templ string, messages []openai.
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
//fmt.Println("=== HUD START ===", hud.String(), "=== HUD END ===")
|
|
||||||
|
|
||||||
//fmt.Println("=== PROMPT START ===", prompt.String(), "=== PROMPT END ===")
|
if a.options.debugMode {
|
||||||
|
fmt.Println("=== HUD START ===", hud.String(), "=== HUD END ===")
|
||||||
|
fmt.Println("=== PROMPT START ===", prompt.String(), "=== PROMPT END ===")
|
||||||
|
}
|
||||||
|
|
||||||
// Get all the available actions IDs
|
// Get all the available actions IDs
|
||||||
actionsID := []string{}
|
actionsID := []string{}
|
||||||
@@ -214,7 +215,7 @@ func (a *Agent) pickAction(ctx context.Context, templ string, messages []openai.
|
|||||||
Actions{intentionsTools}.ToTools(),
|
Actions{intentionsTools}.ToTools(),
|
||||||
intentionsTools.Definition().Name)
|
intentionsTools.Definition().Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", fmt.Errorf("failed to get the action tool parameters: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
actionChoice := action.IntentResponse{}
|
actionChoice := action.IntentResponse{}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package agent
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -92,6 +93,35 @@ func New(opts ...Option) (*Agent, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if a.options.statefile != "" {
|
||||||
|
if _, err := os.Stat(a.options.statefile); err == nil {
|
||||||
|
if err = a.LoadState(a.options.statefile); err != nil {
|
||||||
|
return a, fmt.Errorf("failed to load state: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if a.options.characterfile != "" {
|
||||||
|
if _, err := os.Stat(a.options.characterfile); err == nil {
|
||||||
|
// if there is a file, load the character back
|
||||||
|
if err = a.LoadCharacter(a.options.characterfile); err != nil {
|
||||||
|
return a, fmt.Errorf("failed to load character: %v", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// otherwise save it for next time
|
||||||
|
if err = a.SaveCharacter(a.options.characterfile); err != nil {
|
||||||
|
return a, fmt.Errorf("failed to save character: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if a.options.debugMode {
|
||||||
|
fmt.Println("=== Agent in Debug mode ===")
|
||||||
|
fmt.Println(a.Character.String())
|
||||||
|
fmt.Println(a.State().String())
|
||||||
|
fmt.Println("Permanent goal: ", a.options.permanentGoal)
|
||||||
|
}
|
||||||
|
|
||||||
return a, nil
|
return a, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -144,16 +174,28 @@ func (a *Agent) runAction(chosenAction Action, decisionResult *decisionResult) (
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if a.options.debugMode {
|
||||||
|
fmt.Println("Action", chosenAction.Definition().Name)
|
||||||
|
fmt.Println("Result", result)
|
||||||
|
}
|
||||||
|
|
||||||
if chosenAction.Definition().Name.Is(action.StateActionName) {
|
if chosenAction.Definition().Name.Is(action.StateActionName) {
|
||||||
// We need to store the result in the state
|
// We need to store the result in the state
|
||||||
state := action.StateResult{}
|
state := action.StateResult{}
|
||||||
|
|
||||||
err = decisionResult.actionParams.Unmarshal(&result)
|
err = decisionResult.actionParams.Unmarshal(&state)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", fmt.Errorf("error unmarshalling state of the agent: %w", err)
|
||||||
}
|
}
|
||||||
// update the current state with the one we just got from the action
|
// update the current state with the one we just got from the action
|
||||||
a.currentState = &state
|
a.currentState = &state
|
||||||
|
|
||||||
|
// update the state file
|
||||||
|
if a.options.statefile != "" {
|
||||||
|
if err := a.SaveState(a.options.statefile); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return result, nil
|
return result, nil
|
||||||
@@ -207,7 +249,7 @@ func (a *Agent) consumeJob(job *Job, role string) {
|
|||||||
|
|
||||||
params, err := a.generateParameters(ctx, chosenAction, a.currentConversation)
|
params, err := a.generateParameters(ctx, chosenAction, a.currentConversation)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
job.Result.Finish(err)
|
job.Result.Finish(fmt.Errorf("error generating action's parameters: %w", err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -222,6 +264,7 @@ func (a *Agent) consumeJob(job *Job, role string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If we don't have to reply , run the action!
|
||||||
if !chosenAction.Definition().Name.Is(action.ReplyActionName) {
|
if !chosenAction.Definition().Name.Is(action.ReplyActionName) {
|
||||||
result, err := a.runAction(chosenAction, params)
|
result, err := a.runAction(chosenAction, params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -125,5 +125,27 @@ var _ = Describe("Agent test", func() {
|
|||||||
}
|
}
|
||||||
Expect(reasons).To(ContainElement(testActionResult), fmt.Sprint(res))
|
Expect(reasons).To(ContainElement(testActionResult), fmt.Sprint(res))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("updates the state with internal actions", func() {
|
||||||
|
agent, err := New(
|
||||||
|
WithLLMAPIURL(apiModel),
|
||||||
|
WithModel(testModel),
|
||||||
|
EnableHUD,
|
||||||
|
DebugMode,
|
||||||
|
// EnableStandaloneJob,
|
||||||
|
WithRandomIdentity(),
|
||||||
|
WithPermanentGoal("I want to learn to play music"),
|
||||||
|
)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
go agent.Run()
|
||||||
|
defer agent.Stop()
|
||||||
|
|
||||||
|
result := agent.Ask(
|
||||||
|
WithText("Update your goals such as you want to learn to play the guitar"),
|
||||||
|
)
|
||||||
|
fmt.Printf("%+v\n", result)
|
||||||
|
Expect(result.Error).ToNot(HaveOccurred())
|
||||||
|
Expect(agent.State().Goal).To(ContainSubstring("guitar"), fmt.Sprint(agent.State()))
|
||||||
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -19,6 +19,9 @@ type options struct {
|
|||||||
randomIdentity bool
|
randomIdentity bool
|
||||||
userActions Actions
|
userActions Actions
|
||||||
enableHUD, standaloneJob bool
|
enableHUD, standaloneJob bool
|
||||||
|
debugMode bool
|
||||||
|
characterfile string
|
||||||
|
statefile string
|
||||||
context context.Context
|
context context.Context
|
||||||
permanentGoal string
|
permanentGoal string
|
||||||
}
|
}
|
||||||
@@ -54,6 +57,11 @@ var EnableHUD = func(o *options) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var DebugMode = func(o *options) error {
|
||||||
|
o.debugMode = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// EnableStandaloneJob is an option to enable the agent
|
// EnableStandaloneJob is an option to enable the agent
|
||||||
// to run jobs in the background automatically
|
// to run jobs in the background automatically
|
||||||
var EnableStandaloneJob = func(o *options) error {
|
var EnableStandaloneJob = func(o *options) error {
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
"github.com/mudler/local-agent-framework/action"
|
"github.com/mudler/local-agent-framework/action"
|
||||||
"github.com/mudler/local-agent-framework/llm"
|
"github.com/mudler/local-agent-framework/llm"
|
||||||
@@ -40,8 +41,39 @@ func Load(path string) (*Character, error) {
|
|||||||
return &c, nil
|
return &c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Agent) Save(path string) error {
|
func (a *Agent) State() action.StateResult {
|
||||||
data, err := json.Marshal(a.options.character)
|
return *a.currentState
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Agent) LoadState(path string) error {
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return json.Unmarshal(data, a.currentState)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Agent) LoadCharacter(path string) error {
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return json.Unmarshal(data, &a.Character)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Agent) SaveState(path string) error {
|
||||||
|
os.MkdirAll(filepath.Dir(path), 0755)
|
||||||
|
data, err := json.Marshal(a.currentState)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
os.WriteFile(path, data, 0644)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Agent) SaveCharacter(path string) error {
|
||||||
|
os.MkdirAll(filepath.Dir(path), 0755)
|
||||||
|
data, err := json.Marshal(a.Character)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -59,7 +91,7 @@ func (a *Agent) generateIdentity(guidance string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !a.validCharacter() {
|
if !a.validCharacter() {
|
||||||
return fmt.Errorf("generated character is not valid ( guidance: %s ): %v", guidance, a.String())
|
return fmt.Errorf("generated character is not valid ( guidance: %s ): %v", guidance, a.Character.String())
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -80,13 +112,13 @@ Hobbies: %v
|
|||||||
Music taste: %v
|
Music taste: %v
|
||||||
=====================`
|
=====================`
|
||||||
|
|
||||||
func (a *Agent) String() string {
|
func (c *Character) String() string {
|
||||||
return fmt.Sprintf(
|
return fmt.Sprintf(
|
||||||
fmtT,
|
fmtT,
|
||||||
a.Character.Name,
|
c.Name,
|
||||||
a.Character.Age,
|
c.Age,
|
||||||
a.Character.Occupation,
|
c.Occupation,
|
||||||
a.Character.Hobbies,
|
c.Hobbies,
|
||||||
a.Character.MusicTaste,
|
c.MusicTaste,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user