diff --git a/agent/actions.go b/agent/actions.go index 018a517..bd23fcb 100644 --- a/agent/actions.go +++ b/agent/actions.go @@ -67,10 +67,14 @@ func (a *Agent) decision( ToolChoice: toolchoice, } resp, err := a.client.CreateChatCompletion(ctx, decision) - if err != nil || len(resp.Choices) != 1 { + if err != nil { return nil, err } + if len(resp.Choices) != 1 { + return nil, fmt.Errorf("no choices") + } + msg := resp.Choices[0].Message if len(msg.ToolCalls) != 1 { return &decisionResult{message: msg.Content}, nil diff --git a/agent/agent.go b/agent/agent.go index e9f792f..b90e600 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -51,7 +51,7 @@ func New(opts ...Option) (*Agent, error) { return nil, fmt.Errorf("failed to set options: %v", err) } - client := llm.NewClient(options.LLMAPI.APIKey, options.LLMAPI.APIURL) + client := llm.NewClient(options.LLMAPI.APIKey, options.LLMAPI.APIURL, options.timeout) c := context.Background() if options.context != nil { @@ -103,6 +103,10 @@ func (a *Agent) Context() context.Context { return a.context.Context } +func (a *Agent) ActionContext() context.Context { + return a.actionContext.Context +} + func (a *Agent) ConversationChannel() chan openai.ChatCompletionMessage { return a.newConversations } @@ -111,7 +115,13 @@ func (a *Agent) ConversationChannel() chan openai.ChatCompletionMessage { // It discards any other computation. func (a *Agent) Ask(opts ...JobOption) *JobResult { a.StopAction() - j := NewJob(append(opts, WithReasoningCallback(a.options.reasoningCallback), WithResultCallback(a.options.resultCallback))...) + j := NewJob( + append( + opts, + WithReasoningCallback(a.options.reasoningCallback), + WithResultCallback(a.options.resultCallback), + )..., + ) // slog.Info("Job created", text) a.jobQueue <- j return j.Result.WaitResult() @@ -164,7 +174,7 @@ func (a *Agent) Paused() bool { func (a *Agent) runAction(chosenAction Action, decisionResult *decisionResult) (result string, err error) { for _, action := range a.systemInternalActions() { if action.Definition().Name == chosenAction.Definition().Name { - if result, err = action.Run(a.context, decisionResult.actionParams); err != nil { + if result, err = action.Run(a.actionContext, decisionResult.actionParams); err != nil { return "", fmt.Errorf("error running action: %w", err) } } diff --git a/agent/options.go b/agent/options.go index c455813..6bf240e 100644 --- a/agent/options.go +++ b/agent/options.go @@ -29,6 +29,7 @@ type options struct { statefile string context context.Context permanentGoal string + timeout string periodicRuns time.Duration kbResults int ragdb RAGDB @@ -98,6 +99,13 @@ func LogLevel(level slog.Level) Option { } } +func WithTimeout(timeout string) Option { + return func(o *options) error { + o.timeout = timeout + return nil + } +} + func EnableKnowledgeBaseWithResults(results int) Option { return func(o *options) error { o.enableKB = true diff --git a/example/webui/agentpool.go b/example/webui/agentpool.go index d3c4339..8ee7edb 100644 --- a/example/webui/agentpool.go +++ b/example/webui/agentpool.go @@ -155,6 +155,7 @@ func (a *AgentPool) startAgentWithConfig(name string, config *AgentConfig) error ), WithStateFile(stateFile), WithCharacterFile(characterFile), + WithTimeout(timeout), WithRAGDB(a.ragDB), WithAgentReasoningCallback(func(state ActionCurrentState) bool { slog.Info("Reasoning", state.Reasoning) diff --git a/example/webui/connector/githubissue.go b/example/webui/connector/githubissue.go index d7e17a2..e91a786 100644 --- a/example/webui/connector/githubissue.go +++ b/example/webui/connector/githubissue.go @@ -61,6 +61,7 @@ func (g *GithubIssues) Start(a *agent.Agent) { slog.Info("Looking into github issues...") g.issuesService() case <-a.Context().Done(): + slog.Info("GithubIssues connector is now stopping") return } } @@ -152,25 +153,23 @@ func (g *GithubIssues) issuesService() { continue } - go func() { - res := g.agent.Ask( - agent.WithConversationHistory(messages), - ) - if res.Error != nil { - slog.Error("Error asking", "error", res.Error) - return - } + res := g.agent.Ask( + agent.WithConversationHistory(messages), + ) + if res.Error != nil { + slog.Error("Error asking", "error", res.Error) + return + } - _, _, err := g.client.Issues.CreateComment( - g.agent.Context(), - g.owner, g.repository, - issue.GetNumber(), &github.IssueComment{ - Body: github.String(res.Response), - }, - ) - if err != nil { - slog.Error("Error creating comment", "error", err) - } - }() + _, _, err := g.client.Issues.CreateComment( + g.agent.Context(), + g.owner, g.repository, + issue.GetNumber(), &github.IssueComment{ + Body: github.String(res.Response), + }, + ) + if err != nil { + slog.Error("Error creating comment", "error", err) + } } } diff --git a/example/webui/main.go b/example/webui/main.go index 367927a..795efd0 100644 --- a/example/webui/main.go +++ b/example/webui/main.go @@ -21,6 +21,7 @@ var apiURL = os.Getenv("API_URL") var apiKey = os.Getenv("API_KEY") var vectorStore = os.Getenv("VECTOR_STORE") var kbdisableIndexing = os.Getenv("KBDISABLEINDEX") +var timeout = os.Getenv("TIMEOUT") const defaultChunkSize = 4098 @@ -31,6 +32,9 @@ func init() { if apiURL == "" { apiURL = "http://192.168.68.113:8080" } + if timeout == "" { + timeout = "5m" + } } //go:embed views/* @@ -50,7 +54,7 @@ func main() { os.MkdirAll(stateDir, 0755) var dbStore RAGDB - lai := llm.NewClient(apiKey, apiURL+"/v1") + lai := llm.NewClient(apiKey, apiURL+"/v1", timeout) switch vectorStore { case "localai": diff --git a/llm/client.go b/llm/client.go index 10cba75..dc27afe 100644 --- a/llm/client.go +++ b/llm/client.go @@ -1,8 +1,13 @@ package llm -import "github.com/sashabaranov/go-openai" +import ( + "net/http" + "time" -func NewClient(APIKey, URL string) *openai.Client { + "github.com/sashabaranov/go-openai" +) + +func NewClient(APIKey, URL, timeout string) *openai.Client { // Set up OpenAI client if APIKey == "" { //log.Fatal("OPENAI_API_KEY environment variable not set") @@ -10,5 +15,14 @@ func NewClient(APIKey, URL string) *openai.Client { } config := openai.DefaultConfig(APIKey) config.BaseURL = URL + + dur, err := time.ParseDuration(timeout) + if err != nil { + dur = 150 * time.Second + } + + config.HTTPClient = &http.Client{ + Timeout: dur, + } return openai.NewClientWithConfig(config) }