From 8601956e53ac330fbf01f89879b3bfa355fc7969 Mon Sep 17 00:00:00 2001 From: mudler Date: Sun, 31 Mar 2024 17:08:34 +0200 Subject: [PATCH] simplify declarations --- agent/actions.go | 64 +++++++++++++++++++++++++++++++----------------- 1 file changed, 42 insertions(+), 22 deletions(-) diff --git a/agent/actions.go b/agent/actions.go index 98f9c5d..713f01a 100644 --- a/agent/actions.go +++ b/agent/actions.go @@ -11,6 +11,7 @@ import ( //"github.com/mudler/local-agent-framework/llm" "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/jsonschema" ) type ActionContext struct { @@ -30,20 +31,33 @@ func (ap ActionParams) String() string { return string(b) } -type ActionDefinition openai.FunctionDefinition - -func (a ActionDefinition) FD() openai.FunctionDefinition { - return openai.FunctionDefinition(a) -} +//type ActionDefinition openai.FunctionDefinition // Actions is something the agent can do type Action interface { - ID() string - Description() string Run(ActionParams) (string, error) Definition() ActionDefinition } +type ActionDefinition struct { + Properties map[string]jsonschema.Definition + Required []string + Name string + Description string +} + +func (a ActionDefinition) ToFunctionDefinition() openai.FunctionDefinition { + return openai.FunctionDefinition{ + Name: a.Name, + Description: a.Description, + Parameters: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: a.Properties, + Required: a.Required, + }, + } +} + var ErrContextCanceled = fmt.Errorf("context canceled") func (a *Agent) Stop() { @@ -140,19 +154,19 @@ func (a Actions) ToTools() []openai.Tool { for _, action := range a { tools = append(tools, openai.Tool{ Type: openai.ToolTypeFunction, - Function: action.Definition().FD(), + Function: action.Definition().ToFunctionDefinition(), }) } return tools } func (a *Agent) generateParameters(ctx context.Context, action Action, conversation []openai.ChatCompletionMessage) (ActionParams, error) { - return a.decision(ctx, conversation, a.options.actions.ToTools(), action.ID()) + return a.decision(ctx, conversation, a.options.actions.ToTools(), action.Definition().Name) } const pickActionTemplate = `You can take any of the following tools: -{{range .Actions}}{{.ID}}: {{.Description}}{{end}} +{{range .Actions}}{{.Name}}: {{.Description}}{{end}} or none. Given the text below, decide which action to take and explain the reasoning behind it. For answering without picking a choice, reply with 'none'. @@ -170,12 +184,15 @@ func (a *Agent) pickAction(ctx context.Context, messages []openai.ChatCompletion if err != nil { return nil, err } - + definitions := []ActionDefinition{} + for _, m := range a.options.actions { + definitions = append(definitions, m.Definition()) + } err = tmpl.Execute(prompt, struct { - Actions []Action + Actions []ActionDefinition Messages []openai.ChatCompletionMessage }{ - Actions: a.options.actions, + Actions: definitions, Messages: messages, }) if err != nil { @@ -186,7 +203,7 @@ func (a *Agent) pickAction(ctx context.Context, messages []openai.ChatCompletion actionsID := []string{} for _, m := range a.options.actions { - actionsID = append(actionsID, m.ID()) + actionsID = append(actionsID, m.Definition().Name) } intentionsTools := NewIntention(actionsID...) @@ -197,7 +214,10 @@ func (a *Agent) pickAction(ctx context.Context, messages []openai.ChatCompletion }, } - params, err := a.decision(ctx, conversation, Actions{intentionsTools}.ToTools(), intentionsTools.ID()) + params, err := a.decision(ctx, + conversation, + Actions{intentionsTools}.ToTools(), + intentionsTools.Definition().Name) if err != nil { fmt.Println("failed decision", err) return nil, err @@ -221,7 +241,7 @@ func (a *Agent) pickAction(ctx context.Context, messages []openai.ChatCompletion // Find the action var action Action for _, a := range a.options.actions { - if a.ID() == actionChoice.Intent { + if a.Definition().Name == actionChoice.Intent { action = a break } @@ -277,9 +297,9 @@ func (a *Agent) consumeJob(job *Job) { var result string for _, action := range a.options.actions { - fmt.Println("Checking action: ", action.ID(), chosenAction.ID()) - if action.ID() == chosenAction.ID() { - fmt.Printf("Running action: %v\n", action.ID()) + fmt.Println("Checking action: ", action.Definition().Name, chosenAction.Definition().Name) + if action.Definition().Name == chosenAction.Definition().Name { + fmt.Printf("Running action: %v\n", action.Definition().Name) if result, err = action.Run(params); err != nil { fmt.Printf("error running action: %v\n", err) return @@ -292,7 +312,7 @@ func (a *Agent) consumeJob(job *Job) { messages = append(messages, openai.ChatCompletionMessage{ Role: "assistant", FunctionCall: &openai.FunctionCall{ - Name: chosenAction.ID(), + Name: chosenAction.Definition().Name, Arguments: params.String(), }, }) @@ -301,8 +321,8 @@ func (a *Agent) consumeJob(job *Job) { messages = append(messages, openai.ChatCompletionMessage{ Role: openai.ChatMessageRoleTool, Content: result, - Name: chosenAction.ID(), - ToolCallID: chosenAction.ID(), + Name: chosenAction.Definition().Name, + ToolCallID: chosenAction.Definition().Name, }) resp, err := a.client.CreateChatCompletion(ctx,