refactoring

This commit is contained in:
Ettore Di Giacinto
2025-02-25 23:17:28 +01:00
parent 296734ba3b
commit 0139b79835
13 changed files with 177 additions and 139 deletions

224
core/sse/sse.go Normal file
View File

@@ -0,0 +1,224 @@
package sse
import (
"bufio"
"fmt"
"strings"
"sync"
"time"
"github.com/gofiber/fiber/v2"
"github.com/valyala/fasthttp"
)
type (
// Listener defines the interface for the receiving end.
Listener interface {
ID() string
Chan() chan Envelope
}
// Envelope defines the interface for content that can be broadcast to clients.
Envelope interface {
String() string // Represent the envelope contents as a string for transmission.
}
// Manager defines the interface for managing clients and broadcasting messages.
Manager interface {
Send(message Envelope)
Handle(ctx *fiber.Ctx, cl Listener)
Clients() []string
}
History interface {
Add(message Envelope) // Add adds a message to the history.
Send(c Listener) // Send sends the history to a client.
}
)
type Client struct {
id string
ch chan Envelope
}
func NewClient(id string) Listener {
return &Client{
id: id,
ch: make(chan Envelope, 50),
}
}
func (c *Client) ID() string { return c.id }
func (c *Client) Chan() chan Envelope { return c.ch }
// Message represents a simple message implementation.
type Message struct {
Event string
Time time.Time
Data string
}
// NewMessage returns a new message instance.
func NewMessage(data string) *Message {
return &Message{
Data: data,
Time: time.Now(),
}
}
// String returns the message as a string.
func (m *Message) String() string {
sb := strings.Builder{}
if m.Event != "" {
sb.WriteString(fmt.Sprintf("event: %s\n", m.Event))
}
sb.WriteString(fmt.Sprintf("data: %v\n\n", m.Data))
return sb.String()
}
// WithEvent sets the event name for the message.
func (m *Message) WithEvent(event string) Envelope {
m.Event = event
return m
}
// broadcastManager manages the clients and broadcasts messages to them.
type broadcastManager struct {
clients sync.Map
broadcast chan Envelope
workerPoolSize int
messageHistory *history
}
// NewManager initializes and returns a new Manager instance.
func NewManager(workerPoolSize int) Manager {
manager := &broadcastManager{
broadcast: make(chan Envelope),
workerPoolSize: workerPoolSize,
messageHistory: newHistory(10),
}
manager.startWorkers()
return manager
}
// Send broadcasts a message to all connected clients.
func (manager *broadcastManager) Send(message Envelope) {
manager.broadcast <- message
}
// Handle sets up a new client and handles the connection.
func (manager *broadcastManager) Handle(c *fiber.Ctx, cl Listener) {
manager.register(cl)
ctx := c.Context()
ctx.SetContentType("text/event-stream")
ctx.Response.Header.Set("Cache-Control", "no-cache")
ctx.Response.Header.Set("Connection", "keep-alive")
ctx.Response.Header.Set("Access-Control-Allow-Origin", "*")
ctx.Response.Header.Set("Access-Control-Allow-Headers", "Cache-Control")
ctx.Response.Header.Set("Access-Control-Allow-Credentials", "true")
// Send history to the newly connected client
manager.messageHistory.Send(cl)
ctx.SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
for {
select {
case msg, ok := <-cl.Chan():
if !ok {
// If the channel is closed, return from the function
return
}
_, err := fmt.Fprint(w, msg.String())
if err != nil {
// If an error occurs (e.g., client has disconnected), return from the function
return
}
w.Flush()
case <-ctx.Done():
manager.unregister(cl.ID())
close(cl.Chan())
return
}
}
}))
}
// Clients method to list connected client IDs
func (manager *broadcastManager) Clients() []string {
var clients []string
manager.clients.Range(func(key, value any) bool {
id, ok := key.(string)
if ok {
clients = append(clients, id)
}
return true
})
return clients
}
// startWorkers starts worker goroutines for message broadcasting.
func (manager *broadcastManager) startWorkers() {
for i := 0; i < manager.workerPoolSize; i++ {
go func() {
for message := range manager.broadcast {
manager.clients.Range(func(key, value any) bool {
client, ok := value.(Listener)
if !ok {
return true // Continue iteration
}
select {
case client.Chan() <- message:
manager.messageHistory.Add(message)
default:
// If the client's channel is full, drop the message
}
return true // Continue iteration
})
}
}()
}
}
// register adds a client to the manager.
func (manager *broadcastManager) register(client Listener) {
manager.clients.Store(client.ID(), client)
}
// unregister removes a client from the manager.
func (manager *broadcastManager) unregister(clientID string) {
manager.clients.Delete(clientID)
}
type history struct {
messages []Envelope
maxSize int // Maximum number of messages to retain
}
func newHistory(maxSize int) *history {
return &history{
messages: []Envelope{},
maxSize: maxSize,
}
}
func (h *history) Add(message Envelope) {
h.messages = append(h.messages, message)
// Ensure history does not exceed maxSize
if len(h.messages) > h.maxSize {
// Remove the oldest messages to fit the maxSize
h.messages = h.messages[len(h.messages)-h.maxSize:]
}
}
func (h *history) Send(c Listener) {
for _, msg := range h.messages {
c.Chan() <- msg
}
}

