fix: mixed fixups and enhancements (#107)

* chore(Makefile): build react dist if missing

Signed-off-by: mudler <mudler@localai.io>

* fix(planning): don't loose results

Signed-off-by: mudler <mudler@localai.io>

* fix(slack): track user messages when writing on channel

Signed-off-by: mudler <mudler@localai.io>

---------

Signed-off-by: mudler <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2025-03-26 17:05:59 +01:00
committed by GitHub
parent 3e36b09376
commit 5cd0eaae3f
6 changed files with 93 additions and 20 deletions

View File

@@ -1,5 +1,6 @@
GOCMD?=go GOCMD?=go
IMAGE_NAME?=webui IMAGE_NAME?=webui
ROOT_DIR:=$(shell dirname $(realpath $(lastword $(MAKEFILE_LIST))))
prepare-tests: prepare-tests:
docker compose up -d docker compose up -d
@@ -13,12 +14,15 @@ tests: prepare-tests
run-nokb: run-nokb:
$(MAKE) run KBDISABLEINDEX=true $(MAKE) run KBDISABLEINDEX=true
webui/react-ui/dist:
docker run --entrypoint /bin/bash -v $(ROOT_DIR):/app oven/bun:1 -c "cd /app/webui/react-ui && bun install && bun run build"
.PHONY: build .PHONY: build
build: build: webui/react-ui/dist
$(GOCMD) build -o localagent ./ $(GOCMD) build -o localagent ./
.PHONY: run .PHONY: run
run: run: webui/react-ui/dist
$(GOCMD) run ./ $(GOCMD) run ./
build-image: build-image:

View File

@@ -174,17 +174,17 @@ func (a *Agent) generateParameters(ctx context.Context, pickTemplate string, act
) )
} }
func (a *Agent) handlePlanning(ctx context.Context, job *types.Job, chosenAction types.Action, actionParams types.ActionParams, reasoning string, pickTemplate string, conv Messages) error { func (a *Agent) handlePlanning(ctx context.Context, job *types.Job, chosenAction types.Action, actionParams types.ActionParams, reasoning string, pickTemplate string, conv Messages) (Messages, error) {
// Planning: run all the actions in sequence // Planning: run all the actions in sequence
if !chosenAction.Definition().Name.Is(action.PlanActionName) { if !chosenAction.Definition().Name.Is(action.PlanActionName) {
xlog.Debug("no plan action") xlog.Debug("no plan action")
return nil return conv, nil
} }
xlog.Debug("[planning]...") xlog.Debug("[planning]...")
planResult := action.PlanResult{} planResult := action.PlanResult{}
if err := actionParams.Unmarshal(&planResult); err != nil { if err := actionParams.Unmarshal(&planResult); err != nil {
return fmt.Errorf("error unmarshalling plan result: %w", err) return conv, fmt.Errorf("error unmarshalling plan result: %w", err)
} }
stateResult := types.ActionState{ stateResult := types.ActionState{
@@ -207,7 +207,7 @@ func (a *Agent) handlePlanning(ctx context.Context, job *types.Job, chosenAction
} }
if len(planResult.Subtasks) == 0 { if len(planResult.Subtasks) == 0 {
return fmt.Errorf("no subtasks") return conv, fmt.Errorf("no subtasks")
} }
// Execute all subtasks in sequence // Execute all subtasks in sequence
@@ -223,7 +223,7 @@ func (a *Agent) handlePlanning(ctx context.Context, job *types.Job, chosenAction
params, err := a.generateParameters(ctx, pickTemplate, subTaskAction, conv, subTaskReasoning) params, err := a.generateParameters(ctx, pickTemplate, subTaskAction, conv, subTaskReasoning)
if err != nil { if err != nil {
return fmt.Errorf("error generating action's parameters: %w", err) return conv, fmt.Errorf("error generating action's parameters: %w", err)
} }
actionParams = params.actionParams actionParams = params.actionParams
@@ -252,7 +252,7 @@ func (a *Agent) handlePlanning(ctx context.Context, job *types.Job, chosenAction
result, err := a.runAction(subTaskAction, actionParams) result, err := a.runAction(subTaskAction, actionParams)
if err != nil { if err != nil {
return fmt.Errorf("error running action: %w", err) return conv, fmt.Errorf("error running action: %w", err)
} }
stateResult := types.ActionState{ stateResult := types.ActionState{
@@ -270,7 +270,7 @@ func (a *Agent) handlePlanning(ctx context.Context, job *types.Job, chosenAction
conv = a.addFunctionResultToConversation(subTaskAction, actionParams, result, conv) conv = a.addFunctionResultToConversation(subTaskAction, actionParams, result, conv)
} }
return nil return conv, nil
} }
func (a *Agent) availableActions() types.Actions { func (a *Agent) availableActions() types.Actions {

View File

@@ -528,7 +528,9 @@ func (a *Agent) consumeJob(job *types.Job, role string) {
return return
} }
if err := a.handlePlanning(ctx, job, chosenAction, actionParams, reasoning, pickTemplate, conv); err != nil { var err error
conv, err = a.handlePlanning(ctx, job, chosenAction, actionParams, reasoning, pickTemplate, conv)
if err != nil {
job.Result.Finish(fmt.Errorf("error running action: %w", err)) job.Result.Finish(fmt.Errorf("error running action: %w", err))
return return
} }

View File

@@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"io/ioutil"
"log" "log"
"os" "os"
"strings" "strings"
@@ -38,6 +37,11 @@ type Slack struct {
placeholderMutex sync.RWMutex placeholderMutex sync.RWMutex
apiClient *slack.Client apiClient *slack.Client
// Track active jobs for cancellation
activeJobs map[string]bool // map[channelID]bool to track if a channel has active processing
activeJobsMutex sync.RWMutex
agent *agent.Agent // Reference to the agent to call StopAction
conversationTracker *ConversationTracker[string] conversationTracker *ConversationTracker[string]
} }
@@ -57,13 +61,20 @@ func NewSlack(config map[string]string) *Slack {
alwaysReply: config["alwaysReply"] == "true", alwaysReply: config["alwaysReply"] == "true",
conversationTracker: NewConversationTracker[string](duration), conversationTracker: NewConversationTracker[string](duration),
placeholders: make(map[string]string), placeholders: make(map[string]string),
activeJobs: make(map[string]bool),
} }
} }
func (t *Slack) AgentResultCallback() func(state types.ActionState) { func (t *Slack) AgentResultCallback() func(state types.ActionState) {
return func(state types.ActionState) { return func(state types.ActionState) {
// The final result callback is intentionally empty as we're handling // Mark the job as completed when we get the final result
// the final update in the handleMention function directly if state.ActionCurrentState.Job != nil && state.ActionCurrentState.Job.Metadata != nil {
if channel, ok := state.ActionCurrentState.Job.Metadata["channel"].(string); ok && channel != "" {
t.activeJobsMutex.Lock()
delete(t.activeJobs, channel)
t.activeJobsMutex.Unlock()
}
}
} }
} }
@@ -102,6 +113,23 @@ func (t *Slack) AgentReasoningCallback() func(state types.ActionCurrentState) bo
} }
} }
// cancelActiveJobForChannel cancels any active job for the given channel
func (t *Slack) cancelActiveJobForChannel(channelID string) {
t.activeJobsMutex.RLock()
isActive := t.activeJobs[channelID]
t.activeJobsMutex.RUnlock()
if isActive && t.agent != nil {
xlog.Info(fmt.Sprintf("Cancelling active job for channel: %s", channelID))
t.agent.StopAction()
// Mark the job as inactive
t.activeJobsMutex.Lock()
delete(t.activeJobs, channelID)
t.activeJobsMutex.Unlock()
}
}
func cleanUpUsernameFromMessage(message string, b *slack.AuthTestResponse) string { func cleanUpUsernameFromMessage(message string, b *slack.AuthTestResponse) string {
cleaned := strings.ReplaceAll(message, "<@"+b.UserID+">", "") cleaned := strings.ReplaceAll(message, "<@"+b.UserID+">", "")
cleaned = strings.ReplaceAll(cleaned, "<@"+b.BotID+">", "") cleaned = strings.ReplaceAll(cleaned, "<@"+b.BotID+">", "")
@@ -214,11 +242,26 @@ func (t *Slack) handleChannelMessage(
return return
} }
// Cancel any active job for this channel before starting a new one
t.cancelActiveJobForChannel(ev.Channel)
currentConv := t.conversationTracker.GetConversation(t.channelID) currentConv := t.conversationTracker.GetConversation(t.channelID)
message := replaceUserIDsWithNamesInMessage(api, cleanUpUsernameFromMessage(ev.Text, b)) message := replaceUserIDsWithNamesInMessage(api, cleanUpUsernameFromMessage(ev.Text, b))
go func() { go func() {
// Mark this channel as having an active job
t.activeJobsMutex.Lock()
t.activeJobs[ev.Channel] = true
t.activeJobsMutex.Unlock()
defer func() {
// Mark job as complete
t.activeJobsMutex.Lock()
delete(t.activeJobs, ev.Channel)
t.activeJobsMutex.Unlock()
}()
imageBytes, mimeType := scanImagesInMessages(api, ev) imageBytes, mimeType := scanImagesInMessages(api, ev)
agentOptions := []types.JobOption{ agentOptions := []types.JobOption{
@@ -257,8 +300,18 @@ func (t *Slack) handleChannelMessage(
}) })
} }
t.conversationTracker.AddMessage(
t.channelID, currentConv[len(currentConv)-1],
)
agentOptions = append(agentOptions, types.WithConversationHistory(currentConv)) agentOptions = append(agentOptions, types.WithConversationHistory(currentConv))
// Add channel to metadata for tracking
metadata := map[string]interface{}{
"channel": ev.Channel,
}
agentOptions = append(agentOptions, types.WithMetadata(metadata))
res := a.Ask( res := a.Ask(
agentOptions..., agentOptions...,
) )
@@ -292,9 +345,6 @@ func (t *Slack) handleChannelMessage(
// Function to download the image from a URL and encode it to base64 // Function to download the image from a URL and encode it to base64
func encodeImageFromURL(imageBytes bytes.Buffer) (string, error) { func encodeImageFromURL(imageBytes bytes.Buffer) (string, error) {
// WRITE THIS SOMEWHERE
ioutil.WriteFile("image.jpg", imageBytes.Bytes(), 0644)
// Encode the image data to base64 // Encode the image data to base64
base64Image := base64.StdEncoding.EncodeToString(imageBytes.Bytes()) base64Image := base64.StdEncoding.EncodeToString(imageBytes.Bytes())
return base64Image, nil return base64Image, nil
@@ -639,11 +689,15 @@ func (t *Slack) handleMention(
} }
func (t *Slack) Start(a *agent.Agent) { func (t *Slack) Start(a *agent.Agent) {
postMessageParams := slack.PostMessageParameters{ postMessageParams := slack.PostMessageParameters{
LinkNames: 1, LinkNames: 1,
Markdown: true, Markdown: true,
} }
// Store the agent reference for use in cancellation
t.agent = a
api := slack.New( api := slack.New(
t.botToken, t.botToken,
// slack.OptionDebug(true), // slack.OptionDebug(true),

View File

@@ -3,6 +3,7 @@ package connectors
import ( import (
"context" "context"
"errors" "errors"
"net/http"
"os" "os"
"os/signal" "os/signal"
"slices" "slices"
@@ -99,12 +100,24 @@ func (t *Telegram) handleUpdate(ctx context.Context, b *bot.Bot, a *agent.Agent,
// coming from the gen image actions // coming from the gen image actions
if imagesUrls, exists := res.Metadata[actions.MetadataImages]; exists { if imagesUrls, exists := res.Metadata[actions.MetadataImages]; exists {
for _, url := range xstrings.UniqueSlice(imagesUrls.([]string)) { for _, url := range xstrings.UniqueSlice(imagesUrls.([]string)) {
b.SendPhoto(ctx, &bot.SendPhotoParams{ xlog.Debug("Sending photo", "url", url)
resp, err := http.Get(url)
if err != nil {
xlog.Error("Error downloading image", "error", err.Error())
continue
}
defer resp.Body.Close()
_, err = b.SendPhoto(ctx, &bot.SendPhotoParams{
ChatID: update.Message.Chat.ID, ChatID: update.Message.Chat.ID,
Photo: models.InputFileString{ Photo: models.InputFileUpload{
Data: url, Filename: "image.jpg",
Data: resp.Body,
}, },
}) })
if err != nil {
xlog.Error("Error sending photo", "error", err.Error())
}
} }
} }