Add support for milvus in the discord bot
This commit is contained in:
@@ -10,6 +10,8 @@ from langchain.document_loaders import (
|
|||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
import sys
|
import sys
|
||||||
|
from config import config
|
||||||
|
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
import asyncio
|
import asyncio
|
||||||
import threading
|
import threading
|
||||||
@@ -24,22 +26,31 @@ import discord
|
|||||||
import openai
|
import openai
|
||||||
import urllib.request
|
import urllib.request
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
# these three lines swap the stdlib sqlite3 lib with the pysqlite3 package for chroma
|
|
||||||
__import__('pysqlite3')
|
|
||||||
import sys
|
|
||||||
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
|
|
||||||
|
|
||||||
from langchain.vectorstores import Chroma
|
|
||||||
from chromadb.config import Settings
|
from chromadb.config import Settings
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
FILE_NAME_FORMAT = '%Y_%m_%d_%H_%M_%S'
|
FILE_NAME_FORMAT = '%Y_%m_%d_%H_%M_%S'
|
||||||
|
|
||||||
EMBEDDINGS_MODEL = os.environ.get("EMBEDDINGS_MODEL", "all-MiniLM-L6-v2")
|
EMBEDDINGS_MODEL = config["agent"]["embeddings_model"]
|
||||||
EMBEDDINGS_API_BASE = os.environ.get("EMBEDDINGS_API_BASE", "http://api:8080")
|
EMBEDDINGS_API_BASE = config["agent"]["embeddings_api_base"]
|
||||||
PERSISTENT_DIR = os.environ.get("PERSISTENT_DIR", "/tmp/data/")
|
PERSISTENT_DIR = config["agent"]["persistent_dir"]
|
||||||
DB_DIR = os.environ.get("DB_DIR", "/tmp/data/db")
|
MILVUS_HOST = config["agent"]["milvus_host"]
|
||||||
|
MILVUS_PORT = config["agent"]["milvus_port"]
|
||||||
|
DB_DIR = config["agent"]["db_dir"]
|
||||||
|
|
||||||
|
if MILVUS_HOST == "":
|
||||||
|
if not os.environ.get("PYSQL_HACK", "false") == "false":
|
||||||
|
# these three lines swap the stdlib sqlite3 lib with the pysqlite3 package for chroma
|
||||||
|
__import__('pysqlite3')
|
||||||
|
import sys
|
||||||
|
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
|
||||||
|
|
||||||
|
from langchain.vectorstores import Chroma
|
||||||
|
else:
|
||||||
|
from langchain.vectorstores import Milvus
|
||||||
|
|
||||||
embeddings = LocalAIEmbeddings(model=EMBEDDINGS_MODEL,openai_api_base=EMBEDDINGS_API_BASE)
|
embeddings = LocalAIEmbeddings(model=EMBEDDINGS_MODEL,openai_api_base=EMBEDDINGS_API_BASE)
|
||||||
|
|
||||||
@@ -50,8 +61,8 @@ def call(thing):
|
|||||||
|
|
||||||
def ingest(a, agent_actions={}, localagi=None):
|
def ingest(a, agent_actions={}, localagi=None):
|
||||||
q = json.loads(a)
|
q = json.loads(a)
|
||||||
chunk_size = 500
|
chunk_size = 1024
|
||||||
chunk_overlap = 50
|
chunk_overlap = 110
|
||||||
logger.info(">>> ingesting: ")
|
logger.info(">>> ingesting: ")
|
||||||
logger.info(q)
|
logger.info(q)
|
||||||
documents = []
|
documents = []
|
||||||
@@ -59,9 +70,12 @@ def ingest(a, agent_actions={}, localagi=None):
|
|||||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||||||
documents.extend(sitemap_loader.load())
|
documents.extend(sitemap_loader.load())
|
||||||
texts = text_splitter.split_documents(documents)
|
texts = text_splitter.split_documents(documents)
|
||||||
|
if MILVUS_HOST == "":
|
||||||
db = Chroma.from_documents(texts,embeddings,collection_name="memories", persist_directory=DB_DIR)
|
db = Chroma.from_documents(texts,embeddings,collection_name="memories", persist_directory=DB_DIR)
|
||||||
db.persist()
|
db.persist()
|
||||||
db = None
|
db = None
|
||||||
|
else:
|
||||||
|
Milvus.from_documents(texts,embeddings,collection_name="memories", connection_args={"host": MILVUS_HOST, "port": MILVUS_PORT})
|
||||||
return f"Documents ingested"
|
return f"Documents ingested"
|
||||||
|
|
||||||
def create_image(a, agent_actions={}, localagi=None):
|
def create_image(a, agent_actions={}, localagi=None):
|
||||||
@@ -96,24 +110,33 @@ def save(memory, agent_actions={}, localagi=None):
|
|||||||
q = json.loads(memory)
|
q = json.loads(memory)
|
||||||
logger.info(">>> saving to memories: ")
|
logger.info(">>> saving to memories: ")
|
||||||
logger.info(q["content"])
|
logger.info(q["content"])
|
||||||
|
if MILVUS_HOST == "":
|
||||||
chroma_client = Chroma(collection_name="memories",embedding_function=embeddings, persist_directory=DB_DIR)
|
chroma_client = Chroma(collection_name="memories",embedding_function=embeddings, persist_directory=DB_DIR)
|
||||||
|
else:
|
||||||
|
chroma_client = Milvus(collection_name="memories",embedding_function=embeddings, connection_args={"host": MILVUS_HOST, "port": MILVUS_PORT})
|
||||||
chroma_client.add_texts([q["content"]],[{"id": str(uuid.uuid4())}])
|
chroma_client.add_texts([q["content"]],[{"id": str(uuid.uuid4())}])
|
||||||
|
if MILVUS_HOST == "":
|
||||||
chroma_client.persist()
|
chroma_client.persist()
|
||||||
chroma_client = None
|
chroma_client = None
|
||||||
return f"The object was saved permanently to memory."
|
return f"The object was saved permanently to memory."
|
||||||
|
|
||||||
def search_memory(query, agent_actions={}, localagi=None):
|
def search_memory(query, agent_actions={}, localagi=None):
|
||||||
q = json.loads(query)
|
q = json.loads(query)
|
||||||
|
if MILVUS_HOST == "":
|
||||||
chroma_client = Chroma(collection_name="memories",embedding_function=embeddings, persist_directory=DB_DIR)
|
chroma_client = Chroma(collection_name="memories",embedding_function=embeddings, persist_directory=DB_DIR)
|
||||||
docs = chroma_client.similarity_search(q["reasoning"])
|
else:
|
||||||
|
chroma_client = Milvus(collection_name="memories",embedding_function=embeddings, connection_args={"host": MILVUS_HOST, "port": MILVUS_PORT})
|
||||||
|
docs = chroma_client.search(q["keywords"], "mmr")
|
||||||
text_res="Memories found in the database:\n"
|
text_res="Memories found in the database:\n"
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
|
# drop newlines from page_content
|
||||||
|
doc.page_content = " ".join(doc.page_content.replace.split())
|
||||||
text_res+="- "+doc.page_content+"\n"
|
text_res+="- "+doc.page_content+"\n"
|
||||||
chroma_client = None
|
chroma_client = None
|
||||||
#if args.postprocess:
|
#if args.postprocess:
|
||||||
# return post_process(text_res)
|
# return post_process(text_res)
|
||||||
#return text_res
|
return text_res
|
||||||
return localagi.post_process(text_res)
|
#return localagi.post_process(text_res)
|
||||||
|
|
||||||
# write file to disk with content
|
# write file to disk with content
|
||||||
def save_file(arg, agent_actions={}, localagi=None):
|
def save_file(arg, agent_actions={}, localagi=None):
|
||||||
@@ -317,12 +340,12 @@ agent_actions = {
|
|||||||
"parameters": {
|
"parameters": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"reasoning": {
|
"keywords": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "reasoning behind the intent"
|
"description": "reasoning behind the intent"
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"required": ["reasoning"]
|
"required": ["keywords"]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|||||||
5
examples/discord/config.py
Normal file
5
examples/discord/config.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
from configparser import ConfigParser
|
||||||
|
|
||||||
|
config_file = "config.ini"
|
||||||
|
config = ConfigParser(interpolation=None)
|
||||||
|
config.read(config_file)
|
||||||
@@ -7,12 +7,21 @@ GitHub: https://https://github.com/StefanRial/ClaudeBot
|
|||||||
E-Mail: mail.stefanrial@gmail.com
|
E-Mail: mail.stefanrial@gmail.com
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import discord
|
from config import config
|
||||||
import openai
|
|
||||||
import urllib.request
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
OPENAI_API_KEY = config["openai"][str("api_key")]
|
||||||
|
|
||||||
|
if OPENAI_API_KEY == "":
|
||||||
|
OPENAI_API_KEY = "foo"
|
||||||
|
os.environ["OPENAI_API_BASE"] = config["agent"]["api_base"]
|
||||||
|
os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY
|
||||||
|
import openai
|
||||||
|
|
||||||
|
import discord
|
||||||
|
|
||||||
|
import urllib.request
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from configparser import ConfigParser
|
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
import agent
|
import agent
|
||||||
from agent import agent_actions
|
from agent import agent_actions
|
||||||
@@ -23,14 +32,11 @@ from discord import app_commands
|
|||||||
import functools
|
import functools
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
config_file = "config.ini"
|
|
||||||
config = ConfigParser(interpolation=None)
|
|
||||||
config.read(config_file)
|
|
||||||
|
|
||||||
SERVER_ID = config["discord"]["server_id"]
|
SERVER_ID = config["discord"]["server_id"]
|
||||||
DISCORD_API_KEY = config["discord"][str("api_key")]
|
DISCORD_API_KEY = config["discord"][str("api_key")]
|
||||||
OPENAI_ORG = config["openai"][str("organization")]
|
OPENAI_ORG = config["openai"][str("organization")]
|
||||||
OPENAI_API_KEY = config["openai"][str("api_key")]
|
|
||||||
|
|
||||||
|
|
||||||
FILE_PATH = config["settings"][str("file_path")]
|
FILE_PATH = config["settings"][str("file_path")]
|
||||||
FILE_NAME_FORMAT = config["settings"][str("file_name_format")]
|
FILE_NAME_FORMAT = config["settings"][str("file_name_format")]
|
||||||
@@ -126,6 +132,7 @@ def run_localagi_thread_history(history, message, thread, loop):
|
|||||||
message.content,
|
message.content,
|
||||||
history,
|
history,
|
||||||
subtaskContext=True,
|
subtaskContext=True,
|
||||||
|
critic=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
analyze_history(history, conversation_history, call, thread)
|
analyze_history(history, conversation_history, call, thread)
|
||||||
@@ -161,6 +168,7 @@ def run_localagi_message(message, loop):
|
|||||||
conversation_history = localagi.evaluate(
|
conversation_history = localagi.evaluate(
|
||||||
message.content,
|
message.content,
|
||||||
[],
|
[],
|
||||||
|
critic=True,
|
||||||
subtaskContext=True,
|
subtaskContext=True,
|
||||||
)
|
)
|
||||||
analyze_history([], conversation_history, call, message.channel)
|
analyze_history([], conversation_history, call, message.channel)
|
||||||
@@ -217,6 +225,7 @@ def run_localagi(interaction, prompt, loop):
|
|||||||
prompt,
|
prompt,
|
||||||
messages,
|
messages,
|
||||||
subtaskContext=True,
|
subtaskContext=True,
|
||||||
|
critic=True,
|
||||||
)
|
)
|
||||||
analyze_history(messages, conversation_history, call, interaction.channel)
|
analyze_history(messages, conversation_history, call, interaction.channel)
|
||||||
call(sent_message.edit(content=f"<@{user.id}> {conversation_history[-1]['content']}"))
|
call(sent_message.edit(content=f"<@{user.id}> {conversation_history[-1]['content']}"))
|
||||||
|
|||||||
@@ -8,3 +8,4 @@ chromadb
|
|||||||
pysqlite3-binary
|
pysqlite3-binary
|
||||||
langchain
|
langchain
|
||||||
beautifulsoup4
|
beautifulsoup4
|
||||||
|
pymilvus
|
||||||
Reference in New Issue
Block a user