43
core/state/config.go Normal file
View File

@@ -0,0 +1,43 @@
package state
import (
"github.com/mudler/local-agent-framework/core/agent"
)
type ConnectorConfig struct {
Type string `json:"type"` // e.g. Slack
Config string `json:"config"`
}
type ActionsConfig struct {
Name string `json:"name"` // e.g. search
Config string `json:"config"`
}
type AgentConfig struct {
Connector []ConnectorConfig `json:"connectors" form:"connectors" `
Actions []ActionsConfig `json:"actions" form:"actions"`
// This is what needs to be part of ActionsConfig
Model string `json:"model" form:"model"`
Name string `json:"name" form:"name"`
HUD bool `json:"hud" form:"hud"`
StandaloneJob bool `json:"standalone_job" form:"standalone_job"`
RandomIdentity bool `json:"random_identity" form:"random_identity"`
InitiateConversations bool `json:"initiate_conversations" form:"initiate_conversations"`
IdentityGuidance string `json:"identity_guidance" form:"identity_guidance"`
PeriodicRuns string `json:"periodic_runs" form:"periodic_runs"`
PermanentGoal string `json:"permanent_goal" form:"permanent_goal"`
EnableKnowledgeBase bool `json:"enable_kb" form:"enable_kb"`
EnableReasoning bool `json:"enable_reasoning" form:"enable_reasoning"`
KnowledgeBaseResults int `json:"kb_results" form:"kb_results"`
CanStopItself bool `json:"can_stop_itself" form:"can_stop_itself"`
SystemPrompt string `json:"system_prompt" form:"system_prompt"`
LongTermMemory bool `json:"long_term_memory" form:"long_term_memory"`
SummaryLongTermMemory bool `json:"summary_long_term_memory" form:"summary_long_term_memory"`
}
type Connector interface {
AgentResultCallback() func(state agent.ActionState)
AgentReasoningCallback() func(state agent.ActionCurrentState) bool
Start(a *agent.Agent)
}

212
core/state/memory.go Normal file
View File

