diff --git a/action/definition.go b/action/definition.go index c4b1414..d4cb742 100644 --- a/action/definition.go +++ b/action/definition.go @@ -24,7 +24,7 @@ func NewContext(ctx context.Context, cancel context.CancelFunc) *ActionContext { } } -type ActionParams map[string]string +type ActionParams map[string]interface{} func (ap ActionParams) Read(s string) error { err := json.Unmarshal([]byte(s), &ap) diff --git a/action/state.go b/action/state.go index 649928f..9deb194 100644 --- a/action/state.go +++ b/action/state.go @@ -1,6 +1,8 @@ package action import ( + "fmt" + "github.com/sashabaranov/go-openai/jsonschema" ) @@ -31,7 +33,7 @@ type StateResult struct { } func (a *StateAction) Run(ActionParams) (string, error) { - return "no-op", nil + return "internal state has been updated", nil } func (a *StateAction) Definition() ActionDefinition { @@ -68,3 +70,23 @@ func (a *StateAction) Definition() ActionDefinition { }, } } + +const fmtT = `===================== +NowDoing: %s +DoingNext: %s +Your current goal is: %s +You have done: %+v +You have a short memory with: %+v +===================== +` + +func (c StateResult) String() string { + return fmt.Sprintf( + fmtT, + c.NowDoing, + c.DoingNext, + c.Goal, + c.DoneHistory, + c.Memories, + ) +} diff --git a/agent/actions.go b/agent/actions.go index a6bda21..65aa40e 100644 --- a/agent/actions.go +++ b/agent/actions.go @@ -117,13 +117,12 @@ const hudTemplate = `You have a character and your replies and actions might be {{end}} This is your current state: -{{if .CurrentState.NowDoing}}NowDoing: {{.CurrentState.NowDoing}} {{end}} -{{if .CurrentState.DoingNext}}DoingNext: {{.CurrentState.DoingNext}} {{end}} -{{if .PermanentGoal}}Your permanent goal is: {{.PermanentGoal}} {{end}} -{{if .CurrentState.Goal}}Your current goal is: {{.CurrentState.Goal}} {{end}} +NowDoing: {{if .CurrentState.NowDoing}}{{.CurrentState.NowDoing}}{{else}}Nothing{{end}} +DoingNext: {{if .CurrentState.DoingNext}}{{.CurrentState.DoingNext}}{{else}}Nothing{{end}} +Your permanent goal is: {{if .PermanentGoal}}{{.PermanentGoal}}{{else}}Nothing{{end}} +Your current goal is: {{if .CurrentState.Goal}}{{.CurrentState.Goal}}{{else}}Nothing{{end}} You have done: {{range .CurrentState.DoneHistory}}{{.}} {{end}} -You have a short memory with: {{range .CurrentState.Memories}}{{.}} {{end}} -` +You have a short memory with: {{range .CurrentState.Memories}}{{.}} {{end}}` // pickAction picks an action based on the conversation func (a *Agent) pickAction(ctx context.Context, templ string, messages []openai.ChatCompletionMessage) (Action, string, error) { @@ -160,9 +159,11 @@ func (a *Agent) pickAction(ctx context.Context, templ string, messages []openai. if err != nil { return nil, "", err } - //fmt.Println("=== HUD START ===", hud.String(), "=== HUD END ===") - //fmt.Println("=== PROMPT START ===", prompt.String(), "=== PROMPT END ===") + if a.options.debugMode { + fmt.Println("=== HUD START ===", hud.String(), "=== HUD END ===") + fmt.Println("=== PROMPT START ===", prompt.String(), "=== PROMPT END ===") + } // Get all the available actions IDs actionsID := []string{} @@ -214,7 +215,7 @@ func (a *Agent) pickAction(ctx context.Context, templ string, messages []openai. Actions{intentionsTools}.ToTools(), intentionsTools.Definition().Name) if err != nil { - return nil, "", err + return nil, "", fmt.Errorf("failed to get the action tool parameters: %v", err) } actionChoice := action.IntentResponse{} diff --git a/agent/agent.go b/agent/agent.go index e6b8d70..c551163 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -3,6 +3,7 @@ package agent import ( "context" "fmt" + "os" "strings" "sync" "time" @@ -92,6 +93,35 @@ func New(opts ...Option) (*Agent, error) { } } + if a.options.statefile != "" { + if _, err := os.Stat(a.options.statefile); err == nil { + if err = a.LoadState(a.options.statefile); err != nil { + return a, fmt.Errorf("failed to load state: %v", err) + } + } + } + + if a.options.characterfile != "" { + if _, err := os.Stat(a.options.characterfile); err == nil { + // if there is a file, load the character back + if err = a.LoadCharacter(a.options.characterfile); err != nil { + return a, fmt.Errorf("failed to load character: %v", err) + } + } else { + // otherwise save it for next time + if err = a.SaveCharacter(a.options.characterfile); err != nil { + return a, fmt.Errorf("failed to save character: %v", err) + } + } + } + + if a.options.debugMode { + fmt.Println("=== Agent in Debug mode ===") + fmt.Println(a.Character.String()) + fmt.Println(a.State().String()) + fmt.Println("Permanent goal: ", a.options.permanentGoal) + } + return a, nil } @@ -144,16 +174,28 @@ func (a *Agent) runAction(chosenAction Action, decisionResult *decisionResult) ( } } + if a.options.debugMode { + fmt.Println("Action", chosenAction.Definition().Name) + fmt.Println("Result", result) + } + if chosenAction.Definition().Name.Is(action.StateActionName) { // We need to store the result in the state state := action.StateResult{} - err = decisionResult.actionParams.Unmarshal(&result) + err = decisionResult.actionParams.Unmarshal(&state) if err != nil { - return "", err + return "", fmt.Errorf("error unmarshalling state of the agent: %w", err) } // update the current state with the one we just got from the action a.currentState = &state + + // update the state file + if a.options.statefile != "" { + if err := a.SaveState(a.options.statefile); err != nil { + return "", err + } + } } return result, nil @@ -207,7 +249,7 @@ func (a *Agent) consumeJob(job *Job, role string) { params, err := a.generateParameters(ctx, chosenAction, a.currentConversation) if err != nil { - job.Result.Finish(err) + job.Result.Finish(fmt.Errorf("error generating action's parameters: %w", err)) return } @@ -222,6 +264,7 @@ func (a *Agent) consumeJob(job *Job, role string) { return } + // If we don't have to reply , run the action! if !chosenAction.Definition().Name.Is(action.ReplyActionName) { result, err := a.runAction(chosenAction, params) if err != nil { diff --git a/agent/agent_test.go b/agent/agent_test.go index 3861ce3..b532109 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -125,5 +125,27 @@ var _ = Describe("Agent test", func() { } Expect(reasons).To(ContainElement(testActionResult), fmt.Sprint(res)) }) + + It("updates the state with internal actions", func() { + agent, err := New( + WithLLMAPIURL(apiModel), + WithModel(testModel), + EnableHUD, + DebugMode, + // EnableStandaloneJob, + WithRandomIdentity(), + WithPermanentGoal("I want to learn to play music"), + ) + Expect(err).ToNot(HaveOccurred()) + go agent.Run() + defer agent.Stop() + + result := agent.Ask( + WithText("Update your goals such as you want to learn to play the guitar"), + ) + fmt.Printf("%+v\n", result) + Expect(result.Error).ToNot(HaveOccurred()) + Expect(agent.State().Goal).To(ContainSubstring("guitar"), fmt.Sprint(agent.State())) + }) }) }) diff --git a/agent/options.go b/agent/options.go index baef145..c9c6e42 100644 --- a/agent/options.go +++ b/agent/options.go @@ -19,6 +19,9 @@ type options struct { randomIdentity bool userActions Actions enableHUD, standaloneJob bool + debugMode bool + characterfile string + statefile string context context.Context permanentGoal string } @@ -54,6 +57,11 @@ var EnableHUD = func(o *options) error { return nil } +var DebugMode = func(o *options) error { + o.debugMode = true + return nil +} + // EnableStandaloneJob is an option to enable the agent // to run jobs in the background automatically var EnableStandaloneJob = func(o *options) error { diff --git a/agent/state.go b/agent/state.go index e3f740d..c165f92 100644 --- a/agent/state.go +++ b/agent/state.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "os" + "path/filepath" "github.com/mudler/local-agent-framework/action" "github.com/mudler/local-agent-framework/llm" @@ -40,8 +41,39 @@ func Load(path string) (*Character, error) { return &c, nil } -func (a *Agent) Save(path string) error { - data, err := json.Marshal(a.options.character) +func (a *Agent) State() action.StateResult { + return *a.currentState +} + +func (a *Agent) LoadState(path string) error { + data, err := os.ReadFile(path) + if err != nil { + return err + } + return json.Unmarshal(data, a.currentState) +} + +func (a *Agent) LoadCharacter(path string) error { + data, err := os.ReadFile(path) + if err != nil { + return err + } + return json.Unmarshal(data, &a.Character) +} + +func (a *Agent) SaveState(path string) error { + os.MkdirAll(filepath.Dir(path), 0755) + data, err := json.Marshal(a.currentState) + if err != nil { + return err + } + os.WriteFile(path, data, 0644) + return nil +} + +func (a *Agent) SaveCharacter(path string) error { + os.MkdirAll(filepath.Dir(path), 0755) + data, err := json.Marshal(a.Character) if err != nil { return err } @@ -59,7 +91,7 @@ func (a *Agent) generateIdentity(guidance string) error { } if !a.validCharacter() { - return fmt.Errorf("generated character is not valid ( guidance: %s ): %v", guidance, a.String()) + return fmt.Errorf("generated character is not valid ( guidance: %s ): %v", guidance, a.Character.String()) } return nil } @@ -80,13 +112,13 @@ Hobbies: %v Music taste: %v =====================` -func (a *Agent) String() string { +func (c *Character) String() string { return fmt.Sprintf( fmtT, - a.Character.Name, - a.Character.Age, - a.Character.Occupation, - a.Character.Hobbies, - a.Character.MusicTaste, + c.Name, + c.Age, + c.Occupation, + c.Hobbies, + c.MusicTaste, ) }