feat(api): implement stateful responses api (#122)
* feat(api): implement stateful responses api Signed-off-by: mudler <mudler@localai.io> * fix(tests): align client to API changes Signed-off-by: mudler <mudler@localai.io> --------- Signed-off-by: mudler <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
86cb9f1282
commit
f3c06b1bfb
2
main.go
2
main.go
@@ -21,6 +21,7 @@ var localRAG = os.Getenv("LOCALAGENT_LOCALRAG_URL")
|
|||||||
var withLogs = os.Getenv("LOCALAGENT_ENABLE_CONVERSATIONS_LOGGING") == "true"
|
var withLogs = os.Getenv("LOCALAGENT_ENABLE_CONVERSATIONS_LOGGING") == "true"
|
||||||
var apiKeysEnv = os.Getenv("LOCALAGENT_API_KEYS")
|
var apiKeysEnv = os.Getenv("LOCALAGENT_API_KEYS")
|
||||||
var imageModel = os.Getenv("LOCALAGENT_IMAGE_MODEL")
|
var imageModel = os.Getenv("LOCALAGENT_IMAGE_MODEL")
|
||||||
|
var conversationDuration = os.Getenv("LOCALAGENT_CONVERSATION_DURATION")
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
if testModel == "" {
|
if testModel == "" {
|
||||||
@@ -73,6 +74,7 @@ func main() {
|
|||||||
// Create the application
|
// Create the application
|
||||||
app := webui.NewApp(
|
app := webui.NewApp(
|
||||||
webui.WithPool(pool),
|
webui.WithPool(pool),
|
||||||
|
webui.WithConversationStoreduration(conversationDuration),
|
||||||
webui.WithApiKeys(apiKeys...),
|
webui.WithApiKeys(apiKeys...),
|
||||||
webui.WithLLMAPIUrl(apiURL),
|
webui.WithLLMAPIUrl(apiURL),
|
||||||
webui.WithLLMAPIKey(apiKey),
|
webui.WithLLMAPIKey(apiKey),
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ func (c *Client) GetAgentConfig(name string) (*AgentConfig, error) {
|
|||||||
|
|
||||||
// CreateAgent creates a new agent with the given configuration
|
// CreateAgent creates a new agent with the given configuration
|
||||||
func (c *Client) CreateAgent(config *AgentConfig) error {
|
func (c *Client) CreateAgent(config *AgentConfig) error {
|
||||||
resp, err := c.doRequest(http.MethodPost, "/create", config)
|
resp, err := c.doRequest(http.MethodPost, "/api/agent/create", config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -96,7 +96,7 @@ func (c *Client) UpdateAgentConfig(name string, config *AgentConfig) error {
|
|||||||
|
|
||||||
// DeleteAgent removes an agent
|
// DeleteAgent removes an agent
|
||||||
func (c *Client) DeleteAgent(name string) error {
|
func (c *Client) DeleteAgent(name string) error {
|
||||||
path := fmt.Sprintf("/delete/%s", name)
|
path := fmt.Sprintf("/api/agent/%s", name)
|
||||||
resp, err := c.doRequest(http.MethodDelete, path, nil)
|
resp, err := c.doRequest(http.MethodDelete, path, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -116,7 +116,7 @@ func (c *Client) DeleteAgent(name string) error {
|
|||||||
|
|
||||||
// PauseAgent pauses an agent
|
// PauseAgent pauses an agent
|
||||||
func (c *Client) PauseAgent(name string) error {
|
func (c *Client) PauseAgent(name string) error {
|
||||||
path := fmt.Sprintf("/pause/%s", name)
|
path := fmt.Sprintf("/api/agent/pause/%s", name)
|
||||||
resp, err := c.doRequest(http.MethodPut, path, nil)
|
resp, err := c.doRequest(http.MethodPut, path, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -136,7 +136,7 @@ func (c *Client) PauseAgent(name string) error {
|
|||||||
|
|
||||||
// StartAgent starts a paused agent
|
// StartAgent starts a paused agent
|
||||||
func (c *Client) StartAgent(name string) error {
|
func (c *Client) StartAgent(name string) error {
|
||||||
path := fmt.Sprintf("/start/%s", name)
|
path := fmt.Sprintf("/api/agent/start/%s", name)
|
||||||
resp, err := c.doRequest(http.MethodPut, path, nil)
|
resp, err := c.doRequest(http.MethodPut, path, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -73,3 +73,12 @@ func (c *ConversationTracker[K]) AddMessage(key K, message openai.ChatCompletion
|
|||||||
c.currentconversation[key] = append(c.currentconversation[key], message)
|
c.currentconversation[key] = append(c.currentconversation[key], message)
|
||||||
c.lastMessageTime[key] = time.Now()
|
c.lastMessageTime[key] = time.Now()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *ConversationTracker[K]) SetConversation(key K, messages []openai.ChatCompletionMessage) {
|
||||||
|
// Lock the conversation mutex to update the conversation history
|
||||||
|
c.convMutex.Lock()
|
||||||
|
defer c.convMutex.Unlock()
|
||||||
|
|
||||||
|
c.currentconversation[key] = messages
|
||||||
|
c.lastMessageTime[key] = time.Now()
|
||||||
|
}
|
||||||
|
|||||||
25
webui/app.go
25
webui/app.go
@@ -10,11 +10,14 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
coreTypes "github.com/mudler/LocalAgent/core/types"
|
coreTypes "github.com/mudler/LocalAgent/core/types"
|
||||||
"github.com/mudler/LocalAgent/pkg/llm"
|
"github.com/mudler/LocalAgent/pkg/llm"
|
||||||
"github.com/mudler/LocalAgent/pkg/xlog"
|
"github.com/mudler/LocalAgent/pkg/xlog"
|
||||||
"github.com/mudler/LocalAgent/services"
|
"github.com/mudler/LocalAgent/services"
|
||||||
|
"github.com/mudler/LocalAgent/services/connectors"
|
||||||
"github.com/mudler/LocalAgent/webui/types"
|
"github.com/mudler/LocalAgent/webui/types"
|
||||||
|
"github.com/sashabaranov/go-openai"
|
||||||
"github.com/sashabaranov/go-openai/jsonschema"
|
"github.com/sashabaranov/go-openai/jsonschema"
|
||||||
|
|
||||||
"github.com/mudler/LocalAgent/core/sse"
|
"github.com/mudler/LocalAgent/core/sse"
|
||||||
@@ -405,7 +408,7 @@ func (a *App) ListActions() func(c *fiber.Ctx) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *App) Responses(pool *state.AgentPool) func(c *fiber.Ctx) error {
|
func (a *App) Responses(pool *state.AgentPool, tracker *connectors.ConversationTracker[string]) func(c *fiber.Ctx) error {
|
||||||
return func(c *fiber.Ctx) error {
|
return func(c *fiber.Ctx) error {
|
||||||
var request types.RequestBody
|
var request types.RequestBody
|
||||||
if err := c.BodyParser(&request); err != nil {
|
if err := c.BodyParser(&request); err != nil {
|
||||||
@@ -414,9 +417,15 @@ func (a *App) Responses(pool *state.AgentPool) func(c *fiber.Ctx) error {
|
|||||||
|
|
||||||
request.SetInputByType()
|
request.SetInputByType()
|
||||||
|
|
||||||
agentName := request.Model
|
var previousResponseID string
|
||||||
|
conv := []openai.ChatCompletionMessage{}
|
||||||
|
if request.PreviousResponseID != nil {
|
||||||
|
previousResponseID = *request.PreviousResponseID
|
||||||
|
conv = tracker.GetConversation(previousResponseID)
|
||||||
|
}
|
||||||
|
|
||||||
messages := request.ToChatCompletionMessages()
|
agentName := request.Model
|
||||||
|
messages := append(conv, request.ToChatCompletionMessages()...)
|
||||||
|
|
||||||
a := pool.GetAgent(agentName)
|
a := pool.GetAgent(agentName)
|
||||||
if a == nil {
|
if a == nil {
|
||||||
@@ -435,7 +444,17 @@ func (a *App) Responses(pool *state.AgentPool) func(c *fiber.Ctx) error {
|
|||||||
xlog.Info("we got a response from the agent", "agent", agentName, "response", res.Response)
|
xlog.Info("we got a response from the agent", "agent", agentName, "response", res.Response)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
conv = append(conv, openai.ChatCompletionMessage{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: res.Response,
|
||||||
|
})
|
||||||
|
|
||||||
|
id := uuid.New().String()
|
||||||
|
|
||||||
|
tracker.SetConversation(id, conv)
|
||||||
|
|
||||||
response := types.ResponseBody{
|
response := types.ResponseBody{
|
||||||
|
ID: id,
|
||||||
Object: "response",
|
Object: "response",
|
||||||
// "created_at": 1741476542,
|
// "created_at": 1741476542,
|
||||||
CreatedAt: time.Now().Unix(),
|
CreatedAt: time.Now().Unix(),
|
||||||
|
|||||||
@@ -1,15 +1,20 @@
|
|||||||
package webui
|
package webui
|
||||||
|
|
||||||
import "github.com/mudler/LocalAgent/core/state"
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAgent/core/state"
|
||||||
|
)
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
DefaultChunkSize int
|
DefaultChunkSize int
|
||||||
Pool *state.AgentPool
|
Pool *state.AgentPool
|
||||||
ApiKeys []string
|
ApiKeys []string
|
||||||
LLMAPIURL string
|
LLMAPIURL string
|
||||||
LLMAPIKey string
|
LLMAPIKey string
|
||||||
LLMModel string
|
LLMModel string
|
||||||
StateDir string
|
StateDir string
|
||||||
|
ConversationStoreDuration time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
type Option func(*Config)
|
type Option func(*Config)
|
||||||
@@ -20,6 +25,16 @@ func WithDefaultChunkSize(size int) Option {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithConversationStoreduration(duration string) Option {
|
||||||
|
return func(c *Config) {
|
||||||
|
d, err := time.ParseDuration(duration)
|
||||||
|
if err != nil {
|
||||||
|
d = 1 * time.Hour
|
||||||
|
}
|
||||||
|
c.ConversationStoreDuration = d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func WithStateDir(dir string) Option {
|
func WithStateDir(dir string) Option {
|
||||||
return func(c *Config) {
|
return func(c *Config) {
|
||||||
c.StateDir = dir
|
c.StateDir = dir
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ import (
|
|||||||
"github.com/gofiber/fiber/v2/middleware/filesystem"
|
"github.com/gofiber/fiber/v2/middleware/filesystem"
|
||||||
"github.com/gofiber/fiber/v2/middleware/keyauth"
|
"github.com/gofiber/fiber/v2/middleware/keyauth"
|
||||||
"github.com/mudler/LocalAgent/core/sse"
|
"github.com/mudler/LocalAgent/core/sse"
|
||||||
|
"github.com/mudler/LocalAgent/services/connectors"
|
||||||
|
|
||||||
"github.com/mudler/LocalAgent/core/state"
|
"github.com/mudler/LocalAgent/core/state"
|
||||||
"github.com/mudler/LocalAgent/core/types"
|
"github.com/mudler/LocalAgent/core/types"
|
||||||
"github.com/mudler/LocalAgent/pkg/xlog"
|
"github.com/mudler/LocalAgent/pkg/xlog"
|
||||||
@@ -75,12 +77,17 @@ func (app *App) registerRoutes(pool *state.AgentPool, webapp *fiber.App) {
|
|||||||
webapp.Post("/api/chat/:name", app.Chat(pool))
|
webapp.Post("/api/chat/:name", app.Chat(pool))
|
||||||
webapp.Post("/api/notify/:name", app.Notify(pool))
|
webapp.Post("/api/notify/:name", app.Notify(pool))
|
||||||
|
|
||||||
webapp.Post("/v1/responses", app.Responses(pool))
|
conversationTracker := connectors.NewConversationTracker[string](app.config.ConversationStoreDuration)
|
||||||
|
|
||||||
|
webapp.Post("/v1/responses", app.Responses(pool, conversationTracker))
|
||||||
|
|
||||||
// New API endpoints for getting and updating agent configuration
|
// New API endpoints for getting and updating agent configuration
|
||||||
webapp.Get("/api/agent/:name/config", app.GetAgentConfig(pool))
|
webapp.Get("/api/agent/:name/config", app.GetAgentConfig(pool))
|
||||||
webapp.Put("/api/agent/:name/config", app.UpdateAgentConfig(pool))
|
webapp.Put("/api/agent/:name/config", app.UpdateAgentConfig(pool))
|
||||||
|
|
||||||
|
// Metadata endpoint for agent configuration fields
|
||||||
|
webapp.Get("/api/agent/config/metadata", app.GetAgentConfigMeta())
|
||||||
|
|
||||||
// Add endpoint for getting agent config metadata
|
// Add endpoint for getting agent config metadata
|
||||||
webapp.Get("/api/meta/agent/config", app.GetAgentConfigMeta())
|
webapp.Get("/api/meta/agent/config", app.GetAgentConfigMeta())
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user