This commit is contained in:
mudler
2024-03-31 23:06:28 +02:00
parent aa62d9ef9e
commit 3b1a54083d
6 changed files with 161 additions and 42 deletions

View File

@@ -36,18 +36,39 @@ func (ap ActionParams) String() string {
return string(b) return string(b)
} }
func (ap ActionParams) Unmarshal(v interface{}) error {
b, err := json.Marshal(ap)
if err != nil {
return err
}
if err := json.Unmarshal(b, v); err != nil {
return err
}
return nil
}
//type ActionDefinition openai.FunctionDefinition //type ActionDefinition openai.FunctionDefinition
type ActionDefinition struct { type ActionDefinition struct {
Properties map[string]jsonschema.Definition Properties map[string]jsonschema.Definition
Required []string Required []string
Name string Name ActionDefinitionName
Description string Description string
} }
type ActionDefinitionName string
func (a ActionDefinitionName) Is(name string) bool {
return string(a) == name
}
func (a ActionDefinitionName) String() string {
return string(a)
}
func (a ActionDefinition) ToFunctionDefinition() openai.FunctionDefinition { func (a ActionDefinition) ToFunctionDefinition() openai.FunctionDefinition {
return openai.FunctionDefinition{ return openai.FunctionDefinition{
Name: a.Name, Name: a.Name.String(),
Description: a.Description, Description: a.Description,
Parameters: jsonschema.Definition{ Parameters: jsonschema.Definition{
Type: jsonschema.Object, Type: jsonschema.Object,

View File

@@ -23,7 +23,7 @@ func (a *IntentAction) Definition() ActionDefinition {
Properties: map[string]jsonschema.Definition{ Properties: map[string]jsonschema.Definition{
"reasoning": { "reasoning": {
Type: jsonschema.String, Type: jsonschema.String,
Description: "The city and state, e.g. San Francisco, CA", Description: "A detailed reasoning on why you want to call this tool.",
}, },
"tool": { "tool": {
Type: jsonschema.String, Type: jsonschema.String,

34
action/reply.go Normal file
View File

@@ -0,0 +1,34 @@
package action
import (
"github.com/sashabaranov/go-openai/jsonschema"
)
// ReplyActionName is the name of the reply action
// used by the LLM to reply to the user without
// any additional processing
const ReplyActionName = "reply"
func NewReply() *ReplyAction {
return &ReplyAction{}
}
type ReplyAction struct{}
func (a *ReplyAction) Run(ActionParams) (string, error) {
return "no-op", nil
}
func (a *ReplyAction) Definition() ActionDefinition {
return ActionDefinition{
Name: ReplyActionName,
Description: "Use this tool to reply to the user once we have all the informations we need.",
Properties: map[string]jsonschema.Definition{
"message": {
Type: jsonschema.String,
Description: "The message to reply with",
},
},
Required: []string{"message"},
}
}

View File

@@ -3,7 +3,6 @@ package agent
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json"
"fmt" "fmt"
"html/template" "html/template"
@@ -31,7 +30,21 @@ func (a Actions) ToTools() []openai.Tool {
return tools return tools
} }
func (a *Agent) decision(ctx context.Context, conversation []openai.ChatCompletionMessage, tools []openai.Tool, toolchoice any) (action.ActionParams, error) { func (a Actions) Find(name string) Action {
for _, action := range a {
if action.Definition().Name.Is(name) {
return action
}
}
return nil
}
// decision forces the agent to take on of the available actions
func (a *Agent) decision(
ctx context.Context,
conversation []openai.ChatCompletionMessage,
tools []openai.Tool, toolchoice any) (action.ActionParams, error) {
decision := openai.ChatCompletionRequest{ decision := openai.ChatCompletionRequest{
Model: a.options.LLMAPI.Model, Model: a.options.LLMAPI.Model,
Messages: conversation, Messages: conversation,
@@ -47,13 +60,13 @@ func (a *Agent) decision(ctx context.Context, conversation []openai.ChatCompleti
msg := resp.Choices[0].Message msg := resp.Choices[0].Message
if len(msg.ToolCalls) != 1 { if len(msg.ToolCalls) != 1 {
fmt.Println(msg)
return nil, fmt.Errorf("len(toolcalls): %v", len(msg.ToolCalls)) return nil, fmt.Errorf("len(toolcalls): %v", len(msg.ToolCalls))
} }
params := action.ActionParams{} params := action.ActionParams{}
if err := params.Read(msg.ToolCalls[0].Function.Arguments); err != nil { if err := params.Read(msg.ToolCalls[0].Function.Arguments); err != nil {
fmt.Println("can't read params", err) fmt.Println("can't read params", err)
return nil, err return nil, err
} }
@@ -61,33 +74,71 @@ func (a *Agent) decision(ctx context.Context, conversation []openai.ChatCompleti
} }
func (a *Agent) generateParameters(ctx context.Context, action Action, conversation []openai.ChatCompletionMessage) (action.ActionParams, error) { func (a *Agent) generateParameters(ctx context.Context, action Action, conversation []openai.ChatCompletionMessage) (action.ActionParams, error) {
return a.decision(ctx, conversation, a.options.actions.ToTools(), action.Definition().Name) return a.decision(ctx,
conversation,
a.options.actions.ToTools(),
action.Definition().Name)
} }
const pickActionTemplate = `You can take any of the following tools: const pickActionTemplate = `You can take any of the following tools:
{{range .Actions}}{{.Name}}: {{.Description}}{{end}} {{range .Actions -}}
{{.Name}}: {{.Description }}
{{ end }}
or none. Given the text below, decide which action to take and explain the reasoning behind it. For answering without picking a choice, reply with 'none'. To answer back to the user, use the "answer" tool.
Given the text below, decide which action to take and explain the detailed reasoning behind it. For answering without picking a choice, reply with 'none'.
{{range .Messages}}{{.Content}}{{end}} {{range .Messages }}
{{if eq .Role "tool"}}Tool result{{else}}{{.Role}}{{ end }}: {{.Content}}
{{if .FunctionCall}}
Tool called with: {{.FunctionCall}}
{{end}}
{{range .ToolCalls}}
{{.Name}}: {{.Arguments}}
{{end}}
{{end}}
` `
func (a *Agent) pickAction(ctx context.Context, messages []openai.ChatCompletionMessage) (Action, error) { const reEvalTemplate = `You can take any of the following tools:
{{range .Actions}}{{.Name}}: {{.Description}}{{end}}
To answer back to the user, use the "answer" tool.
For answering without picking a choice, reply with 'none'.
Given the text below, decide which action to take and explain the reasoning behind it.
{{range .Messages -}}
{{if eq .Role "tool" }}Tool result{{else}}{{.Role}}: {{ end }}{{.Content }}
{{if .FunctionCall}}Tool called with: {{.FunctionCall}}{{end}}
{{range .ToolCalls}}
{{.Name}}: {{.Arguments}}
{{end}}
{{end}}
We already have called tools. Evaluate the current situation and decide if we need to execute other tools or answer back with a result.`
// pickAction picks an action based on the conversation
func (a *Agent) pickAction(ctx context.Context, templ string, messages []openai.ChatCompletionMessage) (Action, string, error) {
actionChoice := struct { actionChoice := struct {
Intent string `json:"tool"` Intent string `json:"tool"`
Reasoning string `json:"reasoning"` Reasoning string `json:"reasoning"`
}{} }{}
prompt := bytes.NewBuffer([]byte{}) prompt := bytes.NewBuffer([]byte{})
tmpl, err := template.New("pickAction").Parse(pickActionTemplate)
tmpl, err := template.New("pickAction").Parse(templ)
if err != nil { if err != nil {
return nil, err return nil, "", err
} }
definitions := []action.ActionDefinition{}
// It can pick the reply action too
definitions := []action.ActionDefinition{action.NewReply().Definition()}
for _, m := range a.options.actions { for _, m := range a.options.actions {
definitions = append(definitions, m.Definition()) definitions = append(definitions, m.Definition())
} }
err = tmpl.Execute(prompt, struct { err = tmpl.Execute(prompt, struct {
Actions []action.ActionDefinition Actions []action.ActionDefinition
Messages []openai.ChatCompletionMessage Messages []openai.ChatCompletionMessage
@@ -96,14 +147,14 @@ func (a *Agent) pickAction(ctx context.Context, messages []openai.ChatCompletion
Messages: messages, Messages: messages,
}) })
if err != nil { if err != nil {
return nil, err return nil, "", err
} }
fmt.Println(prompt.String()) fmt.Println("=== PROMPT START ===", prompt.String(), "=== PROMPT END ===")
actionsID := []string{} actionsID := []string{}
for _, m := range a.options.actions { for _, m := range a.options.actions {
actionsID = append(actionsID, m.Definition().Name) actionsID = append(actionsID, m.Definition().Name.String())
} }
intentionsTools := action.NewIntention(actionsID...) intentionsTools := action.NewIntention(actionsID...)
@@ -120,37 +171,26 @@ func (a *Agent) pickAction(ctx context.Context, messages []openai.ChatCompletion
intentionsTools.Definition().Name) intentionsTools.Definition().Name)
if err != nil { if err != nil {
fmt.Println("failed decision", err) fmt.Println("failed decision", err)
return nil, err return nil, "", err
} }
dat, err := json.Marshal(params) err = params.Unmarshal(&actionChoice)
if err != nil { if err != nil {
return nil, err return nil, "", err
}
err = json.Unmarshal(dat, &actionChoice)
if err != nil {
return nil, err
} }
fmt.Printf("Action choice: %v\n", actionChoice) fmt.Printf("Action choice: %v\n", actionChoice)
if actionChoice.Intent == "" || actionChoice.Intent == "none" { if actionChoice.Intent == "" || actionChoice.Intent == "none" {
return nil, fmt.Errorf("no intent detected") return nil, "", fmt.Errorf("no intent detected")
} }
// Find the action // Find the action
var action Action chosenAction := append(a.options.actions, action.NewReply()).Find(actionChoice.Intent)
for _, a := range a.options.actions { if chosenAction == nil {
if a.Definition().Name == actionChoice.Intent {
action = a
break
}
}
if action == nil {
fmt.Println("No action found for intent: ", actionChoice.Intent) fmt.Println("No action found for intent: ", actionChoice.Intent)
return nil, fmt.Errorf("No action found for intent:" + actionChoice.Intent) return nil, "", fmt.Errorf("No action found for intent:" + actionChoice.Intent)
} }
return action, nil return chosenAction, actionChoice.Reasoning, nil
} }

View File

@@ -10,7 +10,7 @@ import (
"github.com/sashabaranov/go-openai/jsonschema" "github.com/sashabaranov/go-openai/jsonschema"
) )
const testActionResult = "It's going to be windy" const testActionResult = "In Boston it's 30C today, it's sunny, and humidity is at 98%"
var _ Action = &TestAction{} var _ Action = &TestAction{}

View File

@@ -96,11 +96,20 @@ func (a *Agent) consumeJob(job *Job) {
}) })
} }
chosenAction, err := a.pickAction(ctx, messages) // choose an action first
chosenAction, reasoning, err := a.pickAction(ctx, pickActionTemplate, messages)
if err != nil { if err != nil {
fmt.Printf("error picking action: %v\n", err) fmt.Printf("error picking action: %v\n", err)
return return
} }
if chosenAction.Definition().Name.Is(action.ReplyActionName) {
fmt.Println("No action to do, just reply")
job.Result.SetResult(reasoning)
job.Result.Finish()
return
}
params, err := a.generateParameters(ctx, chosenAction, messages) params, err := a.generateParameters(ctx, chosenAction, messages)
if err != nil { if err != nil {
fmt.Printf("error generating parameters: %v\n", err) fmt.Printf("error generating parameters: %v\n", err)
@@ -124,7 +133,7 @@ func (a *Agent) consumeJob(job *Job) {
messages = append(messages, openai.ChatCompletionMessage{ messages = append(messages, openai.ChatCompletionMessage{
Role: "assistant", Role: "assistant",
FunctionCall: &openai.FunctionCall{ FunctionCall: &openai.FunctionCall{
Name: chosenAction.Definition().Name, Name: chosenAction.Definition().Name.String(),
Arguments: params.String(), Arguments: params.String(),
}, },
}) })
@@ -133,10 +142,25 @@ func (a *Agent) consumeJob(job *Job) {
messages = append(messages, openai.ChatCompletionMessage{ messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleTool, Role: openai.ChatMessageRoleTool,
Content: result, Content: result,
Name: chosenAction.Definition().Name, Name: chosenAction.Definition().Name.String(),
ToolCallID: chosenAction.Definition().Name, ToolCallID: chosenAction.Definition().Name.String(),
}) })
// given the result, we can now ask OpenAI to complete the conversation or
// to continue using another tool given the result
followingAction, reasoning, err := a.pickAction(ctx, reEvalTemplate, messages)
if err != nil {
fmt.Printf("error picking action: %v\n", err)
return
}
if !chosenAction.Definition().Name.Is(action.ReplyActionName) {
// We need to do another action (?)
// The agent decided to do another action
fmt.Println("Another action to do: ", followingAction.Definition().Name)
fmt.Println("Reasoning: ", reasoning)
return
}
resp, err := a.client.CreateChatCompletion(ctx, resp, err := a.client.CreateChatCompletion(ctx,
openai.ChatCompletionRequest{ openai.ChatCompletionRequest{
Model: a.options.LLMAPI.Model, Model: a.options.LLMAPI.Model,