Add support for milvus in the discord bot

This commit is contained in:
mudler
2023-08-26 00:17:19 +02:00
parent 317b91a7e9
commit a7462249a7
4 changed files with 70 additions and 32 deletions

View File

@@ -10,6 +10,8 @@ from langchain.document_loaders import (
import uuid import uuid
import sys import sys
from config import config
from queue import Queue from queue import Queue
import asyncio import asyncio
import threading import threading
@@ -24,22 +26,31 @@ import discord
import openai import openai
import urllib.request import urllib.request
from datetime import datetime from datetime import datetime
# these three lines swap the stdlib sqlite3 lib with the pysqlite3 package for chroma
__import__('pysqlite3')
import sys
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
from langchain.vectorstores import Chroma
from chromadb.config import Settings from chromadb.config import Settings
import json import json
import os import os
from io import StringIO from io import StringIO
FILE_NAME_FORMAT = '%Y_%m_%d_%H_%M_%S' FILE_NAME_FORMAT = '%Y_%m_%d_%H_%M_%S'
EMBEDDINGS_MODEL = os.environ.get("EMBEDDINGS_MODEL", "all-MiniLM-L6-v2") EMBEDDINGS_MODEL = config["agent"]["embeddings_model"]
EMBEDDINGS_API_BASE = os.environ.get("EMBEDDINGS_API_BASE", "http://api:8080") EMBEDDINGS_API_BASE = config["agent"]["embeddings_api_base"]
PERSISTENT_DIR = os.environ.get("PERSISTENT_DIR", "/tmp/data/") PERSISTENT_DIR = config["agent"]["persistent_dir"]
DB_DIR = os.environ.get("DB_DIR", "/tmp/data/db") MILVUS_HOST = config["agent"]["milvus_host"]
MILVUS_PORT = config["agent"]["milvus_port"]
DB_DIR = config["agent"]["db_dir"]
if MILVUS_HOST == "":
if not os.environ.get("PYSQL_HACK", "false") == "false":
# these three lines swap the stdlib sqlite3 lib with the pysqlite3 package for chroma
__import__('pysqlite3')
import sys
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
from langchain.vectorstores import Chroma
else:
from langchain.vectorstores import Milvus
embeddings = LocalAIEmbeddings(model=EMBEDDINGS_MODEL,openai_api_base=EMBEDDINGS_API_BASE) embeddings = LocalAIEmbeddings(model=EMBEDDINGS_MODEL,openai_api_base=EMBEDDINGS_API_BASE)
@@ -50,8 +61,8 @@ def call(thing):
def ingest(a, agent_actions={}, localagi=None): def ingest(a, agent_actions={}, localagi=None):
q = json.loads(a) q = json.loads(a)
chunk_size = 500 chunk_size = 1024
chunk_overlap = 50 chunk_overlap = 110
logger.info(">>> ingesting: ") logger.info(">>> ingesting: ")
logger.info(q) logger.info(q)
documents = [] documents = []
@@ -59,9 +70,12 @@ def ingest(a, agent_actions={}, localagi=None):
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
documents.extend(sitemap_loader.load()) documents.extend(sitemap_loader.load())
texts = text_splitter.split_documents(documents) texts = text_splitter.split_documents(documents)
db = Chroma.from_documents(texts,embeddings,collection_name="memories", persist_directory=DB_DIR) if MILVUS_HOST == "":
db.persist() db = Chroma.from_documents(texts,embeddings,collection_name="memories", persist_directory=DB_DIR)
db = None db.persist()
db = None
else:
Milvus.from_documents(texts,embeddings,collection_name="memories", connection_args={"host": MILVUS_HOST, "port": MILVUS_PORT})
return f"Documents ingested" return f"Documents ingested"
def create_image(a, agent_actions={}, localagi=None): def create_image(a, agent_actions={}, localagi=None):
@@ -96,24 +110,33 @@ def save(memory, agent_actions={}, localagi=None):
q = json.loads(memory) q = json.loads(memory)
logger.info(">>> saving to memories: ") logger.info(">>> saving to memories: ")
logger.info(q["content"]) logger.info(q["content"])
chroma_client = Chroma(collection_name="memories",embedding_function=embeddings, persist_directory=DB_DIR) if MILVUS_HOST == "":
chroma_client = Chroma(collection_name="memories",embedding_function=embeddings, persist_directory=DB_DIR)
else:
chroma_client = Milvus(collection_name="memories",embedding_function=embeddings, connection_args={"host": MILVUS_HOST, "port": MILVUS_PORT})
chroma_client.add_texts([q["content"]],[{"id": str(uuid.uuid4())}]) chroma_client.add_texts([q["content"]],[{"id": str(uuid.uuid4())}])
chroma_client.persist() if MILVUS_HOST == "":
chroma_client = None chroma_client.persist()
chroma_client = None
return f"The object was saved permanently to memory." return f"The object was saved permanently to memory."
def search_memory(query, agent_actions={}, localagi=None): def search_memory(query, agent_actions={}, localagi=None):
q = json.loads(query) q = json.loads(query)
chroma_client = Chroma(collection_name="memories",embedding_function=embeddings, persist_directory=DB_DIR) if MILVUS_HOST == "":
docs = chroma_client.similarity_search(q["reasoning"]) chroma_client = Chroma(collection_name="memories",embedding_function=embeddings, persist_directory=DB_DIR)
else:
chroma_client = Milvus(collection_name="memories",embedding_function=embeddings, connection_args={"host": MILVUS_HOST, "port": MILVUS_PORT})
docs = chroma_client.search(q["keywords"], "mmr")
text_res="Memories found in the database:\n" text_res="Memories found in the database:\n"
for doc in docs: for doc in docs:
# drop newlines from page_content
doc.page_content = " ".join(doc.page_content.replace.split())
text_res+="- "+doc.page_content+"\n" text_res+="- "+doc.page_content+"\n"
chroma_client = None chroma_client = None
#if args.postprocess: #if args.postprocess:
# return post_process(text_res) # return post_process(text_res)
#return text_res return text_res
return localagi.post_process(text_res) #return localagi.post_process(text_res)
# write file to disk with content # write file to disk with content
def save_file(arg, agent_actions={}, localagi=None): def save_file(arg, agent_actions={}, localagi=None):
@@ -317,12 +340,12 @@ agent_actions = {
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
"reasoning": { "keywords": {
"type": "string", "type": "string",
"description": "reasoning behind the intent" "description": "reasoning behind the intent"
}, },
}, },
"required": ["reasoning"] "required": ["keywords"]
} }
}, },
}, },

