Add sources in results, allow to set search type and results from configs
This commit is contained in:
@@ -37,9 +37,14 @@ FILE_NAME_FORMAT = '%Y_%m_%d_%H_%M_%S'
|
|||||||
EMBEDDINGS_MODEL = config["agent"]["embeddings_model"]
|
EMBEDDINGS_MODEL = config["agent"]["embeddings_model"]
|
||||||
EMBEDDINGS_API_BASE = config["agent"]["embeddings_api_base"]
|
EMBEDDINGS_API_BASE = config["agent"]["embeddings_api_base"]
|
||||||
PERSISTENT_DIR = config["agent"]["persistent_dir"]
|
PERSISTENT_DIR = config["agent"]["persistent_dir"]
|
||||||
MILVUS_HOST = config["agent"]["milvus_host"]
|
MILVUS_HOST = config["agent"]["milvus_host"] if "milvus_host" in config["agent"] else ""
|
||||||
MILVUS_PORT = config["agent"]["milvus_port"]
|
MILVUS_PORT = config["agent"]["milvus_port"] if "milvus_port" in config["agent"] else 0
|
||||||
|
MEMORY_COLLECTION = config["agent"]["memory_collection"]
|
||||||
DB_DIR = config["agent"]["db_dir"]
|
DB_DIR = config["agent"]["db_dir"]
|
||||||
|
MEMORY_CHUNK_SIZE = int(config["agent"]["memory_chunk_size"])
|
||||||
|
MEMORY_CHUNK_OVERLAP = int(config["agent"]["memory_chunk_overlap"])
|
||||||
|
MEMORY_RESULTS = int(config["agent"]["memory_results"])
|
||||||
|
MEMORY_SEARCH_TYPE = config["agent"]["memory_search_type"]
|
||||||
|
|
||||||
if MILVUS_HOST == "":
|
if MILVUS_HOST == "":
|
||||||
if not os.environ.get("PYSQL_HACK", "false") == "false":
|
if not os.environ.get("PYSQL_HACK", "false") == "false":
|
||||||
@@ -61,8 +66,8 @@ def call(thing):
|
|||||||
|
|
||||||
def ingest(a, agent_actions={}, localagi=None):
|
def ingest(a, agent_actions={}, localagi=None):
|
||||||
q = json.loads(a)
|
q = json.loads(a)
|
||||||
chunk_size = 1024
|
chunk_size = MEMORY_CHUNK_SIZE
|
||||||
chunk_overlap = 110
|
chunk_overlap = MEMORY_CHUNK_OVERLAP
|
||||||
logger.info(">>> ingesting: ")
|
logger.info(">>> ingesting: ")
|
||||||
logger.info(q)
|
logger.info(q)
|
||||||
documents = []
|
documents = []
|
||||||
@@ -71,11 +76,11 @@ def ingest(a, agent_actions={}, localagi=None):
|
|||||||
documents.extend(sitemap_loader.load())
|
documents.extend(sitemap_loader.load())
|
||||||
texts = text_splitter.split_documents(documents)
|
texts = text_splitter.split_documents(documents)
|
||||||
if MILVUS_HOST == "":
|
if MILVUS_HOST == "":
|
||||||
db = Chroma.from_documents(texts,embeddings,collection_name="memories", persist_directory=DB_DIR)
|
db = Chroma.from_documents(texts,embeddings,collection_name=MEMORY_COLLECTION, persist_directory=DB_DIR)
|
||||||
db.persist()
|
db.persist()
|
||||||
db = None
|
db = None
|
||||||
else:
|
else:
|
||||||
Milvus.from_documents(texts,embeddings,collection_name="memories", connection_args={"host": MILVUS_HOST, "port": MILVUS_PORT})
|
Milvus.from_documents(texts,embeddings,collection_name=MEMORY_COLLECTION, connection_args={"host": MILVUS_HOST, "port": MILVUS_PORT})
|
||||||
return f"Documents ingested"
|
return f"Documents ingested"
|
||||||
|
|
||||||
def create_image(a, agent_actions={}, localagi=None):
|
def create_image(a, agent_actions={}, localagi=None):
|
||||||
@@ -111,9 +116,9 @@ def save(memory, agent_actions={}, localagi=None):
|
|||||||
logger.info(">>> saving to memories: ")
|
logger.info(">>> saving to memories: ")
|
||||||
logger.info(q["content"])
|
logger.info(q["content"])
|
||||||
if MILVUS_HOST == "":
|
if MILVUS_HOST == "":
|
||||||
chroma_client = Chroma(collection_name="memories",embedding_function=embeddings, persist_directory=DB_DIR)
|
chroma_client = Chroma(collection_name=MEMORY_COLLECTION,embedding_function=embeddings, persist_directory=DB_DIR)
|
||||||
else:
|
else:
|
||||||
chroma_client = Milvus(collection_name="memories",embedding_function=embeddings, connection_args={"host": MILVUS_HOST, "port": MILVUS_PORT})
|
chroma_client = Milvus(collection_name=MEMORY_COLLECTION,embedding_function=embeddings, connection_args={"host": MILVUS_HOST, "port": MILVUS_PORT})
|
||||||
chroma_client.add_texts([q["content"]],[{"id": str(uuid.uuid4())}])
|
chroma_client.add_texts([q["content"]],[{"id": str(uuid.uuid4())}])
|
||||||
if MILVUS_HOST == "":
|
if MILVUS_HOST == "":
|
||||||
chroma_client.persist()
|
chroma_client.persist()
|
||||||
@@ -123,15 +128,33 @@ def save(memory, agent_actions={}, localagi=None):
|
|||||||
def search_memory(query, agent_actions={}, localagi=None):
|
def search_memory(query, agent_actions={}, localagi=None):
|
||||||
q = json.loads(query)
|
q = json.loads(query)
|
||||||
if MILVUS_HOST == "":
|
if MILVUS_HOST == "":
|
||||||
chroma_client = Chroma(collection_name="memories",embedding_function=embeddings, persist_directory=DB_DIR)
|
chroma_client = Chroma(collection_name=MEMORY_COLLECTION,embedding_function=embeddings, persist_directory=DB_DIR)
|
||||||
else:
|
else:
|
||||||
chroma_client = Milvus(collection_name="memories",embedding_function=embeddings, connection_args={"host": MILVUS_HOST, "port": MILVUS_PORT})
|
chroma_client = Milvus(collection_name=MEMORY_COLLECTION,embedding_function=embeddings, connection_args={"host": MILVUS_HOST, "port": MILVUS_PORT})
|
||||||
docs = chroma_client.search(q["keywords"], "mmr")
|
#docs = chroma_client.search(q["keywords"], "mmr")
|
||||||
|
retriever = chroma_client.as_retriever(search_type=MEMORY_SEARCH_TYPE, search_kwargs={"k": MEMORY_RESULTS})
|
||||||
|
|
||||||
|
docs = retriever.get_relevant_documents(q["keywords"])
|
||||||
text_res="Memories found in the database:\n"
|
text_res="Memories found in the database:\n"
|
||||||
|
|
||||||
|
sources = set() # To store unique sources
|
||||||
|
|
||||||
|
# Collect unique sources
|
||||||
|
for document in docs:
|
||||||
|
if "source" in document.metadata:
|
||||||
|
sources.add(document.metadata["source"])
|
||||||
|
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
# drop newlines from page_content
|
# drop newlines from page_content
|
||||||
doc.page_content = " ".join(doc.page_content.replace.split())
|
content = doc.page_content.replace("\n", " ")
|
||||||
text_res+="- "+doc.page_content+"\n"
|
content = " ".join(content.split())
|
||||||
|
text_res+="- "+content+"\n"
|
||||||
|
|
||||||
|
# Print the relevant sources used for the answer
|
||||||
|
for source in sources:
|
||||||
|
if source.startswith("http"):
|
||||||
|
text_res += "" + source + "\n"
|
||||||
|
|
||||||
chroma_client = None
|
chroma_client = None
|
||||||
#if args.postprocess:
|
#if args.postprocess:
|
||||||
# return post_process(text_res)
|
# return post_process(text_res)
|
||||||
|
|||||||
Reference in New Issue
Block a user