Compare commits

..

3 Commits

Author SHA1 Message Date
mudler
47492f890f chore(tests): add timeout 2025-04-22 21:46:25 +02:00
mudler
ddac344147 chore: better defaults for parallel jobs
Signed-off-by: mudler <mudler@localai.io>
2025-04-22 21:28:00 +02:00
Ettore Di Giacinto
25bb3fb123 feat: allow the agent to perform things concurrently (#74)
* feat: allow the agent to perform things concurrently

Signed-off-by: mudler <mudler@localai.io>

* Apply suggestions from code review

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* collect errors

Signed-off-by: mudler <mudler@localai.io>

---------

Signed-off-by: mudler <mudler@localai.io>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-04-22 16:49:28 +02:00
9 changed files with 94 additions and 103 deletions

View File

@@ -2,6 +2,7 @@ package agent
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"os" "os"
"sync" "sync"
@@ -989,7 +990,6 @@ func (a *Agent) periodicallyRun(timer *time.Timer) {
} }
func (a *Agent) Run() error { func (a *Agent) Run() error {
a.startNewConversationsConsumer() a.startNewConversationsConsumer()
xlog.Debug("Agent is now running", "agent", a.Character.Name) xlog.Debug("Agent is now running", "agent", a.Character.Name)
// The agent run does two things: // The agent run does two things:
@@ -1004,36 +1004,68 @@ func (a *Agent) Run() error {
// Expose a REST API to interact with the agent to ask it things // Expose a REST API to interact with the agent to ask it things
//todoTimer := time.NewTicker(a.options.periodicRuns)
timer := time.NewTimer(a.options.periodicRuns) timer := time.NewTimer(a.options.periodicRuns)
// we fire the periodicalRunner only once.
go a.periodicalRunRunner(timer)
var errs []error
var muErr sync.Mutex
var wg sync.WaitGroup
parallelJobs := a.options.parallelJobs
if a.options.parallelJobs == 0 {
parallelJobs = 1
}
for i := 0; i < parallelJobs; i++ {
xlog.Debug("Starting agent worker", "worker", i)
wg.Add(1)
go func() {
e := a.run(timer)
muErr.Lock()
errs = append(errs, e)
muErr.Unlock()
wg.Done()
}()
}
wg.Wait()
return errors.Join(errs...)
}
func (a *Agent) run(timer *time.Timer) error {
for { for {
xlog.Debug("Agent is now waiting for a new job", "agent", a.Character.Name) xlog.Debug("Agent is now waiting for a new job", "agent", a.Character.Name)
select { select {
case job := <-a.jobQueue: case job := <-a.jobQueue:
a.loop(timer, job) if !timer.Stop() {
<-timer.C
}
xlog.Debug("Agent is consuming a job", "agent", a.Character.Name, "job", job)
a.consumeJob(job, UserRole)
timer.Reset(a.options.periodicRuns)
case <-a.context.Done(): case <-a.context.Done():
// Agent has been canceled, return error // Agent has been canceled, return error
xlog.Warn("Agent has been canceled", "agent", a.Character.Name) xlog.Warn("Agent has been canceled", "agent", a.Character.Name)
return ErrContextCanceled return ErrContextCanceled
}
}
}
func (a *Agent) periodicalRunRunner(timer *time.Timer) {
for {
select {
case <-a.context.Done():
// Agent has been canceled, return error
xlog.Warn("periodicalRunner has been canceled", "agent", a.Character.Name)
return
case <-timer.C: case <-timer.C:
a.periodicallyRun(timer) a.periodicallyRun(timer)
} }
} }
} }
func (a *Agent) loop(timer *time.Timer, job *types.Job) {
// Remember always to reset the timer - if we don't the agent will stop..
defer timer.Reset(a.options.periodicRuns)
// Consume the job and generate a response
// TODO: Give a short-term memory to the agent
// stop and drain the timer
if !timer.Stop() {
<-timer.C
}
xlog.Debug("Agent is consuming a job", "agent", a.Character.Name, "job", job)
a.consumeJob(job, UserRole)
}
func (a *Agent) Observer() Observer { func (a *Agent) Observer() Observer {
return a.observer return a.observer
} }

View File

@@ -36,8 +36,7 @@ func NewSSEObserver(agent string, manager sse.Manager) *SSEObserver {
} }
func (s *SSEObserver) NewObservable() *types.Observable { func (s *SSEObserver) NewObservable() *types.Observable {
id := atomic.AddInt32(&s.maxID, 1) id := atomic.AddInt32(&s.maxID, 1)
return &types.Observable{ return &types.Observable{
ID: id - 1, ID: id - 1,
Agent: s.agent, Agent: s.agent,

View File

@@ -54,7 +54,8 @@ type options struct {
newConversationsSubscribers []func(openai.ChatCompletionMessage) newConversationsSubscribers []func(openai.ChatCompletionMessage)
observer Observer observer Observer
parallelJobs int
} }
func (o *options) SeparatedMultimodalModel() bool { func (o *options) SeparatedMultimodalModel() bool {
@@ -63,6 +64,7 @@ func (o *options) SeparatedMultimodalModel() bool {
func defaultOptions() *options { func defaultOptions() *options {
return &options{ return &options{
parallelJobs: 1,
periodicRuns: 15 * time.Minute, periodicRuns: 15 * time.Minute,
LLMAPI: llmOptions{ LLMAPI: llmOptions{
APIURL: "http://localhost:8080", APIURL: "http://localhost:8080",
@@ -138,6 +140,13 @@ func EnableKnowledgeBaseWithResults(results int) Option {
} }
} }
func WithParallelJobs(jobs int) Option {
return func(o *options) error {
o.parallelJobs = jobs
return nil
}
}
func WithNewConversationSubscriber(sub func(openai.ChatCompletionMessage)) Option { func WithNewConversationSubscriber(sub func(openai.ChatCompletionMessage)) Option {
return func(o *options) error { return func(o *options) error {
o.newConversationsSubscribers = append(o.newConversationsSubscribers, sub) o.newConversationsSubscribers = append(o.newConversationsSubscribers, sub)

View File

@@ -25,6 +25,7 @@ var _ = Describe("Agent test", func() {
agent, err = New( agent, err = New(
WithLLMAPIURL(apiURL), WithLLMAPIURL(apiURL),
WithModel(testModel), WithModel(testModel),
WithTimeout("10m"),
WithRandomIdentity(), WithRandomIdentity(),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())

View File

@@ -61,6 +61,7 @@ type AgentConfig struct {
SystemPrompt string `json:"system_prompt" form:"system_prompt"` SystemPrompt string `json:"system_prompt" form:"system_prompt"`
LongTermMemory bool `json:"long_term_memory" form:"long_term_memory"` LongTermMemory bool `json:"long_term_memory" form:"long_term_memory"`
SummaryLongTermMemory bool `json:"summary_long_term_memory" form:"summary_long_term_memory"` SummaryLongTermMemory bool `json:"summary_long_term_memory" form:"summary_long_term_memory"`
ParallelJobs int `json:"parallel_jobs" form:"parallel_jobs"`
} }
type AgentConfigMeta struct { type AgentConfigMeta struct {
@@ -260,6 +261,16 @@ func NewAgentConfigMeta(
Step: 1, Step: 1,
Tags: config.Tags{Section: "AdvancedSettings"}, Tags: config.Tags{Section: "AdvancedSettings"},
}, },
{
Name: "parallel_jobs",
Label: "Parallel Jobs",
Type: "number",
DefaultValue: 5,
Min: 1,
Step: 1,
HelpText: "Number of concurrent tasks that can run in parallel",
Tags: config.Tags{Section: "AdvancedSettings"},
},
}, },
MCPServers: []config.Field{ MCPServers: []config.Field{
{ {

View File

@@ -166,56 +166,7 @@ func (a *AgentPool) CreateAgent(name string, agentConfig *AgentConfig) error {
} }
}(a.pool[name]) }(a.pool[name])
return a.startAgentWithConfig(name, agentConfig, nil) return a.startAgentWithConfig(name, agentConfig)
}
func (a *AgentPool) RecreateAgent(name string, agentConfig *AgentConfig) error {
a.Lock()
defer a.Unlock()
oldAgent := a.agents[name]
var o *types.Observable
obs := oldAgent.Observer()
if obs != nil {
o = obs.NewObservable()
o.Name = "Restarting Agent"
o.Icon = "sync"
o.Creation = &types.Creation{}
obs.Update(*o)
}
stateFile, characterFile := a.stateFiles(name)
os.Remove(stateFile)
os.Remove(characterFile)
oldAgent.Stop()
a.pool[name] = *agentConfig
delete(a.agents, name)
if err := a.save(); err != nil {
if obs != nil {
o.Completion = &types.Completion{Error: err.Error()}
obs.Update(*o)
}
return err
}
if err := a.startAgentWithConfig(name, agentConfig, obs); err != nil {
if obs != nil {
o.Completion = &types.Completion{Error: err.Error()}
obs.Update(*o)
}
return err
}
if obs != nil {
o.Completion = &types.Completion{}
obs.Update(*o)
}
return nil
} }
func createAgentAvatar(APIURL, APIKey, model, imageModel, avatarDir string, agent AgentConfig) error { func createAgentAvatar(APIURL, APIKey, model, imageModel, avatarDir string, agent AgentConfig) error {
@@ -317,13 +268,8 @@ func (a *AgentPool) GetStatusHistory(name string) *Status {
return a.agentStatus[name] return a.agentStatus[name]
} }
func (a *AgentPool) startAgentWithConfig(name string, config *AgentConfig, obs Observer) error { func (a *AgentPool) startAgentWithConfig(name string, config *AgentConfig) error {
var manager sse.Manager manager := sse.NewManager(5)
if m, ok := a.managers[name]; ok {
manager = m
} else {
manager = sse.NewManager(5)
}
ctx := context.Background() ctx := context.Background()
model := a.defaultModel model := a.defaultModel
multimodalModel := a.defaultMultimodalModel multimodalModel := a.defaultMultimodalModel
@@ -385,10 +331,6 @@ func (a *AgentPool) startAgentWithConfig(name string, config *AgentConfig, obs O
// dynamicPrompts = append(dynamicPrompts, p.ToMap()) // dynamicPrompts = append(dynamicPrompts, p.ToMap())
// } // }
if obs == nil {
obs = NewSSEObserver(name, manager)
}
opts := []Option{ opts := []Option{
WithModel(model), WithModel(model),
WithLLMAPIURL(a.apiURL), WithLLMAPIURL(a.apiURL),
@@ -465,7 +407,7 @@ func (a *AgentPool) startAgentWithConfig(name string, config *AgentConfig, obs O
c.AgentResultCallback()(state) c.AgentResultCallback()(state)
} }
}), }),
WithObserver(obs), WithObserver(NewSSEObserver(name, manager)),
} }
if config.HUD { if config.HUD {
@@ -524,6 +466,10 @@ func (a *AgentPool) startAgentWithConfig(name string, config *AgentConfig, obs O
opts = append(opts, WithLoopDetectionSteps(config.LoopDetectionSteps)) opts = append(opts, WithLoopDetectionSteps(config.LoopDetectionSteps))
} }
if config.ParallelJobs > 0 {
opts = append(opts, WithParallelJobs(config.ParallelJobs))
}
xlog.Info("Starting agent", "name", name, "config", config) xlog.Info("Starting agent", "name", name, "config", config)
agent, err := New(opts...) agent, err := New(opts...)
@@ -568,7 +514,7 @@ func (a *AgentPool) StartAll() error {
if a.agents[name] != nil { // Agent already started if a.agents[name] != nil { // Agent already started
continue continue
} }
if err := a.startAgentWithConfig(name, &config, nil); err != nil { if err := a.startAgentWithConfig(name, &config); err != nil {
xlog.Error("Failed to start agent", "name", name, "error", err) xlog.Error("Failed to start agent", "name", name, "error", err)
} }
} }
@@ -606,7 +552,7 @@ func (a *AgentPool) Start(name string) error {
return nil return nil
} }
if config, ok := a.pool[name]; ok { if config, ok := a.pool[name]; ok {
return a.startAgentWithConfig(name, &config, nil) return a.startAgentWithConfig(name, &config)
} }
return fmt.Errorf("agent %s not found", name) return fmt.Errorf("agent %s not found", name)

View File

@@ -77,9 +77,8 @@ func (i *IRC) Start(a *agent.Agent) {
} }
i.conn.UseTLS = false i.conn.UseTLS = false
i.conn.AddCallback("001", func(e *irc.Event) { i.conn.AddCallback("001", func(e *irc.Event) {
xlog.Info("Connected to IRC server", "server", i.server, "arguments", e.Arguments) xlog.Info("Connected to IRC server", "server", i.server)
i.conn.Join(i.channel) i.conn.Join(i.channel)
i.nickname = e.Arguments[0]
xlog.Info("Joined channel", "channel", i.channel) xlog.Info("Joined channel", "channel", i.channel)
}) })
@@ -208,13 +207,6 @@ func (i *IRC) Start(a *agent.Agent) {
// Start the IRC client in a goroutine // Start the IRC client in a goroutine
go i.conn.Loop() go i.conn.Loop()
go func() {
select {
case <-a.Context().Done():
i.conn.Quit()
return
}
}()
} }
// IRCConfigMeta returns the metadata for IRC connector configuration fields // IRCConfigMeta returns the metadata for IRC connector configuration fields

View File

@@ -176,7 +176,17 @@ func (a *App) UpdateAgentConfig(pool *state.AgentPool) func(c *fiber.Ctx) error
return errorJSONMessage(c, err.Error()) return errorJSONMessage(c, err.Error())
} }
if err := pool.RecreateAgent(agentName, &newConfig); err != nil { // Remove the agent first
if err := pool.Remove(agentName); err != nil {
return errorJSONMessage(c, "Error removing agent: "+err.Error())
}
// Create agent with new config
if err := pool.CreateAgent(agentName, &newConfig); err != nil {
// Try to restore the old configuration if update fails
if restoreErr := pool.CreateAgent(agentName, oldConfig); restoreErr != nil {
return errorJSONMessage(c, fmt.Sprintf("Failed to update agent and restore failed: %v, %v", err, restoreErr))
}
return errorJSONMessage(c, "Error updating agent: "+err.Error()) return errorJSONMessage(c, "Error updating agent: "+err.Error())
} }

View File

@@ -99,17 +99,8 @@ function AgentStatus() {
creation: data.creation, creation: data.creation,
progress: data.progress, progress: data.progress,
completion: data.completion, completion: data.completion,
// children are always built client-side
}; };
// Events can be received out of order
if (data.creation)
updated.creation = data.creation;
if (data.completion)
updated.completion = data.completion;
if (data.parent_id && !prevMap[data.parent_id])
prevMap[data.parent_id] = {
id: data.parent_id,
name: "unknown",
};
const newMap = { ...prevMap, [data.id]: updated }; const newMap = { ...prevMap, [data.id]: updated };
setObservableTree(buildObservableTree(newMap)); setObservableTree(buildObservableTree(newMap));
return newMap; return newMap;