View File

@@ -0,0 +1,5 @@
from configparser import ConfigParser
config_file = "config.ini"
config = ConfigParser(interpolation=None)
config.read(config_file)

View File

@@ -7,12 +7,21 @@ GitHub: https://https://github.com/StefanRial/ClaudeBot
E-Mail: mail.stefanrial@gmail.com E-Mail: mail.stefanrial@gmail.com
""" """
import discord from config import config
import openai
import urllib.request
import os import os
OPENAI_API_KEY = config["openai"][str("api_key")]
if OPENAI_API_KEY == "":
OPENAI_API_KEY = "foo"
os.environ["OPENAI_API_BASE"] = config["agent"]["api_base"]
os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY
import openai
import discord
import urllib.request
from datetime import datetime from datetime import datetime
from configparser import ConfigParser
from queue import Queue from queue import Queue
import agent import agent
from agent import agent_actions from agent import agent_actions
@@ -23,14 +32,11 @@ from discord import app_commands
import functools import functools
import typing import typing
config_file = "config.ini"
config = ConfigParser(interpolation=None)
config.read(config_file)
SERVER_ID = config["discord"]["server_id"] SERVER_ID = config["discord"]["server_id"]
DISCORD_API_KEY = config["discord"][str("api_key")] DISCORD_API_KEY = config["discord"][str("api_key")]
OPENAI_ORG = config["openai"][str("organization")] OPENAI_ORG = config["openai"][str("organization")]
OPENAI_API_KEY = config["openai"][str("api_key")]
FILE_PATH = config["settings"][str("file_path")] FILE_PATH = config["settings"][str("file_path")]
FILE_NAME_FORMAT = config["settings"][str("file_name_format")] FILE_NAME_FORMAT = config["settings"][str("file_name_format")]
@@ -126,6 +132,7 @@ def run_localagi_thread_history(history, message, thread, loop):
message.content, message.content,
history, history,
subtaskContext=True, subtaskContext=True,
critic=True,
) )
analyze_history(history, conversation_history, call, thread) analyze_history(history, conversation_history, call, thread)
@@ -161,6 +168,7 @@ def run_localagi_message(message, loop):
conversation_history = localagi.evaluate( conversation_history = localagi.evaluate(
message.content, message.content,
[], [],
critic=True,
subtaskContext=True, subtaskContext=True,
) )
analyze_history([], conversation_history, call, message.channel) analyze_history([], conversation_history, call, message.channel)
@@ -217,6 +225,7 @@ def run_localagi(interaction, prompt, loop):
prompt, prompt,
messages, messages,
subtaskContext=True, subtaskContext=True,
critic=True,
) )
analyze_history(messages, conversation_history, call, interaction.channel) analyze_history(messages, conversation_history, call, interaction.channel)
call(sent_message.edit(content=f"<@{user.id}> {conversation_history[-1]['content']}")) call(sent_message.edit(content=f"<@{user.id}> {conversation_history[-1]['content']}"))

View File

@@ -8,3 +8,4 @@ chromadb
pysqlite3-binary pysqlite3-binary
langchain langchain
beautifulsoup4 beautifulsoup4
pymilvus