Small fixups

This commit is contained in:
mudler
2023-08-05 16:56:00 +02:00
parent 881edbad07
commit 74449e90e5
2 changed files with 115 additions and 31 deletions

79
main.py
View File

@@ -88,10 +88,27 @@ parser.add_argument('--stablediffusion-model', dest='stablediffusion_model', act
# Stable diffusion prompt
parser.add_argument('--stablediffusion-prompt', dest='stablediffusion_prompt', action='store', default=DEFAULT_PROMPT,
help='Stable diffusion prompt')
# Force action
parser.add_argument('--force-action', dest='force_action', action='store', default="",
help='Force an action')
# Debug mode
parser.add_argument('--debug', dest='debug', action='store_true', default=False,
help='Debug mode')
# Parse arguments
args = parser.parse_args()
# Set log level
LOG_LEVEL = "INFO"
def my_filter(record):
return record["level"].no >= logger.level(LOG_LEVEL).no
logger.remove()
logger.add(sys.stderr, filter=my_filter)
if args.debug:
LOG_LEVEL = "DEBUG"
logger.debug("Debug mode on")
FUNCTIONS_MODEL = os.environ.get("FUNCTIONS_MODEL", args.functions_model)
EMBEDDINGS_MODEL = os.environ.get("EMBEDDINGS_MODEL", args.embeddings_model)
@@ -222,8 +239,8 @@ Function call: """
function_parameters = response.choices[0].message["function_call"].arguments
# read the json from the string
res = json.loads(function_parameters)
logger.info(">>> function name: "+function_name)
logger.info(">>> function parameters: "+function_parameters)
logger.debug(">>> function name: "+function_name)
logger.debug(">>> function parameters: "+function_parameters)
return res
return {"action": REPLY_ACTION}
@@ -274,7 +291,7 @@ Function call: """
if response_message.get("function_call"):
function_name = response.choices[0].message["function_call"].name
function_parameters = response.choices[0].message["function_call"].arguments
logger.info("==> function parameters: {function_parameters}",function_parameters=function_parameters)
logger.debug("==> function parameters: {function_parameters}",function_parameters=function_parameters)
function_to_call = agent_actions[function_name]["function"]
function_result = function_to_call(function_parameters, agent_actions=agent_actions)
@@ -300,7 +317,7 @@ def function_completion(messages, action="", agent_actions={}):
function_call = "auto"
if action != "":
function_call={"name": action}
logger.info("==> function name: {function_call}", function_call=function_call)
logger.debug("==> function name: {function_call}", function_call=function_call)
# get the functions from the signatures of the agent actions, if exists
functions = []
for action in agent_actions:
@@ -364,19 +381,32 @@ def converse(responses):
### Fine tune a string before feeding into the LLM
def analyze(responses, prefix="Analyze the following text highlighting the relevant information and identify a list of actions to take if there are any. If there are errors, suggest solutions to fix them"):
def analyze(responses, prefix="Analyze the following text highlighting the relevant information and identify a list of actions to take if there are any. If there are errors, suggest solutions to fix them", suffix=""):
string = process_history(responses)
messages = [
{
"role": "user",
"content": f"""{prefix}:
messages = []
```
{string}
```
""",
}
]
if prefix != "":
messages = [
{
"role": "user",
"content": f"""{prefix}:
```
{string}
```
""",
}
]
else:
messages = [
{
"role": "user",
"content": f"""{string}""",
}
]
if suffix != "":
messages[0]["content"]+=f"""{suffix}"""
response = openai.ChatCompletion.create(
model=LLM_MODEL,
@@ -509,7 +539,7 @@ Function call: """
function_parameters = response.choices[0].message["function_call"].arguments
# read the json from the string
res = json.loads(function_parameters)
logger.info("<<< function name: {function_name} >>>> parameters: {parameters}", function_name=function_name,parameters=function_parameters)
logger.debug("<<< function name: {function_name} >>>> parameters: {parameters}", function_name=function_name,parameters=function_parameters)
return res
return {"action": REPLY_ACTION}
@@ -634,6 +664,10 @@ def evaluate(user_input, conversation_history = [],re_evaluate=False, agent_acti
logger.info("==> LocalAGI wants to call '{action}'", action=action["action"])
#logger.info("==> Observation '{reasoning}'", reasoning=action["observation"])
logger.info("==> Reasoning '{reasoning}'", reasoning=action["reasoning"])
# Force executing a plan instead
if args.force_action:
action["action"] = args.force_action
logger.info("==> Forcing action to '{action}' as requested by the user", action=action["action"])
reasoning = action["reasoning"]
if action["action"] == PLAN_ACTION:
@@ -700,7 +734,7 @@ def evaluate(user_input, conversation_history = [],re_evaluate=False, agent_acti
#responses = converse(responses)
# TODO: this needs to be optimized
responses = analyze(responses, prefix=f"You are an AI assistant. Return an appropriate answer to the user input '{user_input}' given the context below and summarizing the actions taken\n")
responses = analyze(responses[1:], suffix=f"Return an appropriate answer given the context above\n")
# add responses to conversation history by extending the list
conversation_history.append(
@@ -853,11 +887,10 @@ if not args.skip_avatar:
logger.info("Creating avatar, please wait...")
display_avatar()
if not args.prompt:
actions = ""
for action in agent_actions:
actions+=" '"+action+"'"
logger.info("LocalAGI internally can do the following actions:{actions}", actions=actions)
actions = ""
for action in agent_actions:
actions+=" '"+action+"'"
logger.info("LocalAGI internally can do the following actions:{actions}", actions=actions)
if not args.prompt:
logger.info(">>> Interactive mode <<<")