return statesg

This commit is contained in:
mudler
2024-04-01 22:50:11 +02:00
parent 7c679ead94
commit 8e3a1fcbe5
4 changed files with 50 additions and 26 deletions

View File

@@ -11,6 +11,17 @@ import (
"github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai"
) )
type ActionState struct {
ActionCurrentState
Result string
}
type ActionCurrentState struct {
Action Action
Params action.ActionParams
Reasoning string
}
// Actions is something the agent can do // Actions is something the agent can do
type Action interface { type Action interface {
Run(action.ActionParams) (string, error) Run(action.ActionParams) (string, error)
@@ -135,7 +146,7 @@ func (a *Agent) pickAction(ctx context.Context, templ string, messages []openai.
return nil, "", err return nil, "", err
} }
err = hudTmpl.Execute(prompt, a.prepareHUD()) err = hudTmpl.Execute(hud, a.prepareHUD())
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }

View File

@@ -63,7 +63,7 @@ func New(opts ...Option) (*Agent, error) {
// Ask is a pre-emptive, blocking call that returns the response as soon as it's ready. // Ask is a pre-emptive, blocking call that returns the response as soon as it's ready.
// It discards any other computation. // It discards any other computation.
func (a *Agent) Ask(opts ...JobOption) []string { func (a *Agent) Ask(opts ...JobOption) []ActionState {
//a.StopAction() //a.StopAction()
j := NewJob(opts...) j := NewJob(opts...)
// fmt.Println("Job created", text) // fmt.Println("Job created", text)

View File

@@ -62,13 +62,18 @@ var _ = Describe("Agent test", func() {
go agent.Run() go agent.Run()
defer agent.Stop() defer agent.Stop()
res := agent.Ask( res := agent.Ask(
WithReasoningCallback(func(a Action, ap action.ActionParams, s string) { WithReasoningCallback(func(state ActionCurrentState) bool {
fmt.Println("Reasoning", s) fmt.Println("Reasoning", state)
return true
}), }),
WithText("can you get the weather in boston, and afterward of Milano, Italy?"), WithText("can you get the weather in boston, and afterward of Milano, Italy?"),
) )
Expect(res).To(ContainElement(testActionResult), fmt.Sprint(res)) reasons := []string{}
Expect(res).To(ContainElement(testActionResult2), fmt.Sprint(res)) for _, r := range res {
reasons = append(reasons, r.Result)
}
Expect(reasons).To(ContainElement(testActionResult), fmt.Sprint(res))
Expect(reasons).To(ContainElement(testActionResult2), fmt.Sprint(res))
}) })
It("pick the correct action", func() { It("pick the correct action", func() {
agent, err := New( agent, err := New(
@@ -83,7 +88,11 @@ var _ = Describe("Agent test", func() {
res := agent.Ask( res := agent.Ask(
WithText("can you get the weather in boston?"), WithText("can you get the weather in boston?"),
) )
Expect(res).To(ContainElement(testActionResult), fmt.Sprint(res)) reasons := []string{}
for _, r := range res {
reasons = append(reasons, r.Result)
}
Expect(reasons).To(ContainElement(testActionResult), fmt.Sprint(res))
}) })
}) })
}) })

View File

