From 61e4be0d0cfa56eccad44e91312bb8c960d61a84 Mon Sep 17 00:00:00 2001 From: mudler Date: Sun, 21 Jan 2024 16:12:34 +0100 Subject: [PATCH] add job/queue logics --- agent/actions.go | 96 ++++++++++++++++++++++++++++ agent/agent.go | 64 +++++++++++++++++++ agent/ask.go | 1 - agent/jobs.go | 57 +++++++++++++++++ agent/{constructor.go => options.go} | 45 ++++--------- agent/state.go | 6 +- llm/json.go | 8 +-- 7 files changed, 236 insertions(+), 41 deletions(-) create mode 100644 agent/agent.go delete mode 100644 agent/ask.go create mode 100644 agent/jobs.go rename agent/{constructor.go => options.go} (72%) diff --git a/agent/actions.go b/agent/actions.go index 4883155..79d94b6 100644 --- a/agent/actions.go +++ b/agent/actions.go @@ -1 +1,97 @@ package agent + +import ( + "context" + "fmt" + + "github.com/mudler/local-agent-framework/llm" +) + +type ActionContext struct { + context.Context + cancelFunc context.CancelFunc +} + +// Actions is something the agent can do +type Action interface { + Description() string + ID() string + Run(map[string]string) error +} + +var ErrContextCanceled = fmt.Errorf("context canceled") + +func (a *Agent) Stop() { + a.context.cancelFunc() +} + +func (a *Agent) Run() error { + // The agent run does two things: + // picks up requests from a queue + // and generates a response/perform actions + + // It is also preemptive. + // That is, it can interrupt the current action + // if another one comes in. + + // If there is no action, periodically evaluate if it has to do something on its own. + + // Expose a REST API to interact with the agent to ask it things + + for { + select { + case job := <-a.jobQueue: + // Consume the job and generate a response + a.consumeJob(job) + case <-a.context.Done(): + // Agent has been canceled, return error + return ErrContextCanceled + } + } +} + +// StopAction stops the current action +// if any. Can be called before adding a new job. +func (a *Agent) StopAction() { + if a.actionContext != nil { + a.actionContext.cancelFunc() + } +} + +func (a *Agent) consumeJob(job *Job) { + // Consume the job and generate a response + // Implement your logic here + + // Set the action context + ctx, cancel := context.WithCancel(context.Background()) + a.actionContext = &ActionContext{ + Context: ctx, + cancelFunc: cancel, + } + + if job.Image != "" { + // TODO: Use llava to explain the image content + } + + if job.Text == "" { + fmt.Println("no text!") + return + } + + decision := struct { + Action string `json:"action"` + }{ + Action: "generate_identity", + } + + llm.GenerateJSON(ctx, a.client, a.options.LLMAPI.Model, + "decide which action to take give the", + &decision) + + // perform the action (if any) + // or reply with a result + + // if there is an action... + job.Result.SetResult("I don't know how to do that yet.") + +} diff --git a/agent/agent.go b/agent/agent.go new file mode 100644 index 0000000..d19b521 --- /dev/null +++ b/agent/agent.go @@ -0,0 +1,64 @@ +package agent + +import ( + "context" + "fmt" + + "github.com/mudler/local-agent-framework/llm" + "github.com/sashabaranov/go-openai" +) + +type Agent struct { + options *options + Character Character + client *openai.Client + jobQueue chan *Job + actionContext *ActionContext + context *ActionContext + availableActions []Action +} + +func New(opts ...Option) (*Agent, error) { + options, err := newOptions(opts...) + if err != nil { + if err != nil { + err = fmt.Errorf("failed to set options: %v", err) + } + return nil, err + } + + client := llm.NewClient(options.LLMAPI.APIKey, options.LLMAPI.APIURL) + + c := context.Background() + if options.context != nil { + c = options.context + } + + ctx, cancel := context.WithCancel(c) + a := &Agent{ + options: options, + client: client, + Character: options.character, + context: &ActionContext{ + Context: ctx, + cancelFunc: cancel, + }, + } + + if a.options.randomIdentity { + if err = a.generateIdentity(a.options.randomIdentityGuidance); err != nil { + return a, fmt.Errorf("failed to generate identity: %v", err) + } + } + + return a, nil +} + +// Ask is a pre-emptive, blocking call that returns the response as soon as it's ready. +// It discards any other computation. +func (a *Agent) Ask(text, image string) string { + a.StopAction() + j := NewJob(text, image) + a.jobQueue <- j + return j.Result.WaitResult() +} diff --git a/agent/ask.go b/agent/ask.go deleted file mode 100644 index 4883155..0000000 --- a/agent/ask.go +++ /dev/null @@ -1 +0,0 @@ -package agent diff --git a/agent/jobs.go b/agent/jobs.go new file mode 100644 index 0000000..4cc3b06 --- /dev/null +++ b/agent/jobs.go @@ -0,0 +1,57 @@ +package agent + +import "sync" + +// Job is a request to the agent to do something +type Job struct { + // The job is a request to the agent to do something + // It can be a question, a command, or a request to do something + // The agent will try to do it, and return a response + Text string + Image string // base64 encoded image + Result *JobResult +} + +// JobResult is the result of a job +type JobResult struct { + sync.Mutex + // The result of a job + Text string + ready chan bool +} + +// NewJobResult creates a new job result +func NewJobResult() *JobResult { + return &JobResult{ + ready: make(chan bool), + } +} + +// NewJob creates a new job +// It is a request to the agent to do something +// It has a JobResult to get the result asynchronously +// To wait for a Job result, use JobResult.WaitResult() +func NewJob(text, image string) *Job { + return &Job{ + Text: text, + Image: image, + Result: NewJobResult(), + } +} + +// SetResult sets the result of a job +func (j *JobResult) SetResult(text string) { + j.Lock() + defer j.Unlock() + + j.Text = text + close(j.ready) +} + +// WaitResult waits for the result of a job +func (j *JobResult) WaitResult() string { + <-j.ready + j.Lock() + defer j.Unlock() + return j.Text +} diff --git a/agent/constructor.go b/agent/options.go similarity index 72% rename from agent/constructor.go rename to agent/options.go index b4ce6c1..37bac26 100644 --- a/agent/constructor.go +++ b/agent/options.go @@ -1,12 +1,11 @@ package agent import ( - "fmt" - - "github.com/mudler/local-agent-framework/llm" - "github.com/sashabaranov/go-openai" + "context" + "strings" ) +type Option func(*options) error type llmOptions struct { APIURL string APIKey string @@ -18,16 +17,9 @@ type options struct { character Character randomIdentityGuidance string randomIdentity bool + context context.Context } -type Agent struct { - options *options - Character Character - client *openai.Client -} - -type Option func(*options) error - func defaultOptions() *options { return &options{ LLMAPI: llmOptions{ @@ -58,26 +50,6 @@ func newOptions(opts ...Option) (*options, error) { return options, nil } -func New(opts ...Option) (*Agent, error) { - options, err := newOptions(opts...) - if err != nil { - return nil, err - } - - client := llm.NewClient(options.LLMAPI.APIKey, options.LLMAPI.APIURL) - a := &Agent{ - options: options, - client: client, - Character: options.character, - } - - if a.options.randomIdentity { - err = a.generateIdentity(a.options.randomIdentityGuidance) - } - - return a, err -} - func WithLLMAPIURL(url string) Option { return func(o *options) error { o.LLMAPI.APIURL = url @@ -92,6 +64,13 @@ func WithLLMAPIKey(key string) Option { } } +func WithContext(ctx context.Context) Option { + return func(o *options) error { + o.context = ctx + return nil + } +} + func WithModel(model string) Option { return func(o *options) error { o.LLMAPI.Model = model @@ -119,7 +98,7 @@ func FromFile(path string) Option { func WithRandomIdentity(guidance ...string) Option { return func(o *options) error { - o.randomIdentityGuidance = fmt.Sprint(guidance) + o.randomIdentityGuidance = strings.Join(guidance, "") o.randomIdentity = true return nil } diff --git a/agent/state.go b/agent/state.go index ce9e9c4..8f73935 100644 --- a/agent/state.go +++ b/agent/state.go @@ -46,14 +46,14 @@ func (a *Agent) generateIdentity(guidance string) error { if guidance == "" { guidance = "Generate a random character for roleplaying." } - err := llm.GenerateJSONFromStruct(a.client, guidance, a.options.LLMAPI.Model, &a.options.character) + err := llm.GenerateJSONFromStruct(a.context.Context, a.client, guidance, a.options.LLMAPI.Model, &a.options.character) a.Character = a.options.character if err != nil { - return err + return fmt.Errorf("failed to generate JSON from structure: %v", err) } if !a.validCharacter() { - return fmt.Errorf("invalid character") + return fmt.Errorf("generated character is not valid ( guidance: %s ): %v", guidance, a.String()) } return nil } diff --git a/llm/json.go b/llm/json.go index 7d41915..80a87f2 100644 --- a/llm/json.go +++ b/llm/json.go @@ -9,7 +9,7 @@ import ( ) // generateAnswer generates an answer for the given text using the OpenAI API -func GenerateJSON(client *openai.Client, model, text string, i interface{}) error { +func GenerateJSON(ctx context.Context, client *openai.Client, model, text string, i interface{}) error { req := openai.ChatCompletionRequest{ ResponseFormat: &openai.ChatCompletionResponseFormat{Type: openai.ChatCompletionResponseFormatTypeJSONObject}, Model: model, @@ -22,7 +22,7 @@ func GenerateJSON(client *openai.Client, model, text string, i interface{}) erro }, } - resp, err := client.CreateChatCompletion(context.Background(), req) + resp, err := client.CreateChatCompletion(ctx, req) if err != nil { return fmt.Errorf("failed to generate answer: %v", err) } @@ -37,11 +37,11 @@ func GenerateJSON(client *openai.Client, model, text string, i interface{}) erro return nil } -func GenerateJSONFromStruct(client *openai.Client, guidance, model string, i interface{}) error { +func GenerateJSONFromStruct(ctx context.Context, client *openai.Client, guidance, model string, i interface{}) error { // TODO: use functions? exampleJSON, err := json.Marshal(i) if err != nil { return err } - return GenerateJSON(client, model, "Generate a character as JSON data. "+guidance+". This is the JSON fields that should contain: "+string(exampleJSON), i) + return GenerateJSON(ctx, client, model, "Generate a character as JSON data. "+guidance+". This is the JSON fields that should contain: "+string(exampleJSON), i) }