diff --git a/agent/agent.go b/agent/agent.go index a2d4321..106f5e3 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -32,6 +32,7 @@ type Agent struct { nextAction Action currentConversation Messages selfEvaluationInProgress bool + pause bool newConversations chan openai.ChatCompletionMessage } @@ -137,6 +138,24 @@ func (a *Agent) Stop() { a.context.Cancel() } +func (a *Agent) Pause() { + a.Lock() + defer a.Unlock() + a.pause = true +} + +func (a *Agent) Resume() { + a.Lock() + defer a.Unlock() + a.pause = false +} + +func (a *Agent) Paused() bool { + a.Lock() + defer a.Unlock() + return a.pause +} + func (a *Agent) runAction(chosenAction Action, decisionResult *decisionResult) (result string, err error) { for _, action := range a.systemInternalActions() { if action.Definition().Name == chosenAction.Definition().Name { @@ -174,6 +193,18 @@ func (a *Agent) runAction(chosenAction Action, decisionResult *decisionResult) ( } func (a *Agent) consumeJob(job *Job, role string) { + a.Lock() + paused := a.pause + a.Unlock() + + if paused { + if a.options.debugMode { + fmt.Println("Agent is paused, skipping job") + } + job.Result.Finish(fmt.Errorf("agent is paused")) + return + } + // We are self evaluating if we consume the job as a system role selfEvaluation := role == SystemRole diff --git a/example/webui/app.go b/example/webui/app.go index 23e036b..63ece10 100644 --- a/example/webui/app.go +++ b/example/webui/app.go @@ -137,6 +137,27 @@ func (a *App) Delete(pool *AgentPool) func(c *fiber.Ctx) error { } } +func (a *App) Pause(pool *AgentPool) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + fmt.Println("Pausing agent", c.Params("name")) + agent := pool.GetAgent(c.Params("name")) + if agent != nil { + agent.Pause() + } + return c.Redirect("/agents") + } +} + +func (a *App) Start(pool *AgentPool) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + agent := pool.GetAgent(c.Params("name")) + if agent != nil { + agent.Resume() + } + return c.Redirect("/agents") + } +} + func (a *App) Create(pool *AgentPool) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { config := AgentConfig{} diff --git a/example/webui/routes.go b/example/webui/routes.go index 22cb22e..a580d81 100644 --- a/example/webui/routes.go +++ b/example/webui/routes.go @@ -24,8 +24,13 @@ func RegisterRoutes(webapp *fiber.App, pool *AgentPool, db *InMemoryDatabase, ap }) webapp.Get("/agents", func(c *fiber.Ctx) error { + statuses := map[string]bool{} + for _, a := range pool.List() { + statuses[a] = !pool.GetAgent(a).Paused() + } return c.Render("views/agents", fiber.Map{ "Agents": pool.List(), + "Status": statuses, }) }) @@ -72,6 +77,8 @@ func RegisterRoutes(webapp *fiber.App, pool *AgentPool, db *InMemoryDatabase, ap webapp.Post("/chat/:name", app.Chat(pool)) webapp.Post("/create", app.Create(pool)) webapp.Get("/delete/:name", app.Delete(pool)) + webapp.Put("/pause/:name", app.Pause(pool)) + webapp.Put("/start/:name", app.Start(pool)) webapp.Post("/knowledgebase", app.KnowledgeBase(db)) webapp.Post("/knowledgebase/upload", app.KnowledgeBaseFile(db)) diff --git a/example/webui/views/agents.html b/example/webui/views/agents.html index f2cdcb2..35d2b1f 100644 --- a/example/webui/views/agents.html +++ b/example/webui/views/agents.html @@ -40,6 +40,12 @@