@@ -17,28 +17,28 @@ type Job struct {
Text string Text string
Image string // base64 encoded image Image string // base64 encoded image
Result *JobResult Result *JobResult
reasoningCallback func(Action, action.ActionParams, string) reasoningCallback func(ActionCurrentState) bool
resultCallback func(Action, action.ActionParams, string, string) resultCallback func(ActionState)
} }
// JobResult is the result of a job // JobResult is the result of a job
type JobResult struct { type JobResult struct {
sync.Mutex sync.Mutex
// The result of a job // The result of a job
Data []string State []ActionState
Error error Error error
ready chan bool ready chan bool
} }
type JobOption func(*Job) type JobOption func(*Job)
func WithReasoningCallback(f func(Action, action.ActionParams, string)) JobOption { func WithReasoningCallback(f func(ActionCurrentState) bool) JobOption {
return func(r *Job) { return func(r *Job) {
r.reasoningCallback = f r.reasoningCallback = f
} }
} }
func WithResultCallback(f func(Action, action.ActionParams, string, string)) JobOption { func WithResultCallback(f func(ActionState)) JobOption {
return func(r *Job) { return func(r *Job) {
r.resultCallback = f r.resultCallback = f
} }
@@ -52,18 +52,18 @@ func NewJobResult() *JobResult {
return r return r
} }
func (j *Job) Callback(a Action, p action.ActionParams, s string) { func (j *Job) Callback(stateResult ActionCurrentState) bool {
if j.reasoningCallback == nil { if j.reasoningCallback == nil {
return return true
} }
j.reasoningCallback(a, p, s) return j.reasoningCallback(stateResult)
} }
func (j *Job) CallbackWithResult(a Action, p action.ActionParams, s, r string) { func (j *Job) CallbackWithResult(stateResult ActionState) {
if j.resultCallback == nil { if j.resultCallback == nil {
return return
} }
j.resultCallback(a, p, s, r) j.resultCallback(stateResult)
} }
func WithImage(image string) JobOption { func WithImage(image string) JobOption {
@@ -93,11 +93,11 @@ func NewJob(opts ...JobOption) *Job {
} }
// SetResult sets the result of a job // SetResult sets the result of a job
func (j *JobResult) SetResult(text string) { func (j *JobResult) SetResult(text ActionState) {
j.Lock() j.Lock()
defer j.Unlock() defer j.Unlock()
j.Data = append(j.Data, text) j.State = append(j.State, text)
} }
// SetResult sets the result of a job // SetResult sets the result of a job
@@ -110,11 +110,11 @@ func (j *JobResult) Finish(e error) {
} }
// WaitResult waits for the result of a job // WaitResult waits for the result of a job
func (j *JobResult) WaitResult() []string { func (j *JobResult) WaitResult() []ActionState {
<-j.ready <-j.ready
j.Lock() j.Lock()
defer j.Unlock() defer j.Unlock()
return j.Data return j.State
} }
const pickActionTemplate = `You can take any of the following tools: const pickActionTemplate = `You can take any of the following tools:
@@ -185,8 +185,7 @@ func (a *Agent) consumeJob(job *Job) {
} }
if chosenAction == nil || chosenAction.Definition().Name.Is(action.ReplyActionName) { if chosenAction == nil || chosenAction.Definition().Name.Is(action.ReplyActionName) {
fmt.Println("No action to do, just reply") job.Result.SetResult(ActionState{ActionCurrentState{nil, nil, "No action to do, just reply"}, ""})
job.Result.SetResult(reasoning)
return return
} }
@@ -196,7 +195,11 @@ func (a *Agent) consumeJob(job *Job) {
return return
} }
job.Callback(chosenAction, params.actionParams, reasoning) if !job.Callback(ActionCurrentState{chosenAction, params.actionParams, reasoning}) {
fmt.Println("Stop from callback")
job.Result.SetResult(ActionState{ActionCurrentState{chosenAction, params.actionParams, reasoning}, "stopped by callback"})
return
}
if params.actionParams == nil { if params.actionParams == nil {
fmt.Println("no parameters") fmt.Println("no parameters")
@@ -216,8 +219,9 @@ func (a *Agent) consumeJob(job *Job) {
} }
} }
fmt.Printf("Action run result: %v\n", result) fmt.Printf("Action run result: %v\n", result)
job.Result.SetResult(result) stateResult := ActionState{ActionCurrentState{chosenAction, params.actionParams, reasoning}, result}
job.CallbackWithResult(chosenAction, params.actionParams, reasoning, result) job.Result.SetResult(stateResult)
job.CallbackWithResult(stateResult)
// calling the function // calling the function
messages = append(messages, openai.ChatCompletionMessage{ messages = append(messages, openai.ChatCompletionMessage{