From 3827ebebdfb3fd5a08918df8c133ebd95d64834a Mon Sep 17 00:00:00 2001 From: mudler Date: Sat, 8 Mar 2025 17:52:19 +0100 Subject: [PATCH] feat: add capability to understand images Signed-off-by: mudler --- README.md | 1 + core/agent/actions.go | 11 +++ core/agent/agent.go | 207 ++++++++++++++++++++++++++++++++++-------- core/agent/options.go | 19 +++- core/state/config.go | 1 + core/state/pool.go | 33 ++++--- main.go | 2 + 7 files changed, 218 insertions(+), 56 deletions(-) diff --git a/README.md b/README.md index 0fdf084..26a383e 100644 --- a/README.md +++ b/README.md @@ -78,6 +78,7 @@ LocalAgent can be configured using the following environment variables: | Variable | Description | |-------------------------------|--------------------------------------------------| | `LOCALAGENT_MODEL` | Specifies the test model to use | +| `LOCALAGENT_MULTIMODAL_MODEL` | Specifies a separate model to use with multimodal capabilities (optional, if LOCALAGENT_MODEL does not support multimodality) | | `LOCALAGENT_LLM_API_URL` | URL of the API server | | `LOCALAGENT_API_KEY` | API key for authentication | | `LOCALAGENT_TIMEOUT` | Timeout duration for requests | diff --git a/core/agent/actions.go b/core/agent/actions.go index a004586..2991ec1 100644 --- a/core/agent/actions.go +++ b/core/agent/actions.go @@ -139,6 +139,17 @@ func (m Messages) Save(path string) error { return nil } +func (m Messages) GetLatestUserMessage() *openai.ChatCompletionMessage { + for i := len(m) - 1; i >= 0; i-- { + msg := m[i] + if msg.Role == UserRole { + return &msg + } + } + + return nil +} + func (a *Agent) generateParameters(ctx context.Context, pickTemplate string, act Action, c []openai.ChatCompletionMessage, reasoning string) (*decisionResult, error) { stateHUD, err := renderTemplate(pickTemplate, a.prepareHUD(), a.systemInternalActions(), reasoning) diff --git a/core/agent/agent.go b/core/agent/agent.go index fef642c..9256119 100644 --- a/core/agent/agent.go +++ b/core/agent/agent.go @@ -249,6 +249,171 @@ func (a *Agent) runAction(chosenAction Action, params action.ActionParams) (resu return result, nil } +func (a *Agent) processPrompts() { + //if job.Image != "" { + // TODO: Use llava to explain the image content + //} + // Add custom prompts + for _, prompt := range a.options.prompts { + message, err := prompt.Render(a) + if err != nil { + xlog.Error("Error rendering prompt", "error", err) + continue + } + if message == "" { + xlog.Debug("Prompt is empty, skipping", "agent", a.Character.Name) + continue + } + if !Messages(a.currentConversation).Exist(a.options.systemPrompt) { + a.currentConversation = append([]openai.ChatCompletionMessage{ + { + Role: prompt.Role(), + Content: message, + }}, a.currentConversation...) + } + } + + // TODO: move to a Promptblock? + if a.options.systemPrompt != "" { + if !Messages(a.currentConversation).Exist(a.options.systemPrompt) { + a.currentConversation = append([]openai.ChatCompletionMessage{ + { + Role: "system", + Content: a.options.systemPrompt, + }}, a.currentConversation...) + } + } +} + +func (a *Agent) describeImage(ctx context.Context, model, imageURL string) (string, error) { + resp, err := a.client.CreateChatCompletion(ctx, + openai.ChatCompletionRequest{ + Model: model, Messages: []openai.ChatCompletionMessage{ + { + + Role: "user", + MultiContent: []openai.ChatMessagePart{ + { + Type: openai.ChatMessagePartTypeText, + Text: "What is in the image?", + }, + { + Type: openai.ChatMessagePartTypeImageURL, + ImageURL: &openai.ChatMessageImageURL{ + URL: imageURL, + }, + }, + }, + }, + }}) + if err != nil { + return "", err + } + if len(resp.Choices) == 0 { + return "", fmt.Errorf("no choices") + } + + return resp.Choices[0].Message.Content, nil +} + +func extractImageContent(message openai.ChatCompletionMessage) (imageURL, text string, e error) { + e = fmt.Errorf("no image found") + if message.MultiContent != nil { + for _, content := range message.MultiContent { + if content.Type == openai.ChatMessagePartTypeImageURL { + imageURL = content.ImageURL.URL + e = nil + } + if content.Type == openai.ChatMessagePartTypeText { + text = content.Text + e = nil + } + } + } + return +} + +func (a *Agent) processUserInputs(job *Job, role string) { + + noNewMessage := job.Text == "" && job.Image == "" + onlyText := job.Text != "" && job.Image == "" + + // walk conversation history, and check if last message from user contains image. + // If it does, we need to describe the image first with a model that supports image understanding (if the current model doesn't support it) + // and add it to the conversation context + if a.options.SeparatedMultimodalModel() && noNewMessage { + lastUserMessage := a.currentConversation.GetLatestUserMessage() + if lastUserMessage != nil { + imageURL, text, err := extractImageContent(*lastUserMessage) + if err == nil { + // We have an image, we need to describe it first + // and add it to the conversation context + imageDescription, err := a.describeImage(a.context.Context, a.options.LLMAPI.MultimodalModel, imageURL) + if err != nil { + xlog.Error("Error describing image", "error", err) + } else { + // We replace the user message with the image description + // and add the user text to the conversation + lastUserMessage.Content = fmt.Sprintf("The user shared an image which can be described as: %s", imageDescription) + lastUserMessage.MultiContent = nil + lastUserMessage.Role = "system" + a.currentConversation = append(a.currentConversation, openai.ChatCompletionMessage{ + Role: role, + Content: text, + }) + } + } + } + + } + + if onlyText { + a.currentConversation = append(a.currentConversation, openai.ChatCompletionMessage{ + Role: role, + Content: job.Text, + }) + } + + if job.Image != "" { + // If an image is present with the text + // we have two cases: if the model supports both images and text, we can send both + // if the model supports only text, we can send the text only and we need to describe the image first with a model that support image understanding and add it to the conversation context + if a.options.SeparatedMultimodalModel() { + // We need to describe the image first + imageDescription, err := a.describeImage(a.context.Context, a.options.LLMAPI.Model, job.Image) + if err != nil { + xlog.Error("Error describing image", "error", err) + } else { + a.currentConversation = append(a.currentConversation, openai.ChatCompletionMessage{ + Role: "system", + Content: fmt.Sprintf("The user shared an image which can be described as: %s", imageDescription), + }) + a.currentConversation = append(a.currentConversation, openai.ChatCompletionMessage{ + Role: role, + Content: job.Text, + }) + } + } else { + // Just append to the message both the image and the text + a.currentConversation = append(a.currentConversation, openai.ChatCompletionMessage{ + Role: role, + MultiContent: []openai.ChatMessagePart{ + { + Type: openai.ChatMessagePartTypeText, + Text: job.Text, + }, + { + Type: openai.ChatMessagePartTypeImageURL, + ImageURL: &openai.ChatMessageImageURL{ + URL: job.Image, + }, + }, + }, + }) + } + } +} + func (a *Agent) consumeJob(job *Job, role string) { a.Lock() paused := a.pause @@ -290,46 +455,8 @@ func (a *Agent) consumeJob(job *Job, role string) { }() } - //if job.Image != "" { - // TODO: Use llava to explain the image content - //} - // Add custom prompts - for _, prompt := range a.options.prompts { - message, err := prompt.Render(a) - if err != nil { - xlog.Error("Error rendering prompt", "error", err) - continue - } - if message == "" { - xlog.Debug("Prompt is empty, skipping", "agent", a.Character.Name) - continue - } - if !Messages(a.currentConversation).Exist(a.options.systemPrompt) { - a.currentConversation = append([]openai.ChatCompletionMessage{ - { - Role: prompt.Role(), - Content: message, - }}, a.currentConversation...) - } - } - - // TODO: move to a Promptblock? - if a.options.systemPrompt != "" { - if !Messages(a.currentConversation).Exist(a.options.systemPrompt) { - a.currentConversation = append([]openai.ChatCompletionMessage{ - { - Role: "system", - Content: a.options.systemPrompt, - }}, a.currentConversation...) - } - } - - if job.Text != "" { - a.currentConversation = append(a.currentConversation, openai.ChatCompletionMessage{ - Role: role, - Content: job.Text, - }) - } + a.processPrompts() + a.processUserInputs(job, role) // RAG a.knowledgeBaseLookup() diff --git a/core/agent/options.go b/core/agent/options.go index 2b93d9b..60b325a 100644 --- a/core/agent/options.go +++ b/core/agent/options.go @@ -7,10 +7,12 @@ import ( ) type Option func(*options) error + type llmOptions struct { - APIURL string - APIKey string - Model string + APIURL string + APIKey string + Model string + MultimodalModel string } type options struct { @@ -44,6 +46,10 @@ type options struct { conversationsPath string } +func (o *options) SeparatedMultimodalModel() bool { + return o.LLMAPI.MultimodalModel != "" && o.LLMAPI.Model != o.LLMAPI.MultimodalModel +} + func defaultOptions() *options { return &options{ periodicRuns: 15 * time.Minute, @@ -209,6 +215,13 @@ func WithLLMAPIKey(key string) Option { } } +func WithMultimodalModel(model string) Option { + return func(o *options) error { + o.LLMAPI.MultimodalModel = model + return nil + } +} + func WithPermanentGoal(goal string) Option { return func(o *options) error { o.permanentGoal = goal diff --git a/core/state/config.go b/core/state/config.go index 6db904b..732d6ae 100644 --- a/core/state/config.go +++ b/core/state/config.go @@ -34,6 +34,7 @@ type AgentConfig struct { // This is what needs to be part of ActionsConfig Model string `json:"model" form:"model"` + MultimodalModel string `json:"multimodal_model" form:"multimodal_model"` Name string `json:"name" form:"name"` HUD bool `json:"hud" form:"hud"` StandaloneJob bool `json:"standalone_job" form:"standalone_job"` diff --git a/core/state/pool.go b/core/state/pool.go index 1e31bb8..4ea38f6 100644 --- a/core/state/pool.go +++ b/core/state/pool.go @@ -21,18 +21,18 @@ import ( type AgentPool struct { sync.Mutex - file string - pooldir string - pool AgentPoolData - agents map[string]*Agent - managers map[string]sse.Manager - agentStatus map[string]*Status - apiURL, model, localRAGAPI, apiKey string - availableActions func(*AgentConfig) func(ctx context.Context) []Action - connectors func(*AgentConfig) []Connector - promptBlocks func(*AgentConfig) []PromptBlock - timeout string - conversationLogs string + file string + pooldir string + pool AgentPoolData + agents map[string]*Agent + managers map[string]sse.Manager + agentStatus map[string]*Status + apiURL, model, multimodalModel, localRAGAPI, apiKey string + availableActions func(*AgentConfig) func(ctx context.Context) []Action + connectors func(*AgentConfig) []Connector + promptBlocks func(*AgentConfig) []PromptBlock + timeout string + conversationLogs string } type Status struct { @@ -66,7 +66,7 @@ func loadPoolFromFile(path string) (*AgentPoolData, error) { } func NewAgentPool( - model, apiURL, apiKey, directory string, + model, multimodalModel, apiURL, apiKey, directory string, LocalRAGAPI string, availableActions func(*AgentConfig) func(ctx context.Context) []agent.Action, connectors func(*AgentConfig) []Connector, @@ -91,6 +91,7 @@ func NewAgentPool( pooldir: directory, apiURL: apiURL, model: model, + multimodalModel: multimodalModel, localRAGAPI: LocalRAGAPI, apiKey: apiKey, agents: make(map[string]*Agent), @@ -114,6 +115,7 @@ func NewAgentPool( apiURL: apiURL, pooldir: directory, model: model, + multimodalModel: multimodalModel, apiKey: apiKey, agents: make(map[string]*Agent), managers: make(map[string]sse.Manager), @@ -165,6 +167,10 @@ func (a *AgentPool) startAgentWithConfig(name string, config *AgentConfig) error manager := sse.NewManager(5) ctx := context.Background() model := a.model + multimodalModel := a.multimodalModel + if config.MultimodalModel != "" { + multimodalModel = config.MultimodalModel + } if config.Model != "" { model = config.Model } @@ -244,6 +250,7 @@ func (a *AgentPool) startAgentWithConfig(name string, config *AgentConfig) error return true }), WithSystemPrompt(config.SystemPrompt), + WithMultimodalModel(multimodalModel), WithAgentResultCallback(func(state ActionState) { a.Lock() if _, ok := a.agentStatus[name]; !ok { diff --git a/main.go b/main.go index f460dc0..c11ca00 100644 --- a/main.go +++ b/main.go @@ -11,6 +11,7 @@ import ( ) var testModel = os.Getenv("LOCALAGENT_MODEL") +var multimodalModel = os.Getenv("LOCALAGENT_MULTIMODAL_MODEL") var apiURL = os.Getenv("LOCALAGENT_LLM_API_URL") var apiKey = os.Getenv("LOCALAGENT_API_KEY") var timeout = os.Getenv("LOCALAGENT_TIMEOUT") @@ -45,6 +46,7 @@ func main() { // Create the agent pool pool, err := state.NewAgentPool( testModel, + multimodalModel, apiURL, apiKey, stateDir,