Move action context to the job
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
@@ -250,7 +250,7 @@ func (a *Agent) handlePlanning(ctx context.Context, job *types.Job, chosenAction
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := a.runAction(subTaskAction, actionParams)
|
result, err := a.runAction(ctx, subTaskAction, actionParams)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return conv, fmt.Errorf("error running action: %w", err)
|
return conv, fmt.Errorf("error running action: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -27,7 +27,6 @@ type Agent struct {
|
|||||||
Character Character
|
Character Character
|
||||||
client *openai.Client
|
client *openai.Client
|
||||||
jobQueue chan *types.Job
|
jobQueue chan *types.Job
|
||||||
actionContext *types.ActionContext
|
|
||||||
context *types.ActionContext
|
context *types.ActionContext
|
||||||
|
|
||||||
currentReasoning string
|
currentReasoning string
|
||||||
@@ -136,26 +135,11 @@ func (a *Agent) AddSubscriber(f func(openai.ChatCompletionMessage)) {
|
|||||||
a.newMessagesSubscribers = append(a.newMessagesSubscribers, f)
|
a.newMessagesSubscribers = append(a.newMessagesSubscribers, f)
|
||||||
}
|
}
|
||||||
|
|
||||||
// StopAction stops the current action
|
|
||||||
// if any. Can be called before adding a new job.
|
|
||||||
func (a *Agent) StopAction() {
|
|
||||||
a.Lock()
|
|
||||||
defer a.Unlock()
|
|
||||||
if a.actionContext != nil {
|
|
||||||
xlog.Debug("Stopping current action", "agent", a.Character.Name)
|
|
||||||
a.actionContext.Cancel()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Agent) Context() context.Context {
|
func (a *Agent) Context() context.Context {
|
||||||
return a.context.Context
|
return a.context.Context
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Agent) ActionContext() context.Context {
|
// Ask is a blocking call that returns the response as soon as it's ready.
|
||||||
return a.actionContext.Context
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ask is a pre-emptive, blocking call that returns the response as soon as it's ready.
|
|
||||||
// It discards any other computation.
|
// It discards any other computation.
|
||||||
func (a *Agent) Ask(opts ...types.JobOption) *types.JobResult {
|
func (a *Agent) Ask(opts ...types.JobOption) *types.JobResult {
|
||||||
xlog.Debug("Agent Ask()", "agent", a.Character.Name, "model", a.options.LLMAPI.Model)
|
xlog.Debug("Agent Ask()", "agent", a.Character.Name, "model", a.options.LLMAPI.Model)
|
||||||
@@ -163,18 +147,34 @@ func (a *Agent) Ask(opts ...types.JobOption) *types.JobResult {
|
|||||||
xlog.Debug("Agent has finished being asked", "agent", a.Character.Name)
|
xlog.Debug("Agent has finished being asked", "agent", a.Character.Name)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
//a.StopAction()
|
return a.Execute(types.NewJob(
|
||||||
j := types.NewJob(
|
|
||||||
append(
|
append(
|
||||||
opts,
|
opts,
|
||||||
types.WithReasoningCallback(a.options.reasoningCallback),
|
types.WithReasoningCallback(a.options.reasoningCallback),
|
||||||
types.WithResultCallback(a.options.resultCallback),
|
types.WithResultCallback(a.options.resultCallback),
|
||||||
)...,
|
)...,
|
||||||
)
|
))
|
||||||
a.jobQueue <- j
|
}
|
||||||
|
|
||||||
|
// Ask is a pre-emptive, blocking call that returns the response as soon as it's ready.
|
||||||
|
// It discards any other computation.
|
||||||
|
func (a *Agent) Execute(j *types.Job) *types.JobResult {
|
||||||
|
xlog.Debug("Agent Execute()", "agent", a.Character.Name, "model", a.options.LLMAPI.Model)
|
||||||
|
defer func() {
|
||||||
|
xlog.Debug("Agent has finished", "agent", a.Character.Name)
|
||||||
|
}()
|
||||||
|
|
||||||
|
a.Enqueue(j)
|
||||||
return j.Result.WaitResult()
|
return j.Result.WaitResult()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Agent) Enqueue(j *types.Job) {
|
||||||
|
j.ReasoningCallback = a.options.reasoningCallback
|
||||||
|
j.ResultCallback = a.options.resultCallback
|
||||||
|
|
||||||
|
a.jobQueue <- j
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Agent) askLLM(ctx context.Context, conversation []openai.ChatCompletionMessage) (openai.ChatCompletionMessage, error) {
|
func (a *Agent) askLLM(ctx context.Context, conversation []openai.ChatCompletionMessage) (openai.ChatCompletionMessage, error) {
|
||||||
resp, err := a.client.CreateChatCompletion(ctx,
|
resp, err := a.client.CreateChatCompletion(ctx,
|
||||||
openai.ChatCompletionRequest{
|
openai.ChatCompletionRequest{
|
||||||
@@ -225,10 +225,10 @@ func (a *Agent) Memory() RAGDB {
|
|||||||
return a.options.ragdb
|
return a.options.ragdb
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Agent) runAction(chosenAction types.Action, params types.ActionParams) (result types.ActionResult, err error) {
|
func (a *Agent) runAction(ctx context.Context, chosenAction types.Action, params types.ActionParams) (result types.ActionResult, err error) {
|
||||||
for _, act := range a.availableActions() {
|
for _, act := range a.availableActions() {
|
||||||
if act.Definition().Name == chosenAction.Definition().Name {
|
if act.Definition().Name == chosenAction.Definition().Name {
|
||||||
res, err := act.Run(a.actionContext, params)
|
res, err := act.Run(ctx, params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.ActionResult{}, fmt.Errorf("error running action: %w", err)
|
return types.ActionResult{}, fmt.Errorf("error running action: %w", err)
|
||||||
}
|
}
|
||||||
@@ -407,20 +407,9 @@ func (a *Agent) consumeJob(job *types.Job, role string) {
|
|||||||
conv := job.ConversationHistory
|
conv := job.ConversationHistory
|
||||||
|
|
||||||
a.Lock()
|
a.Lock()
|
||||||
// Set the action context
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
a.actionContext = types.NewActionContext(ctx, cancel)
|
|
||||||
a.selfEvaluationInProgress = selfEvaluation
|
a.selfEvaluationInProgress = selfEvaluation
|
||||||
a.Unlock()
|
a.Unlock()
|
||||||
|
defer job.Cancel()
|
||||||
defer func() {
|
|
||||||
a.Lock()
|
|
||||||
if a.actionContext != nil {
|
|
||||||
a.actionContext.Cancel()
|
|
||||||
a.actionContext = nil
|
|
||||||
}
|
|
||||||
a.Unlock()
|
|
||||||
}()
|
|
||||||
|
|
||||||
if selfEvaluation {
|
if selfEvaluation {
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -463,7 +452,7 @@ func (a *Agent) consumeJob(job *types.Job, role string) {
|
|||||||
a.nextAction = nil
|
a.nextAction = nil
|
||||||
} else {
|
} else {
|
||||||
var err error
|
var err error
|
||||||
chosenAction, actionParams, reasoning, err = a.pickAction(ctx, pickTemplate, conv)
|
chosenAction, actionParams, reasoning, err = a.pickAction(job.GetContext(), pickTemplate, conv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
xlog.Error("Error picking action", "error", err)
|
xlog.Error("Error picking action", "error", err)
|
||||||
job.Result.Finish(err)
|
job.Result.Finish(err)
|
||||||
@@ -506,7 +495,7 @@ func (a *Agent) consumeJob(job *types.Job, role string) {
|
|||||||
"reasoning", reasoning,
|
"reasoning", reasoning,
|
||||||
)
|
)
|
||||||
|
|
||||||
params, err := a.generateParameters(ctx, pickTemplate, chosenAction, conv, reasoning)
|
params, err := a.generateParameters(job.GetContext(), pickTemplate, chosenAction, conv, reasoning)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
job.Result.Finish(fmt.Errorf("error generating action's parameters: %w", err))
|
job.Result.Finish(fmt.Errorf("error generating action's parameters: %w", err))
|
||||||
return
|
return
|
||||||
@@ -529,7 +518,7 @@ func (a *Agent) consumeJob(job *types.Job, role string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
conv, err = a.handlePlanning(ctx, job, chosenAction, actionParams, reasoning, pickTemplate, conv)
|
conv, err = a.handlePlanning(job.GetContext(), job, chosenAction, actionParams, reasoning, pickTemplate, conv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
job.Result.Finish(fmt.Errorf("error running action: %w", err))
|
job.Result.Finish(fmt.Errorf("error running action: %w", err))
|
||||||
return
|
return
|
||||||
@@ -584,7 +573,7 @@ func (a *Agent) consumeJob(job *types.Job, role string) {
|
|||||||
if !chosenAction.Definition().Name.Is(action.ReplyActionName) {
|
if !chosenAction.Definition().Name.Is(action.ReplyActionName) {
|
||||||
|
|
||||||
if !chosenAction.Definition().Name.Is(action.PlanActionName) {
|
if !chosenAction.Definition().Name.Is(action.PlanActionName) {
|
||||||
result, err := a.runAction(chosenAction, actionParams)
|
result, err := a.runAction(job.GetContext(), chosenAction, actionParams)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
//job.Result.Finish(fmt.Errorf("error running action: %w", err))
|
//job.Result.Finish(fmt.Errorf("error running action: %w", err))
|
||||||
//return
|
//return
|
||||||
@@ -613,7 +602,7 @@ func (a *Agent) consumeJob(job *types.Job, role string) {
|
|||||||
|
|
||||||
// given the result, we can now ask OpenAI to complete the conversation or
|
// given the result, we can now ask OpenAI to complete the conversation or
|
||||||
// to continue using another tool given the result
|
// to continue using another tool given the result
|
||||||
followingAction, followingParams, reasoning, err := a.pickAction(ctx, reEvaluationTemplate, conv)
|
followingAction, followingParams, reasoning, err := a.pickAction(job.GetContext(), reEvaluationTemplate, conv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
job.Result.Conversation = conv
|
job.Result.Conversation = conv
|
||||||
job.Result.Finish(fmt.Errorf("error picking action: %w", err))
|
job.Result.Finish(fmt.Errorf("error picking action: %w", err))
|
||||||
@@ -736,7 +725,7 @@ func (a *Agent) consumeJob(job *types.Job, role string) {
|
|||||||
|
|
||||||
xlog.Info("Reasoning, ask LLM for a reply", "agent", a.Character.Name)
|
xlog.Info("Reasoning, ask LLM for a reply", "agent", a.Character.Name)
|
||||||
xlog.Debug("Conversation", "conversation", fmt.Sprintf("%+v", conv))
|
xlog.Debug("Conversation", "conversation", fmt.Sprintf("%+v", conv))
|
||||||
msg, err := a.askLLM(ctx, conv)
|
msg, err := a.askLLM(job.GetContext(), conv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
job.Result.Conversation = conv
|
job.Result.Conversation = conv
|
||||||
job.Result.Finish(err)
|
job.Result.Finish(err)
|
||||||
@@ -793,7 +782,6 @@ func (a *Agent) periodicallyRun(timer *time.Timer) {
|
|||||||
// Remember always to reset the timer - if we don't the agent will stop..
|
// Remember always to reset the timer - if we don't the agent will stop..
|
||||||
defer timer.Reset(a.options.periodicRuns)
|
defer timer.Reset(a.options.periodicRuns)
|
||||||
|
|
||||||
a.StopAction()
|
|
||||||
xlog.Debug("Agent is running periodically", "agent", a.Character.Name)
|
xlog.Debug("Agent is running periodically", "agent", a.Character.Name)
|
||||||
|
|
||||||
// TODO: Would be nice if we have a special action to
|
// TODO: Would be nice if we have a special action to
|
||||||
@@ -902,6 +890,5 @@ func (a *Agent) loop(timer *time.Timer, job *types.Job) {
|
|||||||
<-timer.C
|
<-timer.C
|
||||||
}
|
}
|
||||||
xlog.Debug("Agent is consuming a job", "agent", a.Character.Name, "job", job)
|
xlog.Debug("Agent is consuming a job", "agent", a.Character.Name, "job", job)
|
||||||
a.StopAction()
|
|
||||||
a.consumeJob(job, UserRole)
|
a.consumeJob(job, UserRole)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package types
|
package types
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"log"
|
"log"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@@ -14,11 +15,14 @@ type Job struct {
|
|||||||
// It can be a question, a command, or a request to do something
|
// It can be a question, a command, or a request to do something
|
||||||
// The agent will try to do it, and return a response
|
// The agent will try to do it, and return a response
|
||||||
Result *JobResult
|
Result *JobResult
|
||||||
reasoningCallback func(ActionCurrentState) bool
|
ReasoningCallback func(ActionCurrentState) bool
|
||||||
resultCallback func(ActionState)
|
ResultCallback func(ActionState)
|
||||||
ConversationHistory []openai.ChatCompletionMessage
|
ConversationHistory []openai.ChatCompletionMessage
|
||||||
UUID string
|
UUID string
|
||||||
Metadata map[string]interface{}
|
Metadata map[string]interface{}
|
||||||
|
|
||||||
|
context context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
// JobResult is the result of a job
|
// JobResult is the result of a job
|
||||||
@@ -43,13 +47,13 @@ func WithConversationHistory(history []openai.ChatCompletionMessage) JobOption {
|
|||||||
|
|
||||||
func WithReasoningCallback(f func(ActionCurrentState) bool) JobOption {
|
func WithReasoningCallback(f func(ActionCurrentState) bool) JobOption {
|
||||||
return func(r *Job) {
|
return func(r *Job) {
|
||||||
r.reasoningCallback = f
|
r.ReasoningCallback = f
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithResultCallback(f func(ActionState)) JobOption {
|
func WithResultCallback(f func(ActionState)) JobOption {
|
||||||
return func(r *Job) {
|
return func(r *Job) {
|
||||||
r.resultCallback = f
|
r.ResultCallback = f
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -68,17 +72,17 @@ func NewJobResult() *JobResult {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (j *Job) Callback(stateResult ActionCurrentState) bool {
|
func (j *Job) Callback(stateResult ActionCurrentState) bool {
|
||||||
if j.reasoningCallback == nil {
|
if j.ReasoningCallback == nil {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
return j.reasoningCallback(stateResult)
|
return j.ReasoningCallback(stateResult)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (j *Job) CallbackWithResult(stateResult ActionState) {
|
func (j *Job) CallbackWithResult(stateResult ActionState) {
|
||||||
if j.resultCallback == nil {
|
if j.ResultCallback == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
j.resultCallback(stateResult)
|
j.ResultCallback(stateResult)
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithTextImage(text, image string) JobOption {
|
func WithTextImage(text, image string) JobOption {
|
||||||
@@ -134,6 +138,16 @@ func NewJob(opts ...JobOption) *Job {
|
|||||||
o(j)
|
o(j)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var ctx context.Context
|
||||||
|
if j.context == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
} else {
|
||||||
|
ctx = j.context
|
||||||
|
}
|
||||||
|
|
||||||
|
context, cancel := context.WithCancel(ctx)
|
||||||
|
j.context = context
|
||||||
|
j.cancel = cancel
|
||||||
return j
|
return j
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -142,3 +156,17 @@ func WithUUID(uuid string) JobOption {
|
|||||||
j.UUID = uuid
|
j.UUID = uuid
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithContext(ctx context.Context) JobOption {
|
||||||
|
return func(j *Job) {
|
||||||
|
j.context = ctx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (j *Job) Cancel() {
|
||||||
|
j.cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (j *Job) GetContext() context.Context {
|
||||||
|
return j.context
|
||||||
|
}
|
||||||
|
|||||||
@@ -38,9 +38,8 @@ type Slack struct {
|
|||||||
apiClient *slack.Client
|
apiClient *slack.Client
|
||||||
|
|
||||||
// Track active jobs for cancellation
|
// Track active jobs for cancellation
|
||||||
activeJobs map[string]bool // map[channelID]bool to track if a channel has active processing
|
activeJobs map[string][]*types.Job // map[channelID]bool to track if a channel has active processing
|
||||||
activeJobsMutex sync.RWMutex
|
activeJobsMutex sync.RWMutex
|
||||||
agent *agent.Agent // Reference to the agent to call StopAction
|
|
||||||
|
|
||||||
conversationTracker *ConversationTracker[string]
|
conversationTracker *ConversationTracker[string]
|
||||||
}
|
}
|
||||||
@@ -61,7 +60,7 @@ func NewSlack(config map[string]string) *Slack {
|
|||||||
alwaysReply: config["alwaysReply"] == "true",
|
alwaysReply: config["alwaysReply"] == "true",
|
||||||
conversationTracker: NewConversationTracker[string](duration),
|
conversationTracker: NewConversationTracker[string](duration),
|
||||||
placeholders: make(map[string]string),
|
placeholders: make(map[string]string),
|
||||||
activeJobs: make(map[string]bool),
|
activeJobs: make(map[string][]*types.Job),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -116,15 +115,17 @@ func (t *Slack) AgentReasoningCallback() func(state types.ActionCurrentState) bo
|
|||||||
// cancelActiveJobForChannel cancels any active job for the given channel
|
// cancelActiveJobForChannel cancels any active job for the given channel
|
||||||
func (t *Slack) cancelActiveJobForChannel(channelID string) {
|
func (t *Slack) cancelActiveJobForChannel(channelID string) {
|
||||||
t.activeJobsMutex.RLock()
|
t.activeJobsMutex.RLock()
|
||||||
isActive := t.activeJobs[channelID]
|
ctxs, exists := t.activeJobs[channelID]
|
||||||
t.activeJobsMutex.RUnlock()
|
t.activeJobsMutex.RUnlock()
|
||||||
|
|
||||||
if isActive && t.agent != nil {
|
if exists {
|
||||||
xlog.Info(fmt.Sprintf("Cancelling active job for channel: %s", channelID))
|
xlog.Info(fmt.Sprintf("Cancelling active job for channel: %s", channelID))
|
||||||
t.agent.StopAction()
|
|
||||||
|
|
||||||
// Mark the job as inactive
|
// Mark the job as inactive
|
||||||
t.activeJobsMutex.Lock()
|
t.activeJobsMutex.Lock()
|
||||||
|
for _, c := range ctxs {
|
||||||
|
c.Cancel()
|
||||||
|
}
|
||||||
delete(t.activeJobs, channelID)
|
delete(t.activeJobs, channelID)
|
||||||
t.activeJobsMutex.Unlock()
|
t.activeJobsMutex.Unlock()
|
||||||
}
|
}
|
||||||
@@ -250,17 +251,6 @@ func (t *Slack) handleChannelMessage(
|
|||||||
message := replaceUserIDsWithNamesInMessage(api, cleanUpUsernameFromMessage(ev.Text, b))
|
message := replaceUserIDsWithNamesInMessage(api, cleanUpUsernameFromMessage(ev.Text, b))
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
// Mark this channel as having an active job
|
|
||||||
t.activeJobsMutex.Lock()
|
|
||||||
t.activeJobs[ev.Channel] = true
|
|
||||||
t.activeJobsMutex.Unlock()
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
// Mark job as complete
|
|
||||||
t.activeJobsMutex.Lock()
|
|
||||||
delete(t.activeJobs, ev.Channel)
|
|
||||||
t.activeJobsMutex.Unlock()
|
|
||||||
}()
|
|
||||||
|
|
||||||
imageBytes, mimeType := scanImagesInMessages(api, ev)
|
imageBytes, mimeType := scanImagesInMessages(api, ev)
|
||||||
|
|
||||||
@@ -312,6 +302,27 @@ func (t *Slack) handleChannelMessage(
|
|||||||
}
|
}
|
||||||
agentOptions = append(agentOptions, types.WithMetadata(metadata))
|
agentOptions = append(agentOptions, types.WithMetadata(metadata))
|
||||||
|
|
||||||
|
job := types.NewJob(agentOptions...)
|
||||||
|
|
||||||
|
// Mark this channel as having an active job
|
||||||
|
t.activeJobsMutex.Lock()
|
||||||
|
t.activeJobs[ev.Channel] = append(t.activeJobs[ev.Channel], job)
|
||||||
|
t.activeJobsMutex.Unlock()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
// Mark job as complete
|
||||||
|
t.activeJobsMutex.Lock()
|
||||||
|
job.Cancel()
|
||||||
|
for i, j := range t.activeJobs[ev.Channel] {
|
||||||
|
if j.UUID == job.UUID {
|
||||||
|
t.activeJobs[ev.Channel] = append(t.activeJobs[ev.Channel][:i], t.activeJobs[ev.Channel][i+1:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.activeJobsMutex.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
res := a.Ask(
|
res := a.Ask(
|
||||||
agentOptions...,
|
agentOptions...,
|
||||||
)
|
)
|
||||||
@@ -695,9 +706,6 @@ func (t *Slack) Start(a *agent.Agent) {
|
|||||||
Markdown: true,
|
Markdown: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store the agent reference for use in cancellation
|
|
||||||
t.agent = a
|
|
||||||
|
|
||||||
api := slack.New(
|
api := slack.New(
|
||||||
t.botToken,
|
t.botToken,
|
||||||
// slack.OptionDebug(true),
|
// slack.OptionDebug(true),
|
||||||
|
|||||||
Reference in New Issue
Block a user