This commit is contained in:
mudler
2024-04-16 17:44:04 +02:00
parent 4ed61daf8e
commit 35d9ba44f5
7 changed files with 66 additions and 26 deletions

View File

@@ -67,10 +67,14 @@ func (a *Agent) decision(
ToolChoice: toolchoice, ToolChoice: toolchoice,
} }
resp, err := a.client.CreateChatCompletion(ctx, decision) resp, err := a.client.CreateChatCompletion(ctx, decision)
if err != nil || len(resp.Choices) != 1 { if err != nil {
return nil, err return nil, err
} }
if len(resp.Choices) != 1 {
return nil, fmt.Errorf("no choices")
}
msg := resp.Choices[0].Message msg := resp.Choices[0].Message
if len(msg.ToolCalls) != 1 { if len(msg.ToolCalls) != 1 {
return &decisionResult{message: msg.Content}, nil return &decisionResult{message: msg.Content}, nil

View File

@@ -51,7 +51,7 @@ func New(opts ...Option) (*Agent, error) {
return nil, fmt.Errorf("failed to set options: %v", err) return nil, fmt.Errorf("failed to set options: %v", err)
} }
client := llm.NewClient(options.LLMAPI.APIKey, options.LLMAPI.APIURL) client := llm.NewClient(options.LLMAPI.APIKey, options.LLMAPI.APIURL, options.timeout)
c := context.Background() c := context.Background()
if options.context != nil { if options.context != nil {
@@ -103,6 +103,10 @@ func (a *Agent) Context() context.Context {
return a.context.Context return a.context.Context
} }
func (a *Agent) ActionContext() context.Context {
return a.actionContext.Context
}
func (a *Agent) ConversationChannel() chan openai.ChatCompletionMessage { func (a *Agent) ConversationChannel() chan openai.ChatCompletionMessage {
return a.newConversations return a.newConversations
} }
@@ -111,7 +115,13 @@ func (a *Agent) ConversationChannel() chan openai.ChatCompletionMessage {
// It discards any other computation. // It discards any other computation.
func (a *Agent) Ask(opts ...JobOption) *JobResult { func (a *Agent) Ask(opts ...JobOption) *JobResult {
a.StopAction() a.StopAction()
j := NewJob(append(opts, WithReasoningCallback(a.options.reasoningCallback), WithResultCallback(a.options.resultCallback))...) j := NewJob(
append(
opts,
WithReasoningCallback(a.options.reasoningCallback),
WithResultCallback(a.options.resultCallback),
)...,
)
// slog.Info("Job created", text) // slog.Info("Job created", text)
a.jobQueue <- j a.jobQueue <- j
return j.Result.WaitResult() return j.Result.WaitResult()
@@ -164,7 +174,7 @@ func (a *Agent) Paused() bool {
func (a *Agent) runAction(chosenAction Action, decisionResult *decisionResult) (result string, err error) { func (a *Agent) runAction(chosenAction Action, decisionResult *decisionResult) (result string, err error) {
for _, action := range a.systemInternalActions() { for _, action := range a.systemInternalActions() {
if action.Definition().Name == chosenAction.Definition().Name { if action.Definition().Name == chosenAction.Definition().Name {
if result, err = action.Run(a.context, decisionResult.actionParams); err != nil { if result, err = action.Run(a.actionContext, decisionResult.actionParams); err != nil {
return "", fmt.Errorf("error running action: %w", err) return "", fmt.Errorf("error running action: %w", err)
} }
} }

View File

@@ -29,6 +29,7 @@ type options struct {
statefile string statefile string
context context.Context context context.Context
permanentGoal string permanentGoal string
timeout string
periodicRuns time.Duration periodicRuns time.Duration
kbResults int kbResults int
ragdb RAGDB ragdb RAGDB
@@ -98,6 +99,13 @@ func LogLevel(level slog.Level) Option {
} }
} }
func WithTimeout(timeout string) Option {
return func(o *options) error {
o.timeout = timeout
return nil
}
}
func EnableKnowledgeBaseWithResults(results int) Option { func EnableKnowledgeBaseWithResults(results int) Option {
return func(o *options) error { return func(o *options) error {
o.enableKB = true o.enableKB = true

View File

@@ -155,6 +155,7 @@ func (a *AgentPool) startAgentWithConfig(name string, config *AgentConfig) error
), ),
WithStateFile(stateFile), WithStateFile(stateFile),
WithCharacterFile(characterFile), WithCharacterFile(characterFile),
WithTimeout(timeout),
WithRAGDB(a.ragDB), WithRAGDB(a.ragDB),
WithAgentReasoningCallback(func(state ActionCurrentState) bool { WithAgentReasoningCallback(func(state ActionCurrentState) bool {
slog.Info("Reasoning", state.Reasoning) slog.Info("Reasoning", state.Reasoning)

View File

@@ -61,6 +61,7 @@ func (g *GithubIssues) Start(a *agent.Agent) {
slog.Info("Looking into github issues...") slog.Info("Looking into github issues...")
g.issuesService() g.issuesService()
case <-a.Context().Done(): case <-a.Context().Done():
slog.Info("GithubIssues connector is now stopping")
return return
} }
} }
@@ -152,7 +153,6 @@ func (g *GithubIssues) issuesService() {
continue continue
} }
go func() {
res := g.agent.Ask( res := g.agent.Ask(
agent.WithConversationHistory(messages), agent.WithConversationHistory(messages),
) )
@@ -171,6 +171,5 @@ func (g *GithubIssues) issuesService() {
if err != nil { if err != nil {
slog.Error("Error creating comment", "error", err) slog.Error("Error creating comment", "error", err)
} }
}()
} }
} }

View File

@@ -21,6 +21,7 @@ var apiURL = os.Getenv("API_URL")
var apiKey = os.Getenv("API_KEY") var apiKey = os.Getenv("API_KEY")
var vectorStore = os.Getenv("VECTOR_STORE") var vectorStore = os.Getenv("VECTOR_STORE")
var kbdisableIndexing = os.Getenv("KBDISABLEINDEX") var kbdisableIndexing = os.Getenv("KBDISABLEINDEX")
var timeout = os.Getenv("TIMEOUT")
const defaultChunkSize = 4098 const defaultChunkSize = 4098
@@ -31,6 +32,9 @@ func init() {
if apiURL == "" { if apiURL == "" {
apiURL = "http://192.168.68.113:8080" apiURL = "http://192.168.68.113:8080"
} }
if timeout == "" {
timeout = "5m"
}
} }
//go:embed views/* //go:embed views/*
@@ -50,7 +54,7 @@ func main() {
os.MkdirAll(stateDir, 0755) os.MkdirAll(stateDir, 0755)
var dbStore RAGDB var dbStore RAGDB
lai := llm.NewClient(apiKey, apiURL+"/v1") lai := llm.NewClient(apiKey, apiURL+"/v1", timeout)
switch vectorStore { switch vectorStore {
case "localai": case "localai":

View File

@@ -1,8 +1,13 @@
package llm package llm
import "github.com/sashabaranov/go-openai" import (
"net/http"
"time"
func NewClient(APIKey, URL string) *openai.Client { "github.com/sashabaranov/go-openai"
)
func NewClient(APIKey, URL, timeout string) *openai.Client {
// Set up OpenAI client // Set up OpenAI client
if APIKey == "" { if APIKey == "" {
//log.Fatal("OPENAI_API_KEY environment variable not set") //log.Fatal("OPENAI_API_KEY environment variable not set")
@@ -10,5 +15,14 @@ func NewClient(APIKey, URL string) *openai.Client {
} }
config := openai.DefaultConfig(APIKey) config := openai.DefaultConfig(APIKey)
config.BaseURL = URL config.BaseURL = URL
dur, err := time.ParseDuration(timeout)
if err != nil {
dur = 150 * time.Second
}
config.HTTPClient = &http.Client{
Timeout: dur,
}
return openai.NewClientWithConfig(config) return openai.NewClientWithConfig(config)
} }