diff --git a/agent/actions.go b/agent/actions.go index 79d94b6..b035ba8 100644 --- a/agent/actions.go +++ b/agent/actions.go @@ -2,9 +2,12 @@ package agent import ( "context" + "encoding/json" "fmt" + "time" "github.com/mudler/local-agent-framework/llm" + "github.com/sashabaranov/go-openai" ) type ActionContext struct { @@ -12,16 +15,31 @@ type ActionContext struct { cancelFunc context.CancelFunc } +type ActionParams map[string]string + +func (ap ActionParams) Read(s string) error { + err := json.Unmarshal([]byte(s), &ap) + return err +} + +type ActionDefinition openai.FunctionDefinition + +func (a ActionDefinition) FD() openai.FunctionDefinition { + return openai.FunctionDefinition(a) +} + // Actions is something the agent can do type Action interface { - Description() string ID() string - Run(map[string]string) error + Run(ActionParams) (string, error) + Definition() ActionDefinition } var ErrContextCanceled = fmt.Errorf("context canceled") func (a *Agent) Stop() { + a.Lock() + defer a.Unlock() a.context.cancelFunc() } @@ -38,14 +56,31 @@ func (a *Agent) Run() error { // Expose a REST API to interact with the agent to ask it things + fmt.Println("Agent is running") + clearConvTimer := time.NewTicker(1 * time.Minute) for { + fmt.Println("Agent loop") + select { case job := <-a.jobQueue: + fmt.Println("job from the queue") + // Consume the job and generate a response + // TODO: Give a short-term memory to the agent a.consumeJob(job) case <-a.context.Done(): + fmt.Println("Context canceled, agent is stopping...") + // Agent has been canceled, return error return ErrContextCanceled + case <-clearConvTimer.C: + fmt.Println("Removing chat history...") + + // TODO: decide to do something on its own with the conversation result + // before clearing it out + + // Clear the conversation + a.currentConversation = []openai.ChatCompletionMessage{} } } } @@ -53,21 +88,24 @@ func (a *Agent) Run() error { // StopAction stops the current action // if any. Can be called before adding a new job. func (a *Agent) StopAction() { + a.Lock() + defer a.Unlock() if a.actionContext != nil { a.actionContext.cancelFunc() } } func (a *Agent) consumeJob(job *Job) { - // Consume the job and generate a response - // Implement your logic here + // Consume the job and generate a response + a.Lock() // Set the action context ctx, cancel := context.WithCancel(context.Background()) a.actionContext = &ActionContext{ Context: ctx, cancelFunc: cancel, } + a.Unlock() if job.Image != "" { // TODO: Use llava to explain the image content @@ -78,20 +116,115 @@ func (a *Agent) consumeJob(job *Job) { return } - decision := struct { - Action string `json:"action"` - }{ - Action: "generate_identity", + actionChoice := struct { + Choice string `json:"choice"` + }{} + + llm.GenerateJSON(ctx, a.client, a.options.LLMAPI.Model, , &actionChoice) + + // https://github.com/sashabaranov/go-openai/blob/0925563e86c2fdc5011310aa616ba493989cfe0a/examples/completion-with-tool/main.go#L16 + actions := a.options.actions + tools := []openai.Tool{} + + messages := a.currentConversation + if job.Text != "" { + messages = append(messages, openai.ChatCompletionMessage{ + Role: "user", + Content: job.Text, + }) } - llm.GenerateJSON(ctx, a.client, a.options.LLMAPI.Model, - "decide which action to take give the", - &decision) + for _, action := range actions { + tools = append(tools, openai.Tool{ + Type: openai.ToolTypeFunction, + Function: action.Definition().FD(), + }) + } + + decision := openai.ChatCompletionRequest{ + Model: a.options.LLMAPI.Model, + Messages: messages, + Tools: tools, + } + resp, err := a.client.CreateChatCompletion(ctx, decision) + if err != nil || len(resp.Choices) != 1 { + fmt.Printf("Completion error: err:%v len(choices):%v\n", err, + len(resp.Choices)) + return + } + + msg := resp.Choices[0].Message + if len(msg.ToolCalls) != 1 { + fmt.Printf("Completion error: len(toolcalls): %v\n", len(msg.ToolCalls)) + return + } + + // simulate calling the function & responding to OpenAI + messages = append(messages, msg) + fmt.Printf("OpenAI called us back wanting to invoke our function '%v' with params '%v'\n", + msg.ToolCalls[0].Function.Name, msg.ToolCalls[0].Function.Arguments) + + params := ActionParams{} + if err := params.Read(msg.ToolCalls[0].Function.Arguments); err != nil { + fmt.Printf("error unmarshalling arguments: %v\n", err) + return + } + + var result string + for _, action := range actions { + fmt.Println("Checking action: ", action.ID()) + fmt.Println("Checking action: ", msg.ToolCalls[0].Function.Name) + if action.ID() == msg.ToolCalls[0].Function.Name { + fmt.Printf("Running action: %v\n", action.ID()) + if result, err = action.Run(params); err != nil { + fmt.Printf("error running action: %v\n", err) + return + } + } + } + fmt.Printf("Action run result: %v\n", result) + + // simulate calling the function + messages = append(messages, openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleTool, + Content: result, + Name: msg.ToolCalls[0].Function.Name, + ToolCallID: msg.ToolCalls[0].ID, + }) + + resp, err = a.client.CreateChatCompletion(ctx, + openai.ChatCompletionRequest{ + Model: a.options.LLMAPI.Model, + Messages: messages, + Tools: tools, + }, + ) + if err != nil || len(resp.Choices) != 1 { + fmt.Printf("2nd completion error: err:%v len(choices):%v\n", err, + len(resp.Choices)) + return + } + + // display OpenAI's response to the original question utilizing our function + msg = resp.Choices[0].Message + fmt.Printf("OpenAI answered the original request with: %v\n", + msg.Content) + + messages = append(messages, msg) + a.currentConversation = append(a.currentConversation, messages...) + + if len(msg.ToolCalls) != 0 { + fmt.Printf("OpenAI wants to call again functions: %v\n", msg) + // wants to call again an action (?) + job.Text = "" // Call the job with the current conversation + job.Result.SetResult(result) + a.jobQueue <- job + return + } // perform the action (if any) // or reply with a result - // if there is an action... - job.Result.SetResult("I don't know how to do that yet.") - + job.Result.SetResult(result) + job.Result.Finish() } diff --git a/agent/agent.go b/agent/agent.go index d19b521..bf032d2 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -3,12 +3,14 @@ package agent import ( "context" "fmt" + "sync" "github.com/mudler/local-agent-framework/llm" "github.com/sashabaranov/go-openai" ) type Agent struct { + sync.Mutex options *options Character Character client *openai.Client @@ -16,6 +18,8 @@ type Agent struct { actionContext *ActionContext context *ActionContext availableActions []Action + + currentConversation []openai.ChatCompletionMessage } func New(opts ...Option) (*Agent, error) { @@ -36,6 +40,7 @@ func New(opts ...Option) (*Agent, error) { ctx, cancel := context.WithCancel(c) a := &Agent{ + jobQueue: make(chan *Job), options: options, client: client, Character: options.character, @@ -43,6 +48,7 @@ func New(opts ...Option) (*Agent, error) { Context: ctx, cancelFunc: cancel, }, + availableActions: options.actions, } if a.options.randomIdentity { @@ -56,9 +62,12 @@ func New(opts ...Option) (*Agent, error) { // Ask is a pre-emptive, blocking call that returns the response as soon as it's ready. // It discards any other computation. -func (a *Agent) Ask(text, image string) string { - a.StopAction() +func (a *Agent) Ask(text, image string) []string { + //a.StopAction() j := NewJob(text, image) + fmt.Println("Job created", text) a.jobQueue <- j + fmt.Println("Waiting for result") + return j.Result.WaitResult() } diff --git a/agent/agent_suite_test.go b/agent/agent_suite_test.go index ff10f4c..7db5e21 100644 --- a/agent/agent_suite_test.go +++ b/agent/agent_suite_test.go @@ -1,6 +1,7 @@ package agent_test import ( + "os" "testing" . "github.com/onsi/ginkgo/v2" @@ -11,3 +12,15 @@ func TestAgent(t *testing.T) { RegisterFailHandler(Fail) RunSpecs(t, "Agent test suite") } + +var testModel = os.Getenv("TEST_MODEL") +var apiModel = os.Getenv("API_MODEL") + +func init() { + if testModel == "" { + testModel = "hermes-2-pro-mistral" + } + if apiModel == "" { + apiModel = "http://192.168.68.113:8080" + } +} diff --git a/agent/jobs.go b/agent/jobs.go index 4cc3b06..f7ccdeb 100644 --- a/agent/jobs.go +++ b/agent/jobs.go @@ -16,7 +16,7 @@ type Job struct { type JobResult struct { sync.Mutex // The result of a job - Text string + Data []string ready chan bool } @@ -44,14 +44,21 @@ func (j *JobResult) SetResult(text string) { j.Lock() defer j.Unlock() - j.Text = text + j.Data = append(j.Data, text) +} + +// SetResult sets the result of a job +func (j *JobResult) Finish() { + j.Lock() + defer j.Unlock() + close(j.ready) } // WaitResult waits for the result of a job -func (j *JobResult) WaitResult() string { +func (j *JobResult) WaitResult() []string { <-j.ready j.Lock() defer j.Unlock() - return j.Text + return j.Data } diff --git a/agent/options.go b/agent/options.go index 37bac26..851537c 100644 --- a/agent/options.go +++ b/agent/options.go @@ -17,6 +17,7 @@ type options struct { character Character randomIdentityGuidance string randomIdentity bool + actions []Action context context.Context } @@ -103,3 +104,10 @@ func WithRandomIdentity(guidance ...string) Option { return nil } } + +func WithActions(actions ...Action) Option { + return func(o *options) error { + o.actions = actions + return nil + } +} diff --git a/agent/state_test.go b/agent/state_test.go index f8fa0e7..d5d40e9 100644 --- a/agent/state_test.go +++ b/agent/state_test.go @@ -13,8 +13,8 @@ var _ = Describe("Agent test", func() { Context("identity", func() { It("generates all the fields with random data", func() { agent, err := New( - WithLLMAPIURL("http://192.168.68.113:8080"), - WithModel("echidna"), + WithLLMAPIURL(apiModel), + WithModel(testModel), WithRandomIdentity(), ) Expect(err).ToNot(HaveOccurred()) @@ -35,8 +35,8 @@ var _ = Describe("Agent test", func() { }) It("generates all the fields", func() { agent, err := New( - WithLLMAPIURL("http://192.168.68.113:8080"), - WithModel("echidna"), + WithLLMAPIURL(apiModel), + WithModel(testModel), WithRandomIdentity("An old man with a long beard, a wizard, who lives in a tower."), ) Expect(err).ToNot(HaveOccurred())