From 3b1a54083d5f9663482084491afc3a1bbbcea9df Mon Sep 17 00:00:00 2001 From: mudler Date: Sun, 31 Mar 2024 23:06:28 +0200 Subject: [PATCH] wip --- action/definition.go | 25 +++++++++- action/intention.go | 2 +- action/reply.go | 34 ++++++++++++++ agent/actions.go | 108 +++++++++++++++++++++++++++++-------------- agent/agent_test.go | 2 +- agent/jobs.go | 32 +++++++++++-- 6 files changed, 161 insertions(+), 42 deletions(-) create mode 100644 action/reply.go diff --git a/action/definition.go b/action/definition.go index 5ed3aae..c4b1414 100644 --- a/action/definition.go +++ b/action/definition.go @@ -36,18 +36,39 @@ func (ap ActionParams) String() string { return string(b) } +func (ap ActionParams) Unmarshal(v interface{}) error { + b, err := json.Marshal(ap) + if err != nil { + return err + } + if err := json.Unmarshal(b, v); err != nil { + return err + } + return nil +} + //type ActionDefinition openai.FunctionDefinition type ActionDefinition struct { Properties map[string]jsonschema.Definition Required []string - Name string + Name ActionDefinitionName Description string } +type ActionDefinitionName string + +func (a ActionDefinitionName) Is(name string) bool { + return string(a) == name +} + +func (a ActionDefinitionName) String() string { + return string(a) +} + func (a ActionDefinition) ToFunctionDefinition() openai.FunctionDefinition { return openai.FunctionDefinition{ - Name: a.Name, + Name: a.Name.String(), Description: a.Description, Parameters: jsonschema.Definition{ Type: jsonschema.Object, diff --git a/action/intention.go b/action/intention.go index 4d2cce1..6c04b91 100644 --- a/action/intention.go +++ b/action/intention.go @@ -23,7 +23,7 @@ func (a *IntentAction) Definition() ActionDefinition { Properties: map[string]jsonschema.Definition{ "reasoning": { Type: jsonschema.String, - Description: "The city and state, e.g. San Francisco, CA", + Description: "A detailed reasoning on why you want to call this tool.", }, "tool": { Type: jsonschema.String, diff --git a/action/reply.go b/action/reply.go new file mode 100644 index 0000000..be9b8e9 --- /dev/null +++ b/action/reply.go @@ -0,0 +1,34 @@ +package action + +import ( + "github.com/sashabaranov/go-openai/jsonschema" +) + +// ReplyActionName is the name of the reply action +// used by the LLM to reply to the user without +// any additional processing +const ReplyActionName = "reply" + +func NewReply() *ReplyAction { + return &ReplyAction{} +} + +type ReplyAction struct{} + +func (a *ReplyAction) Run(ActionParams) (string, error) { + return "no-op", nil +} + +func (a *ReplyAction) Definition() ActionDefinition { + return ActionDefinition{ + Name: ReplyActionName, + Description: "Use this tool to reply to the user once we have all the informations we need.", + Properties: map[string]jsonschema.Definition{ + "message": { + Type: jsonschema.String, + Description: "The message to reply with", + }, + }, + Required: []string{"message"}, + } +} diff --git a/agent/actions.go b/agent/actions.go index f276e62..84672cf 100644 --- a/agent/actions.go +++ b/agent/actions.go @@ -3,7 +3,6 @@ package agent import ( "bytes" "context" - "encoding/json" "fmt" "html/template" @@ -31,7 +30,21 @@ func (a Actions) ToTools() []openai.Tool { return tools } -func (a *Agent) decision(ctx context.Context, conversation []openai.ChatCompletionMessage, tools []openai.Tool, toolchoice any) (action.ActionParams, error) { +func (a Actions) Find(name string) Action { + for _, action := range a { + if action.Definition().Name.Is(name) { + return action + } + } + return nil +} + +// decision forces the agent to take on of the available actions +func (a *Agent) decision( + ctx context.Context, + conversation []openai.ChatCompletionMessage, + tools []openai.Tool, toolchoice any) (action.ActionParams, error) { + decision := openai.ChatCompletionRequest{ Model: a.options.LLMAPI.Model, Messages: conversation, @@ -47,13 +60,13 @@ func (a *Agent) decision(ctx context.Context, conversation []openai.ChatCompleti msg := resp.Choices[0].Message if len(msg.ToolCalls) != 1 { + fmt.Println(msg) return nil, fmt.Errorf("len(toolcalls): %v", len(msg.ToolCalls)) } params := action.ActionParams{} if err := params.Read(msg.ToolCalls[0].Function.Arguments); err != nil { fmt.Println("can't read params", err) - return nil, err } @@ -61,33 +74,71 @@ func (a *Agent) decision(ctx context.Context, conversation []openai.ChatCompleti } func (a *Agent) generateParameters(ctx context.Context, action Action, conversation []openai.ChatCompletionMessage) (action.ActionParams, error) { - return a.decision(ctx, conversation, a.options.actions.ToTools(), action.Definition().Name) + return a.decision(ctx, + conversation, + a.options.actions.ToTools(), + action.Definition().Name) } const pickActionTemplate = `You can take any of the following tools: -{{range .Actions}}{{.Name}}: {{.Description}}{{end}} +{{range .Actions -}} +{{.Name}}: {{.Description }} +{{ end }} -or none. Given the text below, decide which action to take and explain the reasoning behind it. For answering without picking a choice, reply with 'none'. +To answer back to the user, use the "answer" tool. +Given the text below, decide which action to take and explain the detailed reasoning behind it. For answering without picking a choice, reply with 'none'. -{{range .Messages}}{{.Content}}{{end}} +{{range .Messages }} +{{if eq .Role "tool"}}Tool result{{else}}{{.Role}}{{ end }}: {{.Content}} +{{if .FunctionCall}} +Tool called with: {{.FunctionCall}} +{{end}} +{{range .ToolCalls}} +{{.Name}}: {{.Arguments}} +{{end}} +{{end}} ` -func (a *Agent) pickAction(ctx context.Context, messages []openai.ChatCompletionMessage) (Action, error) { +const reEvalTemplate = `You can take any of the following tools: + +{{range .Actions}}{{.Name}}: {{.Description}}{{end}} + +To answer back to the user, use the "answer" tool. +For answering without picking a choice, reply with 'none'. + +Given the text below, decide which action to take and explain the reasoning behind it. + +{{range .Messages -}} +{{if eq .Role "tool" }}Tool result{{else}}{{.Role}}: {{ end }}{{.Content }} +{{if .FunctionCall}}Tool called with: {{.FunctionCall}}{{end}} +{{range .ToolCalls}} +{{.Name}}: {{.Arguments}} +{{end}} +{{end}} + +We already have called tools. Evaluate the current situation and decide if we need to execute other tools or answer back with a result.` + +// pickAction picks an action based on the conversation +func (a *Agent) pickAction(ctx context.Context, templ string, messages []openai.ChatCompletionMessage) (Action, string, error) { actionChoice := struct { Intent string `json:"tool"` Reasoning string `json:"reasoning"` }{} prompt := bytes.NewBuffer([]byte{}) - tmpl, err := template.New("pickAction").Parse(pickActionTemplate) + + tmpl, err := template.New("pickAction").Parse(templ) if err != nil { - return nil, err + return nil, "", err } - definitions := []action.ActionDefinition{} + + // It can pick the reply action too + definitions := []action.ActionDefinition{action.NewReply().Definition()} for _, m := range a.options.actions { definitions = append(definitions, m.Definition()) } + err = tmpl.Execute(prompt, struct { Actions []action.ActionDefinition Messages []openai.ChatCompletionMessage @@ -96,14 +147,14 @@ func (a *Agent) pickAction(ctx context.Context, messages []openai.ChatCompletion Messages: messages, }) if err != nil { - return nil, err + return nil, "", err } - fmt.Println(prompt.String()) + fmt.Println("=== PROMPT START ===", prompt.String(), "=== PROMPT END ===") actionsID := []string{} for _, m := range a.options.actions { - actionsID = append(actionsID, m.Definition().Name) + actionsID = append(actionsID, m.Definition().Name.String()) } intentionsTools := action.NewIntention(actionsID...) @@ -120,37 +171,26 @@ func (a *Agent) pickAction(ctx context.Context, messages []openai.ChatCompletion intentionsTools.Definition().Name) if err != nil { fmt.Println("failed decision", err) - return nil, err + return nil, "", err } - dat, err := json.Marshal(params) + err = params.Unmarshal(&actionChoice) if err != nil { - return nil, err - } - - err = json.Unmarshal(dat, &actionChoice) - if err != nil { - return nil, err + return nil, "", err } fmt.Printf("Action choice: %v\n", actionChoice) + if actionChoice.Intent == "" || actionChoice.Intent == "none" { - return nil, fmt.Errorf("no intent detected") + return nil, "", fmt.Errorf("no intent detected") } // Find the action - var action Action - for _, a := range a.options.actions { - if a.Definition().Name == actionChoice.Intent { - action = a - break - } - } - - if action == nil { + chosenAction := append(a.options.actions, action.NewReply()).Find(actionChoice.Intent) + if chosenAction == nil { fmt.Println("No action found for intent: ", actionChoice.Intent) - return nil, fmt.Errorf("No action found for intent:" + actionChoice.Intent) + return nil, "", fmt.Errorf("No action found for intent:" + actionChoice.Intent) } - return action, nil + return chosenAction, actionChoice.Reasoning, nil } diff --git a/agent/agent_test.go b/agent/agent_test.go index d088f57..9c54fcf 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -10,7 +10,7 @@ import ( "github.com/sashabaranov/go-openai/jsonschema" ) -const testActionResult = "It's going to be windy" +const testActionResult = "In Boston it's 30C today, it's sunny, and humidity is at 98%" var _ Action = &TestAction{} diff --git a/agent/jobs.go b/agent/jobs.go index 8cc83f3..fc26b59 100644 --- a/agent/jobs.go +++ b/agent/jobs.go @@ -96,11 +96,20 @@ func (a *Agent) consumeJob(job *Job) { }) } - chosenAction, err := a.pickAction(ctx, messages) + // choose an action first + chosenAction, reasoning, err := a.pickAction(ctx, pickActionTemplate, messages) if err != nil { fmt.Printf("error picking action: %v\n", err) return } + + if chosenAction.Definition().Name.Is(action.ReplyActionName) { + fmt.Println("No action to do, just reply") + job.Result.SetResult(reasoning) + job.Result.Finish() + return + } + params, err := a.generateParameters(ctx, chosenAction, messages) if err != nil { fmt.Printf("error generating parameters: %v\n", err) @@ -124,7 +133,7 @@ func (a *Agent) consumeJob(job *Job) { messages = append(messages, openai.ChatCompletionMessage{ Role: "assistant", FunctionCall: &openai.FunctionCall{ - Name: chosenAction.Definition().Name, + Name: chosenAction.Definition().Name.String(), Arguments: params.String(), }, }) @@ -133,10 +142,25 @@ func (a *Agent) consumeJob(job *Job) { messages = append(messages, openai.ChatCompletionMessage{ Role: openai.ChatMessageRoleTool, Content: result, - Name: chosenAction.Definition().Name, - ToolCallID: chosenAction.Definition().Name, + Name: chosenAction.Definition().Name.String(), + ToolCallID: chosenAction.Definition().Name.String(), }) + // given the result, we can now ask OpenAI to complete the conversation or + // to continue using another tool given the result + followingAction, reasoning, err := a.pickAction(ctx, reEvalTemplate, messages) + if err != nil { + fmt.Printf("error picking action: %v\n", err) + return + } + if !chosenAction.Definition().Name.Is(action.ReplyActionName) { + // We need to do another action (?) + // The agent decided to do another action + fmt.Println("Another action to do: ", followingAction.Definition().Name) + fmt.Println("Reasoning: ", reasoning) + return + } + resp, err := a.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{ Model: a.options.LLMAPI.Model,