From 414f9ca765d84fc5c1692e2773cc3cbe2ee0127d Mon Sep 17 00:00:00 2001 From: mudler Date: Sat, 26 Aug 2023 00:42:49 +0200 Subject: [PATCH] Add sources in results, allow to set search type and results from configs --- examples/discord/agent.py | 49 ++++++++++++++++++++++++++++----------- 1 file changed, 36 insertions(+), 13 deletions(-) diff --git a/examples/discord/agent.py b/examples/discord/agent.py index 9a086d7..8295e08 100644 --- a/examples/discord/agent.py +++ b/examples/discord/agent.py @@ -37,9 +37,14 @@ FILE_NAME_FORMAT = '%Y_%m_%d_%H_%M_%S' EMBEDDINGS_MODEL = config["agent"]["embeddings_model"] EMBEDDINGS_API_BASE = config["agent"]["embeddings_api_base"] PERSISTENT_DIR = config["agent"]["persistent_dir"] -MILVUS_HOST = config["agent"]["milvus_host"] -MILVUS_PORT = config["agent"]["milvus_port"] +MILVUS_HOST = config["agent"]["milvus_host"] if "milvus_host" in config["agent"] else "" +MILVUS_PORT = config["agent"]["milvus_port"] if "milvus_port" in config["agent"] else 0 +MEMORY_COLLECTION = config["agent"]["memory_collection"] DB_DIR = config["agent"]["db_dir"] +MEMORY_CHUNK_SIZE = int(config["agent"]["memory_chunk_size"]) +MEMORY_CHUNK_OVERLAP = int(config["agent"]["memory_chunk_overlap"]) +MEMORY_RESULTS = int(config["agent"]["memory_results"]) +MEMORY_SEARCH_TYPE = config["agent"]["memory_search_type"] if MILVUS_HOST == "": if not os.environ.get("PYSQL_HACK", "false") == "false": @@ -61,8 +66,8 @@ def call(thing): def ingest(a, agent_actions={}, localagi=None): q = json.loads(a) - chunk_size = 1024 - chunk_overlap = 110 + chunk_size = MEMORY_CHUNK_SIZE + chunk_overlap = MEMORY_CHUNK_OVERLAP logger.info(">>> ingesting: ") logger.info(q) documents = [] @@ -71,11 +76,11 @@ def ingest(a, agent_actions={}, localagi=None): documents.extend(sitemap_loader.load()) texts = text_splitter.split_documents(documents) if MILVUS_HOST == "": - db = Chroma.from_documents(texts,embeddings,collection_name="memories", persist_directory=DB_DIR) + db = Chroma.from_documents(texts,embeddings,collection_name=MEMORY_COLLECTION, persist_directory=DB_DIR) db.persist() db = None else: - Milvus.from_documents(texts,embeddings,collection_name="memories", connection_args={"host": MILVUS_HOST, "port": MILVUS_PORT}) + Milvus.from_documents(texts,embeddings,collection_name=MEMORY_COLLECTION, connection_args={"host": MILVUS_HOST, "port": MILVUS_PORT}) return f"Documents ingested" def create_image(a, agent_actions={}, localagi=None): @@ -111,9 +116,9 @@ def save(memory, agent_actions={}, localagi=None): logger.info(">>> saving to memories: ") logger.info(q["content"]) if MILVUS_HOST == "": - chroma_client = Chroma(collection_name="memories",embedding_function=embeddings, persist_directory=DB_DIR) + chroma_client = Chroma(collection_name=MEMORY_COLLECTION,embedding_function=embeddings, persist_directory=DB_DIR) else: - chroma_client = Milvus(collection_name="memories",embedding_function=embeddings, connection_args={"host": MILVUS_HOST, "port": MILVUS_PORT}) + chroma_client = Milvus(collection_name=MEMORY_COLLECTION,embedding_function=embeddings, connection_args={"host": MILVUS_HOST, "port": MILVUS_PORT}) chroma_client.add_texts([q["content"]],[{"id": str(uuid.uuid4())}]) if MILVUS_HOST == "": chroma_client.persist() @@ -123,15 +128,33 @@ def save(memory, agent_actions={}, localagi=None): def search_memory(query, agent_actions={}, localagi=None): q = json.loads(query) if MILVUS_HOST == "": - chroma_client = Chroma(collection_name="memories",embedding_function=embeddings, persist_directory=DB_DIR) + chroma_client = Chroma(collection_name=MEMORY_COLLECTION,embedding_function=embeddings, persist_directory=DB_DIR) else: - chroma_client = Milvus(collection_name="memories",embedding_function=embeddings, connection_args={"host": MILVUS_HOST, "port": MILVUS_PORT}) - docs = chroma_client.search(q["keywords"], "mmr") + chroma_client = Milvus(collection_name=MEMORY_COLLECTION,embedding_function=embeddings, connection_args={"host": MILVUS_HOST, "port": MILVUS_PORT}) + #docs = chroma_client.search(q["keywords"], "mmr") + retriever = chroma_client.as_retriever(search_type=MEMORY_SEARCH_TYPE, search_kwargs={"k": MEMORY_RESULTS}) + + docs = retriever.get_relevant_documents(q["keywords"]) text_res="Memories found in the database:\n" + + sources = set() # To store unique sources + + # Collect unique sources + for document in docs: + if "source" in document.metadata: + sources.add(document.metadata["source"]) + for doc in docs: # drop newlines from page_content - doc.page_content = " ".join(doc.page_content.replace.split()) - text_res+="- "+doc.page_content+"\n" + content = doc.page_content.replace("\n", " ") + content = " ".join(content.split()) + text_res+="- "+content+"\n" + + # Print the relevant sources used for the answer + for source in sources: + if source.startswith("http"): + text_res += "" + source + "\n" + chroma_client = None #if args.postprocess: # return post_process(text_res)