From 36abf837a9f831b55b189aa8a6eb53ddf9966592 Mon Sep 17 00:00:00 2001 From: mudler Date: Tue, 9 Apr 2024 20:05:08 +0200 Subject: [PATCH] Add rag commands --- llm/rag.go | 71 ++++++++++++++++++++++++ llm/store.go | 152 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 223 insertions(+) create mode 100644 llm/rag.go create mode 100644 llm/store.go diff --git a/llm/rag.go b/llm/rag.go new file mode 100644 index 0000000..10ca98e --- /dev/null +++ b/llm/rag.go @@ -0,0 +1,71 @@ +package llm + +import ( + "context" + "fmt" + + "github.com/sashabaranov/go-openai" +) + +func StoreStringEmbeddingInVectorDB(apiHost string, openaiClient *openai.Client, s string) error { + // Example usage + client := NewStoreClient(apiHost) + + resp, err := openaiClient.CreateEmbeddings(context.TODO(), + openai.EmbeddingRequestStrings{ + Input: []string{s}, + Model: openai.AdaEmbeddingV2, + }, + ) + if err != nil { + return fmt.Errorf("error getting keys: %v", err) + } + + if len(resp.Data) == 0 { + return fmt.Errorf("no response from OpenAI API") + } + + embedding := resp.Data[0].Embedding + + setReq := SetRequest{ + Keys: [][]float32{embedding}, + Values: []string{s}, + } + err = client.Set(setReq) + if err != nil { + return fmt.Errorf("error setting keys: %v", err) + } + + return nil +} + +func FindSimilarStrings(apiHost string, openaiClient *openai.Client, s string, similarEntries int) ([]string, error) { + client := NewStoreClient(apiHost) + + resp, err := openaiClient.CreateEmbeddings(context.TODO(), + openai.EmbeddingRequestStrings{ + Input: []string{s}, + Model: openai.AdaEmbeddingV2, + }, + ) + if err != nil { + return []string{}, fmt.Errorf("error getting keys: %v", err) + } + + if len(resp.Data) == 0 { + return []string{}, fmt.Errorf("no response from OpenAI API") + } + embedding := resp.Data[0].Embedding + + // Find example + findReq := FindRequest{ + TopK: similarEntries, // Number of similar entries you want to find + Key: embedding, // The key you're looking for similarities to + } + findResp, err := client.Find(findReq) + if err != nil { + return []string{}, fmt.Errorf("error finding keys: %v", err) + } + + return findResp.Values, nil +} diff --git a/llm/store.go b/llm/store.go new file mode 100644 index 0000000..ba3bebe --- /dev/null +++ b/llm/store.go @@ -0,0 +1,152 @@ +package llm + +import ( + "bytes" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" +) + +// Define a struct to hold your store API client +type StoreClient struct { + BaseURL string + Client *http.Client +} + +// Define request and response struct formats based on the API documentation +type SetRequest struct { + Keys [][]float32 `json:"keys"` + Values []string `json:"values"` +} + +type GetRequest struct { + Keys [][]float32 `json:"keys"` +} + +type GetResponse struct { + Keys [][]float32 `json:"keys"` + Values []string `json:"values"` +} + +type DeleteRequest struct { + Keys [][]float32 `json:"keys"` +} + +type FindRequest struct { + TopK int `json:"topk"` + Key []float32 `json:"key"` +} + +type FindResponse struct { + Keys [][]float32 `json:"keys"` + Values []string `json:"values"` + Similarities []float32 `json:"similarities"` +} + +// Constructor for StoreClient +func NewStoreClient(baseUrl string) *StoreClient { + return &StoreClient{ + BaseURL: baseUrl, + Client: &http.Client{}, + } +} + +// Implement Set method +func (c *StoreClient) Set(req SetRequest) error { + return c.doRequest("stores/set", req) +} + +// Implement Get method +func (c *StoreClient) Get(req GetRequest) (*GetResponse, error) { + body, err := c.doRequestWithResponse("stores/get", req) + if err != nil { + return nil, err + } + + var resp GetResponse + err = json.Unmarshal(body, &resp) + if err != nil { + return nil, err + } + + return &resp, nil +} + +// Implement Delete method +func (c *StoreClient) Delete(req DeleteRequest) error { + return c.doRequest("stores/delete", req) +} + +// Implement Find method +func (c *StoreClient) Find(req FindRequest) (*FindResponse, error) { + body, err := c.doRequestWithResponse("stores/find", req) + if err != nil { + return nil, err + } + + var resp FindResponse + err = json.Unmarshal(body, &resp) + if err != nil { + return nil, err + } + + return &resp, nil +} + +// Helper function to perform a request without expecting a response body +func (c *StoreClient) doRequest(path string, data interface{}) error { + jsonData, err := json.Marshal(data) + if err != nil { + return err + } + + req, err := http.NewRequest("POST", c.BaseURL+"/"+path, bytes.NewBuffer(jsonData)) + if err != nil { + return err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := c.Client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("API request to %s failed with status code %d", path, resp.StatusCode) + } + + return nil +} + +// Helper function to perform a request and parse the response body +func (c *StoreClient) doRequestWithResponse(path string, data interface{}) ([]byte, error) { + jsonData, err := json.Marshal(data) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("POST", c.BaseURL+"/"+path, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := c.Client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request to %s failed with status code %d", path, resp.StatusCode) + } + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + return body, nil +}