diff --git a/Makefile b/Makefile index 920cace..48c139c 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ GOCMD=go tests: - $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts 5 --fail-fast -v -r ./... \ No newline at end of file + $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --fail-fast -v -r ./... \ No newline at end of file diff --git a/agent/actions.go b/agent/actions.go new file mode 100644 index 0000000..4883155 --- /dev/null +++ b/agent/actions.go @@ -0,0 +1 @@ +package agent diff --git a/agent/ask.go b/agent/ask.go new file mode 100644 index 0000000..4883155 --- /dev/null +++ b/agent/ask.go @@ -0,0 +1 @@ +package agent diff --git a/agent/constructor.go b/agent/constructor.go index 7eed122..4cd4061 100644 --- a/agent/constructor.go +++ b/agent/constructor.go @@ -1,18 +1,23 @@ package agent import ( + "fmt" + "github.com/mudler/local-agent-framework/llm" "github.com/sashabaranov/go-openai" ) type llmOptions struct { APIURL string + APIKey string Model string } type options struct { - LLMAPI llmOptions - Character Character + LLMAPI llmOptions + character Character + randomIdentityGuidance string + randomIdentity bool } type Agent struct { @@ -29,7 +34,7 @@ func defaultOptions() *options { APIURL: "http://localhost:8080", Model: "echidna", }, - Character: Character{ + character: Character{ Name: "John Doe", Age: 0, Occupation: "Unemployed", @@ -59,13 +64,18 @@ func New(opts ...Option) (*Agent, error) { return nil, err } - client := llm.NewClient("", options.LLMAPI.APIURL) + client := llm.NewClient(options.LLMAPI.APIKey, options.LLMAPI.APIURL) a := &Agent{ options: options, client: client, - Character: options.Character, + Character: options.character, } - return a, nil + + if a.options.randomIdentity { + err = a.generateIdentity("") + } + + return a, err } func WithLLMAPIURL(url string) Option { @@ -75,6 +85,13 @@ func WithLLMAPIURL(url string) Option { } } +func WithLLMAPIKey(key string) Option { + return func(o *options) error { + o.LLMAPI.APIKey = key + return nil + } +} + func WithModel(model string) Option { return func(o *options) error { o.LLMAPI.Model = model @@ -84,7 +101,7 @@ func WithModel(model string) Option { func WithCharacter(c Character) Option { return func(o *options) error { - o.Character = c + o.character = c return nil } } @@ -95,7 +112,15 @@ func FromFile(path string) Option { if err != nil { return err } - o.Character = *c + o.character = *c + return nil + } +} + +func WithRandomIdentity(guidance ...string) Option { + return func(o *options) error { + o.randomIdentityGuidance = fmt.Sprint(guidance) + o.randomIdentity = true return nil } } diff --git a/agent/state.go b/agent/state.go index 9a45669..27d8c15 100644 --- a/agent/state.go +++ b/agent/state.go @@ -2,6 +2,7 @@ package agent import ( "encoding/json" + "fmt" "os" "github.com/mudler/local-agent-framework/llm" @@ -17,6 +18,7 @@ type Character struct { Memories []string `json:"memories"` Hobbies []string `json:"hobbies"` MusicTaste []string `json:"music_taste"` + Sex string `json:"sex"` } func Load(path string) (*Character, error) { @@ -33,16 +35,37 @@ func Load(path string) (*Character, error) { } func (a *Agent) Save(path string) error { - data, err := json.Marshal(a.options.Character) + data, err := json.Marshal(a.options.character) if err != nil { return err } return os.WriteFile(path, data, 0644) } -func (a *Agent) GenerateIdentity(guidance string) error { - err := llm.GenerateJSONFromStruct(a.client, guidance, a.options.LLMAPI.Model, &a.options.Character) +func (a *Agent) generateIdentity(guidance string) error { + if guidance == "" { + guidance = "Generate a random character for roleplaying." + } + err := llm.GenerateJSONFromStruct(a.client, guidance, a.options.LLMAPI.Model, &a.options.character) + a.Character = a.options.character + if err != nil { + return err + } - a.Character = a.options.Character - return err + if !a.validCharacter() { + return fmt.Errorf("invalid character") + } + return nil +} + +func (a *Agent) validCharacter() bool { + return a.Character.Name != "" && + a.Character.Age != 0 && + a.Character.Occupation != "" && + a.Character.NowDoing != "" && + a.Character.DoingNext != "" && + len(a.Character.DoneHistory) != 0 && + len(a.Character.Memories) != 0 && + len(a.Character.Hobbies) != 0 && + len(a.Character.MusicTaste) != 0 } diff --git a/agent/state_test.go b/agent/state_test.go index 9700ee0..1df175c 100644 --- a/agent/state_test.go +++ b/agent/state_test.go @@ -11,13 +11,34 @@ import ( var _ = Describe("Agent test", func() { Context("identity", func() { - + It("generates all the fields with random data", func() { + agent, err := New( + WithLLMAPIURL("http://192.168.68.113:8080"), + WithModel("echidna"), + WithRandomIdentity(), + ) + Expect(err).ToNot(HaveOccurred()) + Expect(agent.Character.Name).ToNot(BeEmpty()) + Expect(agent.Character.Age).ToNot(BeZero()) + Expect(agent.Character.Occupation).ToNot(BeEmpty()) + Expect(agent.Character.NowDoing).ToNot(BeEmpty()) + Expect(agent.Character.DoingNext).ToNot(BeEmpty()) + Expect(agent.Character.DoneHistory).ToNot(BeEmpty()) + Expect(agent.Character.Memories).ToNot(BeEmpty()) + Expect(agent.Character.Hobbies).ToNot(BeEmpty()) + Expect(agent.Character.MusicTaste).ToNot(BeEmpty()) + fmt.Printf("%+v\n", agent.Character) + }) + It("detect an invalid character", func() { + _, err := New(WithRandomIdentity()) + Expect(err).To(HaveOccurred()) + }) It("generates all the fields", func() { agent, err := New( WithLLMAPIURL("http://192.168.68.113:8080"), - WithModel("echidna")) - Expect(err).ToNot(HaveOccurred()) - err = agent.GenerateIdentity("An old man with a long beard, a wizard, who lives in a tower.") + WithModel("echidna"), + WithRandomIdentity("An old man with a long beard, a wizard, who lives in a tower."), + ) Expect(err).ToNot(HaveOccurred()) Expect(agent.Character.Name).ToNot(BeEmpty()) Expect(agent.Character.Age).ToNot(BeZero())