diff --git a/main.py b/main.py index 48314af..62ca4ba 100644 --- a/main.py +++ b/main.py @@ -366,8 +366,6 @@ if __name__ == "__main__": logger.info("Creating LocalAGI instance") localagi = LocalAGI( agent_actions=agent_actions, - embeddings_model=EMBEDDINGS_MODEL, - embeddings_api_base=EMBEDDINGS_API_BASE, llm_model=LLM_MODEL, tts_model=VOICE_MODEL, tts_api_base=TTS_API_BASE, diff --git a/src/localagi/localagi.py b/src/localagi/localagi.py index e73fd21..8426ea4 100644 --- a/src/localagi/localagi.py +++ b/src/localagi/localagi.py @@ -8,7 +8,6 @@ DEFAULT_API_BASE = "http://api:8080" VOICE_MODEL = "en-us-kathleen-low.onnx" STABLEDIFFUSION_MODEL = "stablediffusion" FUNCTIONS_MODEL = "functions" -EMBEDDINGS_MODEL = "all-MiniLM-L6-v2" LLM_MODEL = "gpt-4" # LocalAGI class @@ -23,13 +22,13 @@ class LocalAGI: api_base=DEFAULT_API_BASE, tts_api_base="", stablediffusion_api_base="", - embeddings_api_base="", tts_model=VOICE_MODEL, stablediffusion_model=STABLEDIFFUSION_MODEL, functions_model=FUNCTIONS_MODEL, - embeddings_model=EMBEDDINGS_MODEL, llm_model=LLM_MODEL, tts_player="aplay", + action_callback=None, + reasoning_callback=None, ): self.api_base = api_base self.agent_actions = agent_actions @@ -37,6 +36,8 @@ class LocalAGI: self.force_action = force_action self.processed_messages=0 self.tts_player = tts_player + self.action_callback = action_callback + self.reasoning_callback = reasoning_callback self.agent_actions[plan_action] = { "function": self.generate_plan, "plannable": False, @@ -63,11 +64,9 @@ class LocalAGI: } self.tts_api_base = tts_api_base if tts_api_base else self.api_base self.stablediffusion_api_base = stablediffusion_api_base if stablediffusion_api_base else self.api_base - self.embeddings_api_base = embeddings_api_base if embeddings_api_base else self.api_base self.tts_model = tts_model self.stablediffusion_model = stablediffusion_model self.functions_model = functions_model - self.embeddings_model = embeddings_model self.llm_model = llm_model self.reply_action = reply_action # Function to create images with LocalAI @@ -216,6 +215,8 @@ class LocalAGI: function_parameters = response.choices[0].message["function_call"].arguments logger.info("==> function parameters: {function_parameters}",function_parameters=function_parameters) function_to_call = self.agent_actions[function_name]["function"] + if self.action_callback: + self.action_callback(function_name, function_parameters) function_result = function_to_call(function_parameters, agent_actions=self.agent_actions, localagi=self) logger.info("==> function result: {function_result}", function_result=function_result) @@ -484,6 +485,9 @@ class LocalAGI: logger.error(e) action = {"action": self.reply_action} + if self.reasoning_callback: + self.reasoning_callback(action["action"], action["detailed_reasoning"]) + if action["action"] != self.reply_action: logger.info("==> LocalAGI wants to call '{action}'", action=action["action"]) #logger.info("==> Observation '{reasoning}'", reasoning=action["detailed_reasoning"])