reordering
This commit is contained in:
28
pkg/llm/client.go
Normal file
28
pkg/llm/client.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
func NewClient(APIKey, URL, timeout string) *openai.Client {
|
||||
// Set up OpenAI client
|
||||
if APIKey == "" {
|
||||
//log.Fatal("OPENAI_API_KEY environment variable not set")
|
||||
APIKey = "sk-xxx"
|
||||
}
|
||||
config := openai.DefaultConfig(APIKey)
|
||||
config.BaseURL = URL
|
||||
|
||||
dur, err := time.ParseDuration(timeout)
|
||||
if err != nil {
|
||||
dur = 150 * time.Second
|
||||
}
|
||||
|
||||
config.HTTPClient = &http.Client{
|
||||
Timeout: dur,
|
||||
}
|
||||
return openai.NewClientWithConfig(config)
|
||||
}
|
||||
47
pkg/llm/json.go
Normal file
47
pkg/llm/json.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
// generateAnswer generates an answer for the given text using the OpenAI API
|
||||
func GenerateJSON(ctx context.Context, client *openai.Client, model, text string, i interface{}) error {
|
||||
req := openai.ChatCompletionRequest{
|
||||
ResponseFormat: &openai.ChatCompletionResponseFormat{Type: openai.ChatCompletionResponseFormatTypeJSONObject},
|
||||
Model: model,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
|
||||
Role: "user",
|
||||
Content: text,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := client.CreateChatCompletion(ctx, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate answer: %v", err)
|
||||
}
|
||||
if len(resp.Choices) == 0 {
|
||||
return fmt.Errorf("no response from OpenAI API")
|
||||
}
|
||||
|
||||
err = json.Unmarshal([]byte(resp.Choices[0].Message.Content), i)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GenerateJSONFromStruct(ctx context.Context, client *openai.Client, guidance, model string, i interface{}) error {
|
||||
// TODO: use functions?
|
||||
exampleJSON, err := json.Marshal(i)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return GenerateJSON(ctx, client, model, "Generate a character as JSON data. "+guidance+". This is the JSON fields that should contain: "+string(exampleJSON), i)
|
||||
}
|
||||
113
pkg/llm/rag/chromem.go
Normal file
113
pkg/llm/rag/chromem.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package rag
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime"
|
||||
|
||||
"github.com/philippgille/chromem-go"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
type ChromemDB struct {
|
||||
collectionName string
|
||||
collection *chromem.Collection
|
||||
index int
|
||||
client *openai.Client
|
||||
db *chromem.DB
|
||||
embeddingsModel string
|
||||
}
|
||||
|
||||
func NewChromemDB(collection, path string, openaiClient *openai.Client, embeddingsModel string) (*ChromemDB, error) {
|
||||
// db, err := chromem.NewPersistentDB(path, true)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
db := chromem.NewDB()
|
||||
|
||||
chromem := &ChromemDB{
|
||||
collectionName: collection,
|
||||
index: 1,
|
||||
db: db,
|
||||
client: openaiClient,
|
||||
embeddingsModel: embeddingsModel,
|
||||
}
|
||||
|
||||
c, err := db.GetOrCreateCollection(collection, nil, chromem.embedding())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
chromem.collection = c
|
||||
|
||||
return chromem, nil
|
||||
}
|
||||
|
||||
func (c *ChromemDB) Count() int {
|
||||
return c.collection.Count()
|
||||
}
|
||||
|
||||
func (c *ChromemDB) Reset() error {
|
||||
if err := c.db.DeleteCollection(c.collectionName); err != nil {
|
||||
return err
|
||||
}
|
||||
collection, err := c.db.GetOrCreateCollection(c.collectionName, nil, c.embedding())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.collection = collection
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ChromemDB) embedding() chromem.EmbeddingFunc {
|
||||
return chromem.EmbeddingFunc(
|
||||
func(ctx context.Context, text string) ([]float32, error) {
|
||||
resp, err := c.client.CreateEmbeddings(ctx,
|
||||
openai.EmbeddingRequestStrings{
|
||||
Input: []string{text},
|
||||
Model: openai.EmbeddingModel(c.embeddingsModel),
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return []float32{}, fmt.Errorf("error getting keys: %v", err)
|
||||
}
|
||||
|
||||
if len(resp.Data) == 0 {
|
||||
return []float32{}, fmt.Errorf("no response from OpenAI API")
|
||||
}
|
||||
|
||||
embedding := resp.Data[0].Embedding
|
||||
|
||||
return embedding, nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (c *ChromemDB) Store(s string) error {
|
||||
defer func() {
|
||||
c.index++
|
||||
}()
|
||||
if s == "" {
|
||||
return fmt.Errorf("empty string")
|
||||
}
|
||||
return c.collection.AddDocuments(context.Background(), []chromem.Document{
|
||||
{
|
||||
Content: s,
|
||||
ID: fmt.Sprint(c.index),
|
||||
},
|
||||
}, runtime.NumCPU())
|
||||
}
|
||||
|
||||
func (c *ChromemDB) Search(s string, similarEntries int) ([]string, error) {
|
||||
res, err := c.collection.Query(context.Background(), s, similarEntries, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var results []string
|
||||
for _, r := range res {
|
||||
results = append(results, r.Content)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
86
pkg/llm/rag/localai.go
Normal file
86
pkg/llm/rag/localai.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package rag
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
type LocalAIRAGDB struct {
|
||||
client *StoreClient
|
||||
openaiClient *openai.Client
|
||||
}
|
||||
|
||||
func NewLocalAIRAGDB(storeClient *StoreClient, openaiClient *openai.Client) *LocalAIRAGDB {
|
||||
return &LocalAIRAGDB{
|
||||
client: storeClient,
|
||||
openaiClient: openaiClient,
|
||||
}
|
||||
}
|
||||
|
||||
func (db *LocalAIRAGDB) Reset() error {
|
||||
return fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
func (db *LocalAIRAGDB) Count() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (db *LocalAIRAGDB) Store(s string) error {
|
||||
resp, err := db.openaiClient.CreateEmbeddings(context.TODO(),
|
||||
openai.EmbeddingRequestStrings{
|
||||
Input: []string{s},
|
||||
Model: openai.AdaEmbeddingV2,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting keys: %v", err)
|
||||
}
|
||||
|
||||
if len(resp.Data) == 0 {
|
||||
return fmt.Errorf("no response from OpenAI API")
|
||||
}
|
||||
|
||||
embedding := resp.Data[0].Embedding
|
||||
|
||||
setReq := SetRequest{
|
||||
Keys: [][]float32{embedding},
|
||||
Values: []string{s},
|
||||
}
|
||||
err = db.client.Set(setReq)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error setting keys: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *LocalAIRAGDB) Search(s string, similarEntries int) ([]string, error) {
|
||||
resp, err := db.openaiClient.CreateEmbeddings(context.TODO(),
|
||||
openai.EmbeddingRequestStrings{
|
||||
Input: []string{s},
|
||||
Model: openai.AdaEmbeddingV2,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return []string{}, fmt.Errorf("error getting keys: %v", err)
|
||||
}
|
||||
|
||||
if len(resp.Data) == 0 {
|
||||
return []string{}, fmt.Errorf("no response from OpenAI API")
|
||||
}
|
||||
embedding := resp.Data[0].Embedding
|
||||
|
||||
// Find example
|
||||
findReq := FindRequest{
|
||||
TopK: similarEntries, // Number of similar entries you want to find
|
||||
Key: embedding, // The key you're looking for similarities to
|
||||
}
|
||||
findResp, err := db.client.Find(findReq)
|
||||
if err != nil {
|
||||
return []string{}, fmt.Errorf("error finding keys: %v", err)
|
||||
}
|
||||
|
||||
return findResp.Values, nil
|
||||
}
|
||||
161
pkg/llm/rag/store.go
Normal file
161
pkg/llm/rag/store.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package rag
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Define a struct to hold your store API client
|
||||
type StoreClient struct {
|
||||
BaseURL string
|
||||
APIToken string
|
||||
Client *http.Client
|
||||
}
|
||||
|
||||
// Define request and response struct formats based on the API documentation
|
||||
type SetRequest struct {
|
||||
Keys [][]float32 `json:"keys"`
|
||||
Values []string `json:"values"`
|
||||
}
|
||||
|
||||
type GetRequest struct {
|
||||
Keys [][]float32 `json:"keys"`
|
||||
}
|
||||
|
||||
type GetResponse struct {
|
||||
Keys [][]float32 `json:"keys"`
|
||||
Values []string `json:"values"`
|
||||
}
|
||||
|
||||
type DeleteRequest struct {
|
||||
Keys [][]float32 `json:"keys"`
|
||||
}
|
||||
|
||||
type FindRequest struct {
|
||||
TopK int `json:"topk"`
|
||||
Key []float32 `json:"key"`
|
||||
}
|
||||
|
||||
type FindResponse struct {
|
||||
Keys [][]float32 `json:"keys"`
|
||||
Values []string `json:"values"`
|
||||
Similarities []float32 `json:"similarities"`
|
||||
}
|
||||
|
||||
// Constructor for StoreClient
|
||||
func NewStoreClient(baseUrl, apiToken string) *StoreClient {
|
||||
return &StoreClient{
|
||||
BaseURL: baseUrl,
|
||||
APIToken: apiToken,
|
||||
Client: &http.Client{},
|
||||
}
|
||||
}
|
||||
|
||||
// Implement Set method
|
||||
func (c *StoreClient) Set(req SetRequest) error {
|
||||
return c.doRequest("stores/set", req)
|
||||
}
|
||||
|
||||
// Implement Get method
|
||||
func (c *StoreClient) Get(req GetRequest) (*GetResponse, error) {
|
||||
body, err := c.doRequestWithResponse("stores/get", req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var resp GetResponse
|
||||
err = json.Unmarshal(body, &resp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// Implement Delete method
|
||||
func (c *StoreClient) Delete(req DeleteRequest) error {
|
||||
return c.doRequest("stores/delete", req)
|
||||
}
|
||||
|
||||
// Implement Find method
|
||||
func (c *StoreClient) Find(req FindRequest) (*FindResponse, error) {
|
||||
body, err := c.doRequestWithResponse("stores/find", req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var resp FindResponse
|
||||
err = json.Unmarshal(body, &resp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// Helper function to perform a request without expecting a response body
|
||||
func (c *StoreClient) doRequest(path string, data interface{}) error {
|
||||
jsonData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", c.BaseURL+"/"+path, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Set Bearer token
|
||||
if c.APIToken != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+c.APIToken)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.Client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("API request to %s failed with status code %d", path, resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Helper function to perform a request and parse the response body
|
||||
func (c *StoreClient) doRequestWithResponse(path string, data interface{}) ([]byte, error) {
|
||||
jsonData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", c.BaseURL+"/"+path, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
// Set Bearer token
|
||||
if c.APIToken != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+c.APIToken)
|
||||
}
|
||||
resp, err := c.Client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API request to %s failed with status code %d", path, resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return body, nil
|
||||
}
|
||||
71
pkg/xlog/xlog.go
Normal file
71
pkg/xlog/xlog.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package xlog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"os"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
var logger *slog.Logger
|
||||
|
||||
func init() {
|
||||
var level = slog.LevelDebug
|
||||
|
||||
switch os.Getenv("LOG_LEVEL") {
|
||||
case "info":
|
||||
level = slog.LevelInfo
|
||||
case "warn":
|
||||
level = slog.LevelWarn
|
||||
case "error":
|
||||
level = slog.LevelError
|
||||
case "debug":
|
||||
level = slog.LevelDebug
|
||||
}
|
||||
|
||||
var opts = &slog.HandlerOptions{
|
||||
Level: level,
|
||||
}
|
||||
|
||||
var handler slog.Handler
|
||||
|
||||
if os.Getenv("LOG_FORMAT") == "json" {
|
||||
handler = slog.NewJSONHandler(os.Stdout, opts)
|
||||
} else {
|
||||
handler = slog.NewTextHandler(os.Stdout, opts)
|
||||
}
|
||||
logger = slog.New(handler)
|
||||
}
|
||||
|
||||
func _log(level slog.Level, msg string, args ...any) {
|
||||
_, f, l, _ := runtime.Caller(2)
|
||||
group := slog.Group(
|
||||
"source",
|
||||
slog.Attr{
|
||||
Key: "file",
|
||||
Value: slog.AnyValue(f),
|
||||
},
|
||||
slog.Attr{
|
||||
Key: "L",
|
||||
Value: slog.AnyValue(l),
|
||||
},
|
||||
)
|
||||
args = append(args, group)
|
||||
logger.Log(context.Background(), level, msg, args...)
|
||||
}
|
||||
|
||||
func Info(msg string, args ...any) {
|
||||
_log(slog.LevelInfo, msg, args...)
|
||||
}
|
||||
|
||||
func Debug(msg string, args ...any) {
|
||||
_log(slog.LevelDebug, msg, args...)
|
||||
}
|
||||
|
||||
func Error(msg string, args ...any) {
|
||||
_log(slog.LevelError, msg, args...)
|
||||
}
|
||||
|
||||
func Warn(msg string, args ...any) {
|
||||
_log(slog.LevelWarn, msg, args...)
|
||||
}
|
||||
Reference in New Issue
Block a user