Fix genimage action
This commit is contained in:
13
services/actions/actions_suite_test.go
Normal file
13
services/actions/actions_suite_test.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
package actions_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestActions(t *testing.T) {
|
||||||
|
RegisterFailHandler(Fail)
|
||||||
|
RunSpecs(t, "Agent actions test suite")
|
||||||
|
}
|
||||||
@@ -31,13 +31,16 @@ func (a *GenImageAction) Run(ctx context.Context, params action.ActionParams) (s
|
|||||||
}{}
|
}{}
|
||||||
err := params.Unmarshal(&result)
|
err := params.Unmarshal(&result)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("error: %v", err)
|
|
||||||
|
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if result.Prompt == "" {
|
||||||
|
return "", fmt.Errorf("prompt is required")
|
||||||
|
}
|
||||||
|
|
||||||
req := openai.ImageRequest{
|
req := openai.ImageRequest{
|
||||||
Prompt: result.Prompt,
|
Prompt: result.Prompt,
|
||||||
|
Model: a.imageModel,
|
||||||
}
|
}
|
||||||
|
|
||||||
switch result.Size {
|
switch result.Size {
|
||||||
|
|||||||
70
services/actions/genimage_test.go
Normal file
70
services/actions/genimage_test.go
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
package actions_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
. "github.com/mudler/LocalAgent/core/action"
|
||||||
|
|
||||||
|
. "github.com/mudler/LocalAgent/services/actions"
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("GenImageAction", func() {
|
||||||
|
var (
|
||||||
|
ctx context.Context
|
||||||
|
action *GenImageAction
|
||||||
|
params ActionParams
|
||||||
|
config map[string]string
|
||||||
|
)
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
ctx = context.Background()
|
||||||
|
apiKey := os.Getenv("OPENAI_API_KEY")
|
||||||
|
apiURL := os.Getenv("OPENAI_API_URL")
|
||||||
|
testModel := os.Getenv("OPENAI_MODEL")
|
||||||
|
if apiURL == "" {
|
||||||
|
Skip("OPENAI_API_URL must be set")
|
||||||
|
}
|
||||||
|
config = map[string]string{
|
||||||
|
"apiKey": apiKey,
|
||||||
|
"apiURL": apiURL,
|
||||||
|
"model": testModel,
|
||||||
|
}
|
||||||
|
action = NewGenImage(config)
|
||||||
|
})
|
||||||
|
|
||||||
|
Describe("Run", func() {
|
||||||
|
It("should generate an image with valid prompt and size", func() {
|
||||||
|
params = ActionParams{
|
||||||
|
"prompt": "test prompt",
|
||||||
|
"size": "256x256",
|
||||||
|
}
|
||||||
|
|
||||||
|
url, err := action.Run(ctx, params)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(url).ToNot(BeEmpty())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should return an error if the prompt is not provided", func() {
|
||||||
|
params = ActionParams{
|
||||||
|
"size": "256x256",
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := action.Run(ctx, params)
|
||||||
|
Expect(err).To(HaveOccurred())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Describe("Definition", func() {
|
||||||
|
It("should return the correct action definition", func() {
|
||||||
|
definition := action.Definition()
|
||||||
|
Expect(definition.Name.String()).To(Equal("generate_image"))
|
||||||
|
Expect(definition.Description).To(Equal("Generate image with."))
|
||||||
|
Expect(definition.Properties).To(HaveKey("prompt"))
|
||||||
|
Expect(definition.Properties).To(HaveKey("size"))
|
||||||
|
Expect(definition.Required).To(ContainElement("prompt"))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
@@ -3,11 +3,11 @@ package actions
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/google/go-github/v61/github"
|
"github.com/google/go-github/v61/github"
|
||||||
"github.com/mudler/LocalAgent/core/action"
|
"github.com/mudler/LocalAgent/core/action"
|
||||||
|
"github.com/mudler/LocalAgent/pkg/xlog"
|
||||||
"github.com/sashabaranov/go-openai/jsonschema"
|
"github.com/sashabaranov/go-openai/jsonschema"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -53,7 +53,7 @@ func (g *GithubIssuesLabeler) Run(ctx context.Context, params action.ActionParam
|
|||||||
labels, _, err := g.client.Issues.AddLabelsToIssue(g.context, result.Owner, result.Repository, result.IssueNumber, []string{result.Label})
|
labels, _, err := g.client.Issues.AddLabelsToIssue(g.context, result.Owner, result.Repository, result.IssueNumber, []string{result.Label})
|
||||||
//labelsNames := []string{}
|
//labelsNames := []string{}
|
||||||
for _, l := range labels {
|
for _, l := range labels {
|
||||||
slog.Info("Label added:", l.Name)
|
xlog.Info("Label added", "label", l.Name)
|
||||||
//labelsNames = append(labelsNames, l.GetName())
|
//labelsNames = append(labelsNames, l.GetName())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,10 +3,10 @@ package actions
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
|
||||||
|
|
||||||
"github.com/google/go-github/v61/github"
|
"github.com/google/go-github/v61/github"
|
||||||
"github.com/mudler/LocalAgent/core/action"
|
"github.com/mudler/LocalAgent/core/action"
|
||||||
|
"github.com/mudler/LocalAgent/pkg/xlog"
|
||||||
"github.com/sashabaranov/go-openai/jsonschema"
|
"github.com/sashabaranov/go-openai/jsonschema"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -51,7 +51,7 @@ func (g *GithubIssueSearch) Run(ctx context.Context, params action.ActionParams)
|
|||||||
return resultString, err
|
return resultString, err
|
||||||
}
|
}
|
||||||
for _, i := range issues.Issues {
|
for _, i := range issues.Issues {
|
||||||
slog.Info("Issue found:", i.GetTitle())
|
xlog.Info("Issue found", "title", i.GetTitle())
|
||||||
resultString += fmt.Sprintf("Issue found: %s\n", i.GetTitle())
|
resultString += fmt.Sprintf("Issue found: %s\n", i.GetTitle())
|
||||||
resultString += fmt.Sprintf("URL: %s\n", i.GetHTMLURL())
|
resultString += fmt.Sprintf("URL: %s\n", i.GetHTMLURL())
|
||||||
// resultString += fmt.Sprintf("Body: %s\n", i.GetBody())
|
// resultString += fmt.Sprintf("Body: %s\n", i.GetBody())
|
||||||
|
|||||||
Reference in New Issue
Block a user