diff --git a/agent/actions.go b/agent/actions.go index cf187e9..98f9c5d 100644 --- a/agent/actions.go +++ b/agent/actions.go @@ -1,13 +1,15 @@ package agent import ( + "bytes" "context" "encoding/json" "fmt" + "html/template" "time" //"github.com/mudler/local-agent-framework/llm" - "github.com/mudler/local-agent-framework/llm" + "github.com/sashabaranov/go-openai" ) @@ -23,6 +25,11 @@ func (ap ActionParams) Read(s string) error { return err } +func (ap ActionParams) String() string { + b, _ := json.Marshal(ap) + return string(b) +} + type ActionDefinition openai.FunctionDefinition func (a ActionDefinition) FD() openai.FunctionDefinition { @@ -97,6 +104,137 @@ func (a *Agent) StopAction() { } } +func (a *Agent) decision(ctx context.Context, conversation []openai.ChatCompletionMessage, tools []openai.Tool, toolchoice any) (ActionParams, error) { + decision := openai.ChatCompletionRequest{ + Model: a.options.LLMAPI.Model, + Messages: conversation, + Tools: tools, + ToolChoice: toolchoice, + } + resp, err := a.client.CreateChatCompletion(ctx, decision) + if err != nil || len(resp.Choices) != 1 { + fmt.Println("no choices", err) + + return nil, err + } + + msg := resp.Choices[0].Message + if len(msg.ToolCalls) != 1 { + return nil, fmt.Errorf("len(toolcalls): %v", len(msg.ToolCalls)) + } + + params := ActionParams{} + if err := params.Read(msg.ToolCalls[0].Function.Arguments); err != nil { + fmt.Println("can't read params", err) + + return nil, err + } + + return params, nil +} + +type Actions []Action + +func (a Actions) ToTools() []openai.Tool { + tools := []openai.Tool{} + for _, action := range a { + tools = append(tools, openai.Tool{ + Type: openai.ToolTypeFunction, + Function: action.Definition().FD(), + }) + } + return tools +} + +func (a *Agent) generateParameters(ctx context.Context, action Action, conversation []openai.ChatCompletionMessage) (ActionParams, error) { + return a.decision(ctx, conversation, a.options.actions.ToTools(), action.ID()) +} + +const pickActionTemplate = `You can take any of the following tools: + +{{range .Actions}}{{.ID}}: {{.Description}}{{end}} + +or none. Given the text below, decide which action to take and explain the reasoning behind it. For answering without picking a choice, reply with 'none'. + +{{range .Messages}}{{.Content}}{{end}} +` + +func (a *Agent) pickAction(ctx context.Context, messages []openai.ChatCompletionMessage) (Action, error) { + actionChoice := struct { + Intent string `json:"tool"` + Reasoning string `json:"reasoning"` + }{} + + prompt := bytes.NewBuffer([]byte{}) + tmpl, err := template.New("pickAction").Parse(pickActionTemplate) + if err != nil { + return nil, err + } + + err = tmpl.Execute(prompt, struct { + Actions []Action + Messages []openai.ChatCompletionMessage + }{ + Actions: a.options.actions, + Messages: messages, + }) + if err != nil { + return nil, err + } + + fmt.Println(prompt.String()) + + actionsID := []string{} + for _, m := range a.options.actions { + actionsID = append(actionsID, m.ID()) + } + intentionsTools := NewIntention(actionsID...) + + conversation := []openai.ChatCompletionMessage{ + { + Role: "user", + Content: prompt.String(), + }, + } + + params, err := a.decision(ctx, conversation, Actions{intentionsTools}.ToTools(), intentionsTools.ID()) + if err != nil { + fmt.Println("failed decision", err) + return nil, err + } + + dat, err := json.Marshal(params) + if err != nil { + return nil, err + } + + err = json.Unmarshal(dat, &actionChoice) + if err != nil { + return nil, err + } + + fmt.Printf("Action choice: %v\n", actionChoice) + if actionChoice.Intent == "" || actionChoice.Intent == "none" { + return nil, fmt.Errorf("no intent detected") + } + + // Find the action + var action Action + for _, a := range a.options.actions { + if a.ID() == actionChoice.Intent { + action = a + break + } + } + + if action == nil { + fmt.Println("No action found for intent: ", actionChoice.Intent) + return nil, fmt.Errorf("No action found for intent:" + actionChoice.Intent) + } + + return action, nil +} + func (a *Agent) consumeJob(job *Job) { // Consume the job and generate a response @@ -118,52 +256,6 @@ func (a *Agent) consumeJob(job *Job) { return } - actionChoice := struct { - Intent string `json:"intent"` - Reasoning string `json:"reasoning"` - }{} - - action_pick := "You can take any action between: " - for _, action := range a.options.actions { - action_pick += action.ID() + ": " + action.Description() + ", " - } - - action_pick += "or none." - action_pick += "Given the text below, decide which action to take and explain the reasoning behind it. For answering without picking a choice, reply with 'none'." - action_pick += "return the result as a JSON object with the 'intent' and 'reasoning' fields." - - err := llm.GenerateJSON(ctx, a.client, a.options.LLMAPI.Model, action_pick, &actionChoice) - if err != nil { - fmt.Println("Error generating JSON: ", err) - return - } - - fmt.Println("Action choice: ", actionChoice) - if actionChoice.Intent == "" || actionChoice.Intent == "none" { - fmt.Println("No intent detected") - return - } - - // Find the action - var action Action - for _, a := range a.options.actions { - if a.ID() == actionChoice.Intent { - action = a - break - } - } - - if action == nil { - fmt.Println("No action found for intent: ", actionChoice.Intent) - return - } - - // Fill the action parameters - - // https://github.com/sashabaranov/go-openai/blob/0925563e86c2fdc5011310aa616ba493989cfe0a/examples/completion-with-tool/main.go#L16 - actions := a.options.actions - tools := []openai.Tool{} - messages := a.currentConversation if job.Text != "" { messages = append(messages, openai.ChatCompletionMessage{ @@ -172,51 +264,21 @@ func (a *Agent) consumeJob(job *Job) { }) } - for _, action := range actions { - tools = append(tools, openai.Tool{ - Type: openai.ToolTypeFunction, - Function: action.Definition().FD(), - }) - } - - decision := openai.ChatCompletionRequest{ - Model: a.options.LLMAPI.Model, - Messages: messages, - Tools: tools, - ToolChoice: &openai.ToolChoice{ - Type: openai.ToolTypeFunction, - Function: openai.ToolFunction{Name: action.ID()}, - }, - } - resp, err := a.client.CreateChatCompletion(ctx, decision) - if err != nil || len(resp.Choices) != 1 { - fmt.Printf("Completion error: err:%v len(choices):%v\n", err, - len(resp.Choices)) + chosenAction, err := a.pickAction(ctx, messages) + if err != nil { + fmt.Printf("error picking action: %v\n", err) return } - - msg := resp.Choices[0].Message - if len(msg.ToolCalls) != 1 { - fmt.Printf("Completion error: len(toolcalls): %v\n", len(msg.ToolCalls)) - return - } - - // simulate calling the function & responding to OpenAI - messages = append(messages, msg) - fmt.Printf("OpenAI called us back wanting to invoke our function '%v' with params '%v'\n", - msg.ToolCalls[0].Function.Name, msg.ToolCalls[0].Function.Arguments) - - params := ActionParams{} - if err := params.Read(msg.ToolCalls[0].Function.Arguments); err != nil { - fmt.Printf("error unmarshalling arguments: %v\n", err) + params, err := a.generateParameters(ctx, chosenAction, messages) + if err != nil { + fmt.Printf("error generating parameters: %v\n", err) return } var result string - for _, action := range actions { - fmt.Println("Checking action: ", action.ID()) - fmt.Println("Checking action: ", msg.ToolCalls[0].Function.Name) - if action.ID() == msg.ToolCalls[0].Function.Name { + for _, action := range a.options.actions { + fmt.Println("Checking action: ", action.ID(), chosenAction.ID()) + if action.ID() == chosenAction.ID() { fmt.Printf("Running action: %v\n", action.ID()) if result, err = action.Run(params); err != nil { fmt.Printf("error running action: %v\n", err) @@ -226,15 +288,24 @@ func (a *Agent) consumeJob(job *Job) { } fmt.Printf("Action run result: %v\n", result) + // calling the function + messages = append(messages, openai.ChatCompletionMessage{ + Role: "assistant", + FunctionCall: &openai.FunctionCall{ + Name: chosenAction.ID(), + Arguments: params.String(), + }, + }) + // result of calling the function messages = append(messages, openai.ChatCompletionMessage{ Role: openai.ChatMessageRoleTool, Content: result, - Name: msg.ToolCalls[0].Function.Name, - ToolCallID: msg.ToolCalls[0].ID, + Name: chosenAction.ID(), + ToolCallID: chosenAction.ID(), }) - resp, err = a.client.CreateChatCompletion(ctx, + resp, err := a.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{ Model: a.options.LLMAPI.Model, Messages: messages, @@ -248,7 +319,7 @@ func (a *Agent) consumeJob(job *Job) { } // display OpenAI's response to the original question utilizing our function - msg = resp.Choices[0].Message + msg := resp.Choices[0].Message fmt.Printf("OpenAI answered the original request with: %v\n", msg.Content) diff --git a/agent/options.go b/agent/options.go index 851537c..84699e1 100644 --- a/agent/options.go +++ b/agent/options.go @@ -17,7 +17,7 @@ type options struct { character Character randomIdentityGuidance string randomIdentity bool - actions []Action + actions Actions context context.Context }