Move action context to the job

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2025-03-26 22:37:25 +01:00
parent 2713349c75
commit 6e888f6008
4 changed files with 100 additions and 77 deletions

View File

@@ -250,7 +250,7 @@ func (a *Agent) handlePlanning(ctx context.Context, job *types.Job, chosenAction
break break
} }
result, err := a.runAction(subTaskAction, actionParams) result, err := a.runAction(ctx, subTaskAction, actionParams)
if err != nil { if err != nil {
return conv, fmt.Errorf("error running action: %w", err) return conv, fmt.Errorf("error running action: %w", err)
} }

View File

@@ -23,12 +23,11 @@ const (
type Agent struct { type Agent struct {
sync.Mutex sync.Mutex
options *options options *options
Character Character Character Character
client *openai.Client client *openai.Client
jobQueue chan *types.Job jobQueue chan *types.Job
actionContext *types.ActionContext context *types.ActionContext
context *types.ActionContext
currentReasoning string currentReasoning string
currentState *action.AgentInternalState currentState *action.AgentInternalState
@@ -136,26 +135,11 @@ func (a *Agent) AddSubscriber(f func(openai.ChatCompletionMessage)) {
a.newMessagesSubscribers = append(a.newMessagesSubscribers, f) a.newMessagesSubscribers = append(a.newMessagesSubscribers, f)
} }
// StopAction stops the current action
// if any. Can be called before adding a new job.
func (a *Agent) StopAction() {
a.Lock()
defer a.Unlock()
if a.actionContext != nil {
xlog.Debug("Stopping current action", "agent", a.Character.Name)
a.actionContext.Cancel()
}
}
func (a *Agent) Context() context.Context { func (a *Agent) Context() context.Context {
return a.context.Context return a.context.Context
} }
func (a *Agent) ActionContext() context.Context { // Ask is a blocking call that returns the response as soon as it's ready.
return a.actionContext.Context
}
// Ask is a pre-emptive, blocking call that returns the response as soon as it's ready.
// It discards any other computation. // It discards any other computation.
func (a *Agent) Ask(opts ...types.JobOption) *types.JobResult { func (a *Agent) Ask(opts ...types.JobOption) *types.JobResult {
xlog.Debug("Agent Ask()", "agent", a.Character.Name, "model", a.options.LLMAPI.Model) xlog.Debug("Agent Ask()", "agent", a.Character.Name, "model", a.options.LLMAPI.Model)
@@ -163,18 +147,34 @@ func (a *Agent) Ask(opts ...types.JobOption) *types.JobResult {
xlog.Debug("Agent has finished being asked", "agent", a.Character.Name) xlog.Debug("Agent has finished being asked", "agent", a.Character.Name)
}() }()
//a.StopAction() return a.Execute(types.NewJob(
j := types.NewJob(
append( append(
opts, opts,
types.WithReasoningCallback(a.options.reasoningCallback), types.WithReasoningCallback(a.options.reasoningCallback),
types.WithResultCallback(a.options.resultCallback), types.WithResultCallback(a.options.resultCallback),
)..., )...,
) ))
a.jobQueue <- j }
// Ask is a pre-emptive, blocking call that returns the response as soon as it's ready.
// It discards any other computation.
func (a *Agent) Execute(j *types.Job) *types.JobResult {
xlog.Debug("Agent Execute()", "agent", a.Character.Name, "model", a.options.LLMAPI.Model)
defer func() {
xlog.Debug("Agent has finished", "agent", a.Character.Name)
}()
a.Enqueue(j)
return j.Result.WaitResult() return j.Result.WaitResult()
} }
func (a *Agent) Enqueue(j *types.Job) {
j.ReasoningCallback = a.options.reasoningCallback
j.ResultCallback = a.options.resultCallback
a.jobQueue <- j
}
func (a *Agent) askLLM(ctx context.Context, conversation []openai.ChatCompletionMessage) (openai.ChatCompletionMessage, error) { func (a *Agent) askLLM(ctx context.Context, conversation []openai.ChatCompletionMessage) (openai.ChatCompletionMessage, error) {
resp, err := a.client.CreateChatCompletion(ctx, resp, err := a.client.CreateChatCompletion(ctx,
openai.ChatCompletionRequest{ openai.ChatCompletionRequest{
@@ -225,10 +225,10 @@ func (a *Agent) Memory() RAGDB {
return a.options.ragdb return a.options.ragdb
} }
func (a *Agent) runAction(chosenAction types.Action, params types.ActionParams) (result types.ActionResult, err error) { func (a *Agent) runAction(ctx context.Context, chosenAction types.Action, params types.ActionParams) (result types.ActionResult, err error) {
for _, act := range a.availableActions() { for _, act := range a.availableActions() {
if act.Definition().Name == chosenAction.Definition().Name { if act.Definition().Name == chosenAction.Definition().Name {
res, err := act.Run(a.actionContext, params) res, err := act.Run(ctx, params)
if err != nil { if err != nil {
return types.ActionResult{}, fmt.Errorf("error running action: %w", err) return types.ActionResult{}, fmt.Errorf("error running action: %w", err)
} }
@@ -407,20 +407,9 @@ func (a *Agent) consumeJob(job *types.Job, role string) {
conv := job.ConversationHistory conv := job.ConversationHistory
a.Lock() a.Lock()
// Set the action context
ctx, cancel := context.WithCancel(context.Background())
a.actionContext = types.NewActionContext(ctx, cancel)
a.selfEvaluationInProgress = selfEvaluation a.selfEvaluationInProgress = selfEvaluation
a.Unlock() a.Unlock()
defer job.Cancel()
defer func() {
a.Lock()
if a.actionContext != nil {
a.actionContext.Cancel()
a.actionContext = nil
}
a.Unlock()
}()
if selfEvaluation { if selfEvaluation {
defer func() { defer func() {
@@ -463,7 +452,7 @@ func (a *Agent) consumeJob(job *types.Job, role string) {
a.nextAction = nil a.nextAction = nil
} else { } else {
var err error var err error
chosenAction, actionParams, reasoning, err = a.pickAction(ctx, pickTemplate, conv) chosenAction, actionParams, reasoning, err = a.pickAction(job.GetContext(), pickTemplate, conv)
if err != nil { if err != nil {
xlog.Error("Error picking action", "error", err) xlog.Error("Error picking action", "error", err)
job.Result.Finish(err) job.Result.Finish(err)
@@ -506,7 +495,7 @@ func (a *Agent) consumeJob(job *types.Job, role string) {
"reasoning", reasoning, "reasoning", reasoning,
) )
params, err := a.generateParameters(ctx, pickTemplate, chosenAction, conv, reasoning) params, err := a.generateParameters(job.GetContext(), pickTemplate, chosenAction, conv, reasoning)
if err != nil { if err != nil {
job.Result.Finish(fmt.Errorf("error generating action's parameters: %w", err)) job.Result.Finish(fmt.Errorf("error generating action's parameters: %w", err))
return return
@@ -529,7 +518,7 @@ func (a *Agent) consumeJob(job *types.Job, role string) {
} }
var err error var err error
conv, err = a.handlePlanning(ctx, job, chosenAction, actionParams, reasoning, pickTemplate, conv) conv, err = a.handlePlanning(job.GetContext(), job, chosenAction, actionParams, reasoning, pickTemplate, conv)
if err != nil { if err != nil {
job.Result.Finish(fmt.Errorf("error running action: %w", err)) job.Result.Finish(fmt.Errorf("error running action: %w", err))
return return
@@ -584,7 +573,7 @@ func (a *Agent) consumeJob(job *types.Job, role string) {
if !chosenAction.Definition().Name.Is(action.ReplyActionName) { if !chosenAction.Definition().Name.Is(action.ReplyActionName) {
if !chosenAction.Definition().Name.Is(action.PlanActionName) { if !chosenAction.Definition().Name.Is(action.PlanActionName) {
result, err := a.runAction(chosenAction, actionParams) result, err := a.runAction(job.GetContext(), chosenAction, actionParams)
if err != nil { if err != nil {
//job.Result.Finish(fmt.Errorf("error running action: %w", err)) //job.Result.Finish(fmt.Errorf("error running action: %w", err))
//return //return
@@ -613,7 +602,7 @@ func (a *Agent) consumeJob(job *types.Job, role string) {
// given the result, we can now ask OpenAI to complete the conversation or // given the result, we can now ask OpenAI to complete the conversation or
// to continue using another tool given the result // to continue using another tool given the result
followingAction, followingParams, reasoning, err := a.pickAction(ctx, reEvaluationTemplate, conv) followingAction, followingParams, reasoning, err := a.pickAction(job.GetContext(), reEvaluationTemplate, conv)
if err != nil { if err != nil {
job.Result.Conversation = conv job.Result.Conversation = conv
job.Result.Finish(fmt.Errorf("error picking action: %w", err)) job.Result.Finish(fmt.Errorf("error picking action: %w", err))
@@ -736,7 +725,7 @@ func (a *Agent) consumeJob(job *types.Job, role string) {
xlog.Info("Reasoning, ask LLM for a reply", "agent", a.Character.Name) xlog.Info("Reasoning, ask LLM for a reply", "agent", a.Character.Name)
xlog.Debug("Conversation", "conversation", fmt.Sprintf("%+v", conv)) xlog.Debug("Conversation", "conversation", fmt.Sprintf("%+v", conv))
msg, err := a.askLLM(ctx, conv) msg, err := a.askLLM(job.GetContext(), conv)
if err != nil { if err != nil {
job.Result.Conversation = conv job.Result.Conversation = conv
job.Result.Finish(err) job.Result.Finish(err)
@@ -793,7 +782,6 @@ func (a *Agent) periodicallyRun(timer *time.Timer) {
// Remember always to reset the timer - if we don't the agent will stop.. // Remember always to reset the timer - if we don't the agent will stop..
defer timer.Reset(a.options.periodicRuns) defer timer.Reset(a.options.periodicRuns)
a.StopAction()
xlog.Debug("Agent is running periodically", "agent", a.Character.Name) xlog.Debug("Agent is running periodically", "agent", a.Character.Name)
// TODO: Would be nice if we have a special action to // TODO: Would be nice if we have a special action to
@@ -902,6 +890,5 @@ func (a *Agent) loop(timer *time.Timer, job *types.Job) {
<-timer.C <-timer.C
} }
xlog.Debug("Agent is consuming a job", "agent", a.Character.Name, "job", job) xlog.Debug("Agent is consuming a job", "agent", a.Character.Name, "job", job)
a.StopAction()
a.consumeJob(job, UserRole) a.consumeJob(job, UserRole)
} }

View File

@@ -1,6 +1,7 @@
package types package types
import ( import (
"context"
"log" "log"
"sync" "sync"
@@ -14,11 +15,14 @@ type Job struct {
// It can be a question, a command, or a request to do something // It can be a question, a command, or a request to do something
// The agent will try to do it, and return a response // The agent will try to do it, and return a response
Result *JobResult Result *JobResult
reasoningCallback func(ActionCurrentState) bool ReasoningCallback func(ActionCurrentState) bool
resultCallback func(ActionState) ResultCallback func(ActionState)
ConversationHistory []openai.ChatCompletionMessage ConversationHistory []openai.ChatCompletionMessage
UUID string UUID string
Metadata map[string]interface{} Metadata map[string]interface{}
context context.Context
cancel context.CancelFunc
} }
// JobResult is the result of a job // JobResult is the result of a job
@@ -43,13 +47,13 @@ func WithConversationHistory(history []openai.ChatCompletionMessage) JobOption {
func WithReasoningCallback(f func(ActionCurrentState) bool) JobOption { func WithReasoningCallback(f func(ActionCurrentState) bool) JobOption {
return func(r *Job) { return func(r *Job) {
r.reasoningCallback = f r.ReasoningCallback = f
} }
} }
func WithResultCallback(f func(ActionState)) JobOption { func WithResultCallback(f func(ActionState)) JobOption {
return func(r *Job) { return func(r *Job) {
r.resultCallback = f r.ResultCallback = f
} }
} }
@@ -68,17 +72,17 @@ func NewJobResult() *JobResult {
} }
func (j *Job) Callback(stateResult ActionCurrentState) bool { func (j *Job) Callback(stateResult ActionCurrentState) bool {
if j.reasoningCallback == nil { if j.ReasoningCallback == nil {
return true return true
} }
return j.reasoningCallback(stateResult) return j.ReasoningCallback(stateResult)
} }
func (j *Job) CallbackWithResult(stateResult ActionState) { func (j *Job) CallbackWithResult(stateResult ActionState) {
if j.resultCallback == nil { if j.ResultCallback == nil {
return return
} }
j.resultCallback(stateResult) j.ResultCallback(stateResult)
} }
func WithTextImage(text, image string) JobOption { func WithTextImage(text, image string) JobOption {
@@ -134,6 +138,16 @@ func NewJob(opts ...JobOption) *Job {
o(j) o(j)
} }
var ctx context.Context
if j.context == nil {
ctx = context.Background()
} else {
ctx = j.context
}
context, cancel := context.WithCancel(ctx)
j.context = context
j.cancel = cancel
return j return j
} }
@@ -142,3 +156,17 @@ func WithUUID(uuid string) JobOption {
j.UUID = uuid j.UUID = uuid
} }
} }
func WithContext(ctx context.Context) JobOption {
return func(j *Job) {
j.context = ctx
}
}
func (j *Job) Cancel() {
j.cancel()
}
func (j *Job) GetContext() context.Context {
return j.context
}

View File

@@ -38,9 +38,8 @@ type Slack struct {
apiClient *slack.Client apiClient *slack.Client
// Track active jobs for cancellation // Track active jobs for cancellation
activeJobs map[string]bool // map[channelID]bool to track if a channel has active processing activeJobs map[string][]*types.Job // map[channelID]bool to track if a channel has active processing
activeJobsMutex sync.RWMutex activeJobsMutex sync.RWMutex
agent *agent.Agent // Reference to the agent to call StopAction
conversationTracker *ConversationTracker[string] conversationTracker *ConversationTracker[string]
} }
@@ -61,7 +60,7 @@ func NewSlack(config map[string]string) *Slack {
alwaysReply: config["alwaysReply"] == "true", alwaysReply: config["alwaysReply"] == "true",
conversationTracker: NewConversationTracker[string](duration), conversationTracker: NewConversationTracker[string](duration),
placeholders: make(map[string]string), placeholders: make(map[string]string),
activeJobs: make(map[string]bool), activeJobs: make(map[string][]*types.Job),
} }
} }
@@ -116,15 +115,17 @@ func (t *Slack) AgentReasoningCallback() func(state types.ActionCurrentState) bo
// cancelActiveJobForChannel cancels any active job for the given channel // cancelActiveJobForChannel cancels any active job for the given channel
func (t *Slack) cancelActiveJobForChannel(channelID string) { func (t *Slack) cancelActiveJobForChannel(channelID string) {
t.activeJobsMutex.RLock() t.activeJobsMutex.RLock()
isActive := t.activeJobs[channelID] ctxs, exists := t.activeJobs[channelID]
t.activeJobsMutex.RUnlock() t.activeJobsMutex.RUnlock()
if isActive && t.agent != nil { if exists {
xlog.Info(fmt.Sprintf("Cancelling active job for channel: %s", channelID)) xlog.Info(fmt.Sprintf("Cancelling active job for channel: %s", channelID))
t.agent.StopAction()
// Mark the job as inactive // Mark the job as inactive
t.activeJobsMutex.Lock() t.activeJobsMutex.Lock()
for _, c := range ctxs {
c.Cancel()
}
delete(t.activeJobs, channelID) delete(t.activeJobs, channelID)
t.activeJobsMutex.Unlock() t.activeJobsMutex.Unlock()
} }
@@ -250,17 +251,6 @@ func (t *Slack) handleChannelMessage(
message := replaceUserIDsWithNamesInMessage(api, cleanUpUsernameFromMessage(ev.Text, b)) message := replaceUserIDsWithNamesInMessage(api, cleanUpUsernameFromMessage(ev.Text, b))
go func() { go func() {
// Mark this channel as having an active job
t.activeJobsMutex.Lock()
t.activeJobs[ev.Channel] = true
t.activeJobsMutex.Unlock()
defer func() {
// Mark job as complete
t.activeJobsMutex.Lock()
delete(t.activeJobs, ev.Channel)
t.activeJobsMutex.Unlock()
}()
imageBytes, mimeType := scanImagesInMessages(api, ev) imageBytes, mimeType := scanImagesInMessages(api, ev)
@@ -312,6 +302,27 @@ func (t *Slack) handleChannelMessage(
} }
agentOptions = append(agentOptions, types.WithMetadata(metadata)) agentOptions = append(agentOptions, types.WithMetadata(metadata))
job := types.NewJob(agentOptions...)
// Mark this channel as having an active job
t.activeJobsMutex.Lock()
t.activeJobs[ev.Channel] = append(t.activeJobs[ev.Channel], job)
t.activeJobsMutex.Unlock()
defer func() {
// Mark job as complete
t.activeJobsMutex.Lock()
job.Cancel()
for i, j := range t.activeJobs[ev.Channel] {
if j.UUID == job.UUID {
t.activeJobs[ev.Channel] = append(t.activeJobs[ev.Channel][:i], t.activeJobs[ev.Channel][i+1:]...)
break
}
}
t.activeJobsMutex.Unlock()
}()
res := a.Ask( res := a.Ask(
agentOptions..., agentOptions...,
) )
@@ -695,9 +706,6 @@ func (t *Slack) Start(a *agent.Agent) {
Markdown: true, Markdown: true,
} }
// Store the agent reference for use in cancellation
t.agent = a
api := slack.New( api := slack.New(
t.botToken, t.botToken,
// slack.OptionDebug(true), // slack.OptionDebug(true),