@@ -0,0 +1,212 @@
package state
import (
"encoding/json"
"fmt"
"io"
"github.com/mudler/local-agent-framework/pkg/xlog"
"net/http"
"os"
"strings"
"sync"
. "github.com/mudler/local-agent-framework/core/agent"
"jaytaylor.com/html2text"
sitemap "github.com/oxffaa/gopher-parse-sitemap"
)
type InMemoryDatabase struct {
RAGDB
sync.Mutex
Database []string
path string
}
func loadDB(path string) ([]string, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
poolData := []string{}
err = json.Unmarshal(data, &poolData)
return poolData, err
}
func NewInMemoryDB(poolfile string, store RAGDB) (*InMemoryDatabase, error) {
// if file exists, try to load an existing pool.
// if file does not exist, create a new pool.
if _, err := os.Stat(poolfile); err != nil {
// file does not exist, return a new pool
return &InMemoryDatabase{
Database: []string{},
path: poolfile,
RAGDB: store,
}, nil
}
poolData, err := loadDB(poolfile)
if err != nil {
return nil, err
}
db := &InMemoryDatabase{
RAGDB: store,
Database: poolData,
path: poolfile,
}
if err := db.populateRAGDB(); err != nil {
return nil, fmt.Errorf("error populating RAGDB: %w", err)
}
return db, nil
}
func (db *InMemoryDatabase) Data() []string {
db.Lock()
defer db.Unlock()
return db.Database
}
func (db *InMemoryDatabase) populateRAGDB() error {
for _, d := range db.Database {
if d == "" {
// skip empty chunks
continue
}
err := db.RAGDB.Store(d)
if err != nil {
return fmt.Errorf("error storing in the KB: %w", err)
}
}
return nil
}
func (db *InMemoryDatabase) Reset() error {
db.Lock()
db.Database = []string{}
db.Unlock()
if err := db.RAGDB.Reset(); err != nil {
return err
}
return db.SaveDB()
}
func (db *InMemoryDatabase) save() error {
data, err := json.Marshal(db.Database)
if err != nil {
return err
}
return os.WriteFile(db.path, data, 0644)
}
func (db *InMemoryDatabase) Store(entry string) error {
db.Lock()
defer db.Unlock()
db.Database = append(db.Database, entry)
if err := db.RAGDB.Store(entry); err != nil {
return err
}
return db.save()
}
func (db *InMemoryDatabase) SaveDB() error {
db.Lock()
defer db.Unlock()
return db.save()
}
func getWebPage(url string) (string, error) {
resp, err := http.Get(url)
if err != nil {
return "", err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
return html2text.FromString(string(body), html2text.Options{PrettyTables: true})
}
func getWebSitemap(url string) (res []string, err error) {
err = sitemap.ParseFromSite(url, func(e sitemap.Entry) error {
xlog.Info("Sitemap page: " + e.GetLocation())
content, err := getWebPage(e.GetLocation())
if err == nil {
res = append(res, content)
}
return nil
})
return
}
func WebsiteToKB(website string, chunkSize int, db RAGDB) {
content, err := getWebSitemap(website)
if err != nil {
xlog.Info("Error walking sitemap for website", err)
}
xlog.Info("Found pages: ", len(content))
xlog.Info("ChunkSize: ", chunkSize)
StringsToKB(db, chunkSize, content...)
}
func StringsToKB(db RAGDB, chunkSize int, content ...string) {
for _, c := range content {
chunks := splitParagraphIntoChunks(c, chunkSize)
xlog.Info("chunks: ", len(chunks))
for _, chunk := range chunks {
xlog.Info("Chunk size: ", len(chunk))
db.Store(chunk)
}
}
}
// splitParagraphIntoChunks takes a paragraph and a maxChunkSize as input,
// and returns a slice of strings where each string is a chunk of the paragraph
// that is at most maxChunkSize long, ensuring that words are not split.
func splitParagraphIntoChunks(paragraph string, maxChunkSize int) []string {
if len(paragraph) <= maxChunkSize {
return []string{paragraph}
}
var chunks []string
var currentChunk strings.Builder
words := strings.Fields(paragraph) // Splits the paragraph into words.
for _, word := range words {
// If adding the next word would exceed maxChunkSize (considering a space if not the first word in a chunk),
// add the currentChunk to chunks, and reset currentChunk.
if currentChunk.Len() > 0 && currentChunk.Len()+len(word)+1 > maxChunkSize { // +1 for the space if not the first word
chunks = append(chunks, currentChunk.String())
currentChunk.Reset()
} else if currentChunk.Len() == 0 && len(word) > maxChunkSize { // Word itself exceeds maxChunkSize, split the word
chunks = append(chunks, word)
continue
}
// Add a space before the word if it's not the beginning of a new chunk.
if currentChunk.Len() > 0 {
currentChunk.WriteString(" ")
}
// Add the word to the current chunk.
currentChunk.WriteString(word)
}
// After the loop, add any remaining content in currentChunk to chunks.
if currentChunk.Len() > 0 {
chunks = append(chunks, currentChunk.String())
}
return chunks
}

437
core/state/pool.go Normal file
View File

@@ -0,0 +1,437 @@
package state
import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"sort"
"sync"
"time"
"github.com/mudler/local-agent-framework/core/agent"
. "github.com/mudler/local-agent-framework/core/agent"
"github.com/mudler/local-agent-framework/core/sse"
"github.com/mudler/local-agent-framework/pkg/utils"
"github.com/mudler/local-agent-framework/pkg/xlog"
)
type AgentPool struct {
sync.Mutex
file string
pooldir string
pool AgentPoolData
agents map[string]*Agent
managers map[string]sse.Manager
agentStatus map[string]*Status
agentMemory map[string]*InMemoryDatabase
apiURL, model string
ragDB RAGDB
availableActions func(*AgentConfig) func(ctx context.Context) []Action
connectors func(*AgentConfig) []Connector
timeout string
}
type Status struct {
ActionResults []ActionState
}
func (s *Status) addResult(result ActionState) {
// If we have more than 10 results, remove the oldest one
if len(s.ActionResults) > 10 {
s.ActionResults = s.ActionResults[1:]
}
s.ActionResults = append(s.ActionResults, result)
}
func (s *Status) Results() []ActionState {
return s.ActionResults
}
type AgentPoolData map[string]AgentConfig
func loadPoolFromFile(path string) (*AgentPoolData, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
poolData := &AgentPoolData{}
err = json.Unmarshal(data, poolData)
return poolData, err
}
func NewAgentPool(
model, apiURL, directory string,
RagDB RAGDB,
availableActions func(*AgentConfig) func(ctx context.Context) []agent.Action,
connectors func(*AgentConfig) []Connector,
timeout string,
) (*AgentPool, error) {
// if file exists, try to load an existing pool.
// if file does not exist, create a new pool.
poolfile := filepath.Join(directory, "pool.json")
if _, err := os.Stat(poolfile); err != nil {
// file does not exist, create a new pool
return &AgentPool{
file: poolfile,
pooldir: directory,
apiURL: apiURL,
model: model,
ragDB: RagDB,
agents: make(map[string]*Agent),
pool: make(map[string]AgentConfig),
agentStatus: make(map[string]*Status),
managers: make(map[string]sse.Manager),
agentMemory: make(map[string]*InMemoryDatabase),
connectors: connectors,
availableActions: availableActions,
timeout: timeout,
}, nil
}
poolData, err := loadPoolFromFile(poolfile)
if err != nil {
return nil, err
}
return &AgentPool{
file: poolfile,
apiURL: apiURL,
pooldir: directory,
ragDB: RagDB,
model: model,
agents: make(map[string]*Agent),
managers: make(map[string]sse.Manager),
agentStatus: map[string]*Status{},
agentMemory: map[string]*InMemoryDatabase{},
pool: *poolData,
connectors: connectors,
availableActions: availableActions,
timeout: timeout,
}, nil
}
// CreateAgent adds a new agent to the pool
// and starts it.
// It also saves the state to the file.
func (a *AgentPool) CreateAgent(name string, agentConfig *AgentConfig) error {
if _, ok := a.pool[name]; ok {
return fmt.Errorf("agent %s already exists", name)
}
a.pool[name] = *agentConfig
if err := a.Save(); err != nil {
return err
}
return a.startAgentWithConfig(name, agentConfig)
}
func (a *AgentPool) List() []string {
var agents []string
for agent := range a.pool {
agents = append(agents, agent)
}
// return a sorted list
sort.SliceStable(agents, func(i, j int) bool {
return agents[i] < agents[j]
})
return agents
}
func (a *AgentPool) GetStatusHistory(name string) *Status {
a.Lock()
defer a.Unlock()
return a.agentStatus[name]
}
func (a *AgentPool) startAgentWithConfig(name string, config *AgentConfig) error {
manager := sse.NewManager(5)
ctx := context.Background()
model := a.model
if config.Model != "" {
model = config.Model
}
if config.PeriodicRuns == "" {
config.PeriodicRuns = "10m"
}
connectors := a.connectors(config)
actions := a.availableActions(config)(ctx)
stateFile, characterFile, knowledgeBase := a.stateFiles(name)
agentDB, err := NewInMemoryDB(knowledgeBase, a.ragDB)
if err != nil {
return err
}
a.agentMemory[name] = agentDB
actionsLog := []string{}
for _, action := range actions {
actionsLog = append(actionsLog, action.Definition().Name.String())
}
connectorLog := []string{}
for _, connector := range connectors {
connectorLog = append(connectorLog, fmt.Sprintf("%+v", connector))
}
xlog.Info(
"Creating agent",
"name", name,
"model", model,
"api_url", a.apiURL,
"actions", actionsLog,
"connectors", connectorLog,
)
opts := []Option{
WithModel(model),
WithLLMAPIURL(a.apiURL),
WithContext(ctx),
WithPeriodicRuns(config.PeriodicRuns),
WithPermanentGoal(config.PermanentGoal),
WithCharacter(Character{
Name: name,
}),
WithActions(
actions...,
),
WithStateFile(stateFile),
WithCharacterFile(characterFile),
WithTimeout(a.timeout),
WithRAGDB(agentDB),
WithAgentReasoningCallback(func(state ActionCurrentState) bool {
xlog.Info(
"Agent is thinking",
"agent", name,
"reasoning", state.Reasoning,
"action", state.Action.Definition().Name,
"params", state.Params,
)
manager.Send(
sse.NewMessage(
fmt.Sprintf(`Thinking: %s`, utils.HTMLify(state.Reasoning)),
).WithEvent("status"),
)
for _, c := range connectors {
if !c.AgentReasoningCallback()(state) {
return false
}
}
return true
}),
WithSystemPrompt(config.SystemPrompt),
WithAgentResultCallback(func(state ActionState) {
a.Lock()
if _, ok := a.agentStatus[name]; !ok {
a.agentStatus[name] = &Status{}
}
a.agentStatus[name].addResult(state)
a.Unlock()
xlog.Info(
"Agent executed an action",
"agent", name,
"reasoning", state.Reasoning,
"action", state.ActionCurrentState.Action.Definition().Name,
"params", state.ActionCurrentState.Params,
"result", state.Result,
)
text := fmt.Sprintf(`Reasoning: %s
Action taken: %+v
Parameters: %+v
Result: %s`,
state.Reasoning,
state.ActionCurrentState.Action.Definition().Name,
state.ActionCurrentState.Params,
state.Result)
manager.Send(
sse.NewMessage(
utils.HTMLify(
text,
),
).WithEvent("status"),
)
for _, c := range connectors {
c.AgentResultCallback()(state)
}
}),
}
if config.HUD {
opts = append(opts, EnableHUD)
}
if config.StandaloneJob {
opts = append(opts, EnableStandaloneJob)
}
if config.LongTermMemory {
opts = append(opts, EnableLongTermMemory)
}
if config.SummaryLongTermMemory {
opts = append(opts, EnableSummaryMemory)
}
if config.CanStopItself {
opts = append(opts, CanStopItself)
}
if config.InitiateConversations {
opts = append(opts, EnableInitiateConversations)
}
if config.RandomIdentity {
if config.IdentityGuidance != "" {
opts = append(opts, WithRandomIdentity(config.IdentityGuidance))
} else {
opts = append(opts, WithRandomIdentity())
}
}
if config.EnableKnowledgeBase {
opts = append(opts, EnableKnowledgeBase)
}
if config.EnableReasoning {
opts = append(opts, EnableForceReasoning)
}
if config.KnowledgeBaseResults > 0 {
opts = append(opts, EnableKnowledgeBaseWithResults(config.KnowledgeBaseResults))
}
xlog.Info("Starting agent", "name", name, "config", config)
agent, err := New(opts...)
if err != nil {
return err
}
a.agents[name] = agent
a.managers[name] = manager
go func() {
if err := agent.Run(); err != nil {
xlog.Error("Agent stopped", "error", err.Error())
panic(err)
}
}()
for _, c := range connectors {
go c.Start(agent)
}
go func() {
for {
time.Sleep(1 * time.Second) // Send a message every seconds
manager.Send(sse.NewMessage(
utils.HTMLify(agent.State().String()),
).WithEvent("hud"))
}
}()
return nil
}
// Starts all the agents in the pool
func (a *AgentPool) StartAll() error {
for name, config := range a.pool {
if a.agents[name] != nil { // Agent already started
continue
}
if err := a.startAgentWithConfig(name, &config); err != nil {
return err
}
}
return nil
}
func (a *AgentPool) StopAll() {
for _, agent := range a.agents {
agent.Stop()
}
}
func (a *AgentPool) Stop(name string) {
if agent, ok := a.agents[name]; ok {
agent.Stop()
}
}
func (a *AgentPool) Start(name string) error {
if agent, ok := a.agents[name]; ok {
err := agent.Run()
if err != nil {
return fmt.Errorf("agent %s failed to start: %w", name, err)
}
xlog.Info("Agent started", "name", name)
return nil
}
if config, ok := a.pool[name]; ok {
return a.startAgentWithConfig(name, &config)
}
return fmt.Errorf("agent %s not found", name)
}
func (a *AgentPool) stateFiles(name string) (string, string, string) {
stateFile := filepath.Join(a.pooldir, fmt.Sprintf("%s.state.json", name))
characterFile := filepath.Join(a.pooldir, fmt.Sprintf("%s.character.json", name))
knowledgeBaseFile := filepath.Join(a.pooldir, fmt.Sprintf("%s.knowledgebase.json", name))
return stateFile, characterFile, knowledgeBaseFile
}
func (a *AgentPool) Remove(name string) error {
// Cleanup character and state
stateFile, characterFile, knowledgeBaseFile := a.stateFiles(name)
os.Remove(stateFile)
os.Remove(characterFile)
os.Remove(knowledgeBaseFile)
a.Stop(name)
delete(a.agents, name)
delete(a.pool, name)
if err := a.Save(); err != nil {
return err
}
return nil
}
func (a *AgentPool) Save() error {
data, err := json.MarshalIndent(a.pool, "", " ")
if err != nil {
return err
}
return os.WriteFile(a.file, data, 0644)
}
func (a *AgentPool) GetAgent(name string) *Agent {
return a.agents[name]
}
func (a *AgentPool) GetAgentMemory(name string) *InMemoryDatabase {
return a.agentMemory[name]
}
func (a *AgentPool) GetConfig(name string) *AgentConfig {
agent, exists := a.pool[name]
if !exists {
return nil
}
return &agent
}
func (a *AgentPool) GetManager(name string) sse.Manager {
return a.managers[name]
}