From a83f4512b6dd4e360163b3d340ffe2c367cd6cfa Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Wed, 19 Mar 2025 23:10:14 +0100 Subject: [PATCH] feat: allow to set LocalRAG API URL ad key (#61) Signed-off-by: Ettore Di Giacinto --- core/state/config.go | 9 +++- core/state/pool.go | 42 ++++++++++----- pkg/localrag/client.go | 77 +++++++++++++++++++++------- webui/views/partials/agent-form.html | 20 ++++++++ webui/views/settings.html | 4 ++ 5 files changed, 119 insertions(+), 33 deletions(-) diff --git a/core/state/config.go b/core/state/config.go index a93c752..e1692c3 100644 --- a/core/state/config.go +++ b/core/state/config.go @@ -35,8 +35,13 @@ type AgentConfig struct { Description string `json:"description" form:"description"` // This is what needs to be part of ActionsConfig - Model string `json:"model" form:"model"` - MultimodalModel string `json:"multimodal_model" form:"multimodal_model"` + Model string `json:"model" form:"model"` + MultimodalModel string `json:"multimodal_model" form:"multimodal_model"` + APIURL string `json:"api_url" form:"api_url"` + APIKey string `json:"api_key" form:"api_key"` + LocalRAGURL string `json:"local_rag_url" form:"local_rag_url"` + LocalRAGAPIKey string `json:"local_rag_api_key" form:"local_rag_api_key"` + Name string `json:"name" form:"name"` HUD bool `json:"hud" form:"hud"` StandaloneJob bool `json:"standalone_job" form:"standalone_job"` diff --git a/core/state/pool.go b/core/state/pool.go index b79d0b3..4efc4f2 100644 --- a/core/state/pool.go +++ b/core/state/pool.go @@ -21,18 +21,18 @@ import ( type AgentPool struct { sync.Mutex - file string - pooldir string - pool AgentPoolData - agents map[string]*Agent - managers map[string]sse.Manager - agentStatus map[string]*Status - apiURL, defaultModel, defaultMultimodalModel, localRAGAPI, apiKey string - availableActions func(*AgentConfig) func(ctx context.Context, pool *AgentPool) []Action - connectors func(*AgentConfig) []Connector - promptBlocks func(*AgentConfig) []PromptBlock - timeout string - conversationLogs string + file string + pooldir string + pool AgentPoolData + agents map[string]*Agent + managers map[string]sse.Manager + agentStatus map[string]*Status + apiURL, defaultModel, defaultMultimodalModel, localRAGAPI, localRAGKey, apiKey string + availableActions func(*AgentConfig) func(ctx context.Context, pool *AgentPool) []Action + connectors func(*AgentConfig) []Connector + promptBlocks func(*AgentConfig) []PromptBlock + timeout string + conversationLogs string } type Status struct { @@ -182,6 +182,22 @@ func (a *AgentPool) startAgentWithConfig(name string, config *AgentConfig) error config.PeriodicRuns = "10m" } + if config.APIURL != "" { + a.apiURL = config.APIURL + } + + if config.APIKey != "" { + a.apiKey = config.APIKey + } + + if config.LocalRAGURL != "" { + a.localRAGAPI = config.LocalRAGURL + } + + if config.LocalRAGAPIKey != "" { + a.localRAGKey = config.LocalRAGAPIKey + } + connectors := a.connectors(config) promptBlocks := a.promptBlocks(config) @@ -231,7 +247,7 @@ func (a *AgentPool) startAgentWithConfig(name string, config *AgentConfig) error WithCharacterFile(characterFile), WithLLMAPIKey(a.apiKey), WithTimeout(a.timeout), - WithRAGDB(localrag.NewWrappedClient(a.localRAGAPI, name)), + WithRAGDB(localrag.NewWrappedClient(a.localRAGAPI, a.localRAGKey, name)), WithAgentReasoningCallback(func(state ActionCurrentState) bool { xlog.Info( "Agent is thinking", diff --git a/pkg/localrag/client.go b/pkg/localrag/client.go index 4034a8f..af69aee 100644 --- a/pkg/localrag/client.go +++ b/pkg/localrag/client.go @@ -26,9 +26,9 @@ type WrappedClient struct { collection string } -func NewWrappedClient(baseURL, collection string) *WrappedClient { +func NewWrappedClient(baseURL, apiKey, collection string) *WrappedClient { wc := &WrappedClient{ - Client: NewClient(baseURL), + Client: NewClient(baseURL, apiKey), collection: collection, } @@ -104,15 +104,25 @@ type Result struct { // Client is a client for the RAG API type Client struct { BaseURL string + APIKey string } // NewClient creates a new RAG API client -func NewClient(baseURL string) *Client { +func NewClient(baseURL, apiKey string) *Client { return &Client{ BaseURL: baseURL, + APIKey: apiKey, } } +// Add a helper method to set the Authorization header +func (c *Client) addAuthHeader(req *http.Request) { + if c.APIKey == "" { + return + } + req.Header.Set("Authorization", "Bearer "+c.APIKey) +} + // CreateCollection creates a new collection func (c *Client) CreateCollection(name string) error { url := fmt.Sprintf("%s/api/collections", c.BaseURL) @@ -126,7 +136,15 @@ func (c *Client) CreateCollection(name string) error { return err } - resp, err := http.Post(url, "application/json", bytes.NewBuffer(payload)) + req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(payload)) + if err != nil { + return err + } + req.Header.Set("Content-Type", "application/json") + c.addAuthHeader(req) + + client := &http.Client{} + resp, err := client.Do(req) if err != nil { return err } @@ -143,7 +161,14 @@ func (c *Client) CreateCollection(name string) error { func (c *Client) ListCollections() ([]string, error) { url := fmt.Sprintf("%s/api/collections", c.BaseURL) - resp, err := http.Get(url) + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return nil, err + } + c.addAuthHeader(req) + + client := &http.Client{} + resp, err := client.Do(req) if err != nil { return nil, err } @@ -162,18 +187,25 @@ func (c *Client) ListCollections() ([]string, error) { return collections, nil } -// ListCollections lists all collections +// ListEntries lists all entries in a collection func (c *Client) ListEntries(collection string) ([]string, error) { url := fmt.Sprintf("%s/api/collections/%s/entries", c.BaseURL, collection) - resp, err := http.Get(url) + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return nil, err + } + c.addAuthHeader(req) + + client := &http.Client{} + resp, err := client.Do(req) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return nil, errors.New("failed to list collections") + return nil, errors.New("failed to list entries") } var entries []string @@ -185,39 +217,37 @@ func (c *Client) ListEntries(collection string) ([]string, error) { return entries, nil } -// DeleteEntry deletes an Entry in a collection and return the entries left +// DeleteEntry deletes an entry in a collection func (c *Client) DeleteEntry(collection, entry string) ([]string, error) { url := fmt.Sprintf("%s/api/collections/%s/entry/delete", c.BaseURL, collection) type request struct { Entry string `json:"entry"` } - client := &http.Client{} + payload, err := json.Marshal(request{Entry: entry}) if err != nil { return nil, err } - // Create request - req, err := http.NewRequest("DELETE", url, bytes.NewBuffer(payload)) + req, err := http.NewRequest(http.MethodDelete, url, bytes.NewBuffer(payload)) if err != nil { return nil, err } - req.Header.Set("Content-Type", "application/json") + c.addAuthHeader(req) - // Fetch Request + client := &http.Client{} resp, err := client.Do(req) if err != nil { return nil, err } - defer resp.Body.Close() if resp.StatusCode != http.StatusOK { bodyResult := new(bytes.Buffer) bodyResult.ReadFrom(resp.Body) - return nil, errors.New("failed to delete collection: " + bodyResult.String()) + return nil, errors.New("failed to delete entry: " + bodyResult.String()) } var results []string @@ -243,7 +273,15 @@ func (c *Client) Search(collection, query string, maxResults int) ([]Result, err return nil, err } - resp, err := http.Post(url, "application/json", bytes.NewBuffer(payload)) + req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(payload)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + c.addAuthHeader(req) + + client := &http.Client{} + resp, err := client.Do(req) if err != nil { return nil, err } @@ -262,12 +300,15 @@ func (c *Client) Search(collection, query string, maxResults int) ([]Result, err return results, nil } +// Reset resets a collection func (c *Client) Reset(collection string) error { url := fmt.Sprintf("%s/api/collections/%s/reset", c.BaseURL, collection) + req, err := http.NewRequest(http.MethodPost, url, nil) if err != nil { return err } + c.addAuthHeader(req) client := &http.Client{} resp, err := client.Do(req) @@ -279,7 +320,6 @@ func (c *Client) Reset(collection string) error { if resp.StatusCode != http.StatusOK { b := new(bytes.Buffer) b.ReadFrom(resp.Body) - return errors.New("failed to reset collection: " + b.String()) } @@ -319,6 +359,7 @@ func (c *Client) Store(collection, filePath string) error { return err } req.Header.Set("Content-Type", writer.FormDataContentType()) + c.addAuthHeader(req) client := &http.Client{} resp, err := client.Do(req) diff --git a/webui/views/partials/agent-form.html b/webui/views/partials/agent-form.html index ec7ff45..c36ee32 100644 --- a/webui/views/partials/agent-form.html +++ b/webui/views/partials/agent-form.html @@ -174,6 +174,26 @@ +
+ + +
+ +
+ + +
+ +
+ + +
+ +
+ + +
+