feat(agent): shared state, allow to track conversations globally (#148)

* feat(agent): shared state, allow to track conversations globally

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Cleanup

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* track conversations initiated by the bot

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2025-05-11 22:23:01 +02:00
committed by GitHub
parent 2b07dd79ec
commit c23e655f44
63 changed files with 290 additions and 316 deletions

View File

@@ -11,12 +11,14 @@ import (
"time"
"github.com/google/uuid"
"github.com/mudler/LocalAGI/core/conversations"
coreTypes "github.com/mudler/LocalAGI/core/types"
internalTypes "github.com/mudler/LocalAGI/core/types"
"github.com/mudler/LocalAGI/pkg/llm"
"github.com/mudler/LocalAGI/pkg/xlog"
"github.com/mudler/LocalAGI/services"
"github.com/mudler/LocalAGI/services/connectors"
"github.com/mudler/LocalAGI/webui/types"
"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/jsonschema"
@@ -33,6 +35,7 @@ type (
htmx *htmx.HTMX
config *Config
*fiber.App
sharedState *internalTypes.AgentSharedState
}
)
@@ -47,9 +50,10 @@ func NewApp(opts ...Option) *App {
})
a := &App{
htmx: htmx.New(),
config: config,
App: webapp,
htmx: htmx.New(),
config: config,
App: webapp,
sharedState: internalTypes.NewAgentSharedState(5 * time.Minute),
}
a.registerRoutes(config.Pool, webapp)
@@ -443,7 +447,7 @@ func (a *App) GetActionDefinition(pool *state.AgentPool) func(c *fiber.Ctx) erro
}
}
func (a *App) ExecuteAction(pool *state.AgentPool) func(c *fiber.Ctx) error {
func (app *App) ExecuteAction(pool *state.AgentPool) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
payload := struct {
Config map[string]string `json:"config"`
@@ -467,7 +471,7 @@ func (a *App) ExecuteAction(pool *state.AgentPool) func(c *fiber.Ctx) error {
ctx, cancel := context.WithTimeout(c.Context(), 200*time.Second)
defer cancel()
res, err := a.Run(ctx, payload.Params)
res, err := a.Run(ctx, app.sharedState, payload.Params)
if err != nil {
xlog.Error("Error running action", "error", err)
return errorJSONMessage(c, err.Error())
@@ -484,7 +488,7 @@ func (a *App) ListActions() func(c *fiber.Ctx) error {
}
}
func (a *App) Responses(pool *state.AgentPool, tracker *connectors.ConversationTracker[string]) func(c *fiber.Ctx) error {
func (a *App) Responses(pool *state.AgentPool, tracker *conversations.ConversationTracker[string]) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
var request types.RequestBody
if err := c.BodyParser(&request); err != nil {

View File

@@ -13,8 +13,8 @@ import (
fiber "github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/filesystem"
"github.com/gofiber/fiber/v2/middleware/keyauth"
"github.com/mudler/LocalAGI/core/conversations"
"github.com/mudler/LocalAGI/core/sse"
"github.com/mudler/LocalAGI/services/connectors"
"github.com/mudler/LocalAGI/core/state"
"github.com/mudler/LocalAGI/core/types"
@@ -138,7 +138,7 @@ func (app *App) registerRoutes(pool *state.AgentPool, webapp *fiber.App) {
webapp.Post("/api/chat/:name", app.Chat(pool))
conversationTracker := connectors.NewConversationTracker[string](app.config.ConversationStoreDuration)
conversationTracker := conversations.NewConversationTracker[string](app.config.ConversationStoreDuration)
webapp.Post("/v1/responses", app.Responses(pool, conversationTracker))
@@ -268,7 +268,7 @@ func (app *App) registerRoutes(pool *state.AgentPool, webapp *fiber.App) {
}
return c.JSON(fiber.Map{
"Name": name,
"Name": name,
"History": agent.Observer().History(),
})
})