Spaces:
Runtime error
Runtime error
Upload 41 files
Browse files- .gitattributes +2 -0
- Chroma_db/readme.txt +0 -0
- Config/__pycache__/config.cpython-310.pyc +0 -0
- Config/config.py +14 -0
- Faiss_db/readme.txt +0 -0
- Faiss_db/sss1/index.faiss +0 -0
- Faiss_db/sss1/index.pkl +3 -0
- Neo4j/__pycache__/graph_extract.cpython-310.pyc +0 -0
- Neo4j/__pycache__/neo4j_op.cpython-310.pyc +0 -0
- Neo4j/graph_extract.py +69 -0
- Neo4j/neo4j_op.py +105 -0
- Ollama_api/__pycache__/ollama_api.cpython-310.pyc +0 -0
- Ollama_api/ollama_api.py +21 -0
- app.py +354 -0
- embeding/__pycache__/asr_utils.cpython-310.pyc +0 -0
- embeding/__pycache__/chromadb.cpython-310.pyc +0 -0
- embeding/__pycache__/elasticsearchStore.cpython-310.pyc +0 -0
- embeding/__pycache__/faissdb.cpython-310.pyc +0 -0
- embeding/asr_utils.py +17 -0
- embeding/chromadb.py +134 -0
- embeding/elasticsearchStore.py +147 -0
- embeding/faissdb.py +138 -0
- embeding/tmp.txt +2 -0
- graph_demo_ui.py +87 -0
- img/graph-tool.png +3 -0
- img/readme.txt +1 -0
- img/zhu.png +3 -0
- img/zhuye.png +0 -0
- img//345/244/215/346/235/202/346/226/271/345/274/217.png +0 -0
- img//345/276/256/344/277/241/345/233/276/347/211/207_20240524180648.jpg +0 -0
- rag/__init__.py +0 -0
- rag/__pycache__/__init__.cpython-310.pyc +0 -0
- rag/__pycache__/config.cpython-310.pyc +0 -0
- rag/__pycache__/rag_class.cpython-310.pyc +0 -0
- rag/__pycache__/rerank.cpython-310.pyc +0 -0
- rag/__pycache__/rerank.cpython-39.pyc +0 -0
- rag/__pycache__/rerank_code.cpython-310.pyc +0 -0
- rag/rag_class.py +169 -0
- rag/rerank_code.py +21 -0
- requirements.txt +10 -0
- test/__init__.py +0 -0
- test/graph2neo4j.py +25 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
img/graph-tool.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
img/zhu.png filter=lfs diff=lfs merge=lfs -text
|
Chroma_db/readme.txt
ADDED
|
File without changes
|
Config/__pycache__/config.cpython-310.pyc
ADDED
|
Binary file (362 Bytes). View file
|
|
|
Config/config.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 向量数据库选择 【chroma:1】 ,【faiss 2】,【ElasticsearchStore 3】
|
| 2 |
+
VECTOR_DB = 2
|
| 3 |
+
DB_directory = "./Chroma_db/"
|
| 4 |
+
if VECTOR_DB==2:
|
| 5 |
+
DB_directory ="./Faiss_db/"
|
| 6 |
+
elif VECTOR_DB==3:
|
| 7 |
+
DB_directory = "es"
|
| 8 |
+
|
| 9 |
+
# 配置neo4j
|
| 10 |
+
neo4j_host = "bolt://localhost:7687"
|
| 11 |
+
neo4j_name = "neo4j"
|
| 12 |
+
neo4j_pwd = "12345678"
|
| 13 |
+
# 测试了 llama3:8b,gemma2:9b,qwen2:7b,glm4:9b,arcee-ai/arcee-agent:latest 目前来看 qwen2:7 效果最好
|
| 14 |
+
neo4j_model = "qwen2:7b"
|
Faiss_db/readme.txt
ADDED
|
File without changes
|
Faiss_db/sss1/index.faiss
ADDED
|
Binary file (82 kB). View file
|
|
|
Faiss_db/sss1/index.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2bb588f4bd46218f42b045c42163bdcf3cc76a19e37458823ceaeaf8a1454e3b
|
| 3 |
+
size 9362
|
Neo4j/__pycache__/graph_extract.cpython-310.pyc
ADDED
|
Binary file (2.51 kB). View file
|
|
|
Neo4j/__pycache__/neo4j_op.cpython-310.pyc
ADDED
|
Binary file (3.89 kB). View file
|
|
|
Neo4j/graph_extract.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_community.llms import Ollama
|
| 2 |
+
from Config.config import neo4j_model
|
| 3 |
+
|
| 4 |
+
# 测试了 llama3:8b,gemma2:9b,qwen2:7b,glm4:9b,arcee-ai/arcee-agent:latest 目前来看 qwen2:7 效果最好
|
| 5 |
+
llm = Ollama(model=neo4j_model)
|
| 6 |
+
|
| 7 |
+
json_example = {'edges': [
|
| 8 |
+
{
|
| 9 |
+
'label': 'label 1',
|
| 10 |
+
'source': 'source 1',
|
| 11 |
+
'target': 'target 1'},
|
| 12 |
+
{
|
| 13 |
+
'label': 'label 1',
|
| 14 |
+
'source': 'source 1',
|
| 15 |
+
'target': 'target 1'}
|
| 16 |
+
],
|
| 17 |
+
'nodes': [{'name': 'label 1'},
|
| 18 |
+
{'name': 'label 2'},
|
| 19 |
+
{'name': 'label 3'}]
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
__retriever_prompt = f"""
|
| 23 |
+
您是一名专门从事知识图谱创建的人工智能专家,目标是根据给定的输入或请求捕获关系。
|
| 24 |
+
基于各种形式的用户输入,如段落、电子邮件、文本文件等。
|
| 25 |
+
你的任务是根据输入创建一个知识图谱。
|
| 26 |
+
nodes中每个元素只有一个name参数,name对应的值是一个实体,实体来自输入的词语或短语。
|
| 27 |
+
edges还必须有一个label参数,其中label是输入中的直接词语或短语,edges中的source和target取自nodes中的name。
|
| 28 |
+
|
| 29 |
+
仅使用JSON进行响应,其格式可以在python中进行jsonify,并直接输入cy.add(data),
|
| 30 |
+
您可以参考给定的示例:{json_example}。存储node和edge的数组中,最后一个元素后边不要有逗号,
|
| 31 |
+
确保边的目标和源与现有节点匹配。
|
| 32 |
+
不要在JSON的上方和下方包含markdown三引号,直接用花括号括起来。
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def generate_graph_info(raw_text: str) -> str | None:
|
| 37 |
+
"""
|
| 38 |
+
generate graph info from raw text
|
| 39 |
+
:param raw_text:
|
| 40 |
+
:return:
|
| 41 |
+
"""
|
| 42 |
+
messages = [
|
| 43 |
+
{"role": "system", "content": "你现在扮演信息抽取的角色,要求根据用户输入和AI的回答,正确提取出信息,记得不多对实体进行翻译。"},
|
| 44 |
+
{"role": "user", "content": raw_text},
|
| 45 |
+
{"role": "user", "content": __retriever_prompt}
|
| 46 |
+
]
|
| 47 |
+
print("解析中....")
|
| 48 |
+
for i in range(3):
|
| 49 |
+
graph_info_result = llm.invoke(messages)
|
| 50 |
+
if len(graph_info_result) < 10:
|
| 51 |
+
print("-------", i, "-------------------")
|
| 52 |
+
continue
|
| 53 |
+
else:
|
| 54 |
+
break
|
| 55 |
+
print(graph_info_result)
|
| 56 |
+
return graph_info_result
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def update_graph(raw_text):
|
| 60 |
+
# raw_text = request.json.get('text', '')
|
| 61 |
+
try:
|
| 62 |
+
result = generate_graph_info(raw_text)
|
| 63 |
+
if '```' in result:
|
| 64 |
+
graph_data = eval(result.split('```', 2)[1].replace("json", ''))
|
| 65 |
+
else:
|
| 66 |
+
graph_data = eval(str(result))
|
| 67 |
+
return graph_data
|
| 68 |
+
except Exception as e:
|
| 69 |
+
return {'error': f"Error parsing graph data: {str(e)}"}
|
Neo4j/neo4j_op.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from py2neo import Graph, Node, Relationship
|
| 2 |
+
from langchain_community.document_loaders import TextLoader,UnstructuredCSVLoader, UnstructuredPDFLoader,UnstructuredWordDocumentLoader,UnstructuredExcelLoader,UnstructuredMarkdownLoader
|
| 3 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class KnowledgeGraph:
|
| 8 |
+
def __init__(self, uri, user, password):
|
| 9 |
+
self.graph = Graph(uri, auth=(user, password))
|
| 10 |
+
|
| 11 |
+
def parse_data(self,file):
|
| 12 |
+
if "txt" in file.lower() or "csv" in file.lower():
|
| 13 |
+
try:
|
| 14 |
+
loaders = UnstructuredCSVLoader(file)
|
| 15 |
+
data = loaders.load()
|
| 16 |
+
except:
|
| 17 |
+
loaders = TextLoader(file,encoding="utf-8")
|
| 18 |
+
data = loaders.load()
|
| 19 |
+
if ".doc" in file.lower() or ".docx" in file.lower():
|
| 20 |
+
loaders = UnstructuredWordDocumentLoader(file)
|
| 21 |
+
data = loaders.load()
|
| 22 |
+
if "pdf" in file.lower():
|
| 23 |
+
loaders = UnstructuredPDFLoader(file)
|
| 24 |
+
data = loaders.load()
|
| 25 |
+
if ".xlsx" in file.lower():
|
| 26 |
+
loaders = UnstructuredExcelLoader(file)
|
| 27 |
+
data = loaders.load()
|
| 28 |
+
if ".md" in file.lower():
|
| 29 |
+
loaders = UnstructuredMarkdownLoader(file)
|
| 30 |
+
data = loaders.load()
|
| 31 |
+
return data
|
| 32 |
+
|
| 33 |
+
# 切分 数据
|
| 34 |
+
def split_files(self, files,chunk_size=500, chunk_overlap=100):
|
| 35 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
| 36 |
+
print("开始创建数据库 ....")
|
| 37 |
+
tmps = []
|
| 38 |
+
for file in files:
|
| 39 |
+
data = self.parse_data(file)
|
| 40 |
+
tmps.extend(data)
|
| 41 |
+
|
| 42 |
+
splits = text_splitter.split_documents(tmps)
|
| 43 |
+
|
| 44 |
+
return splits
|
| 45 |
+
|
| 46 |
+
def create_node(self, label, properties):
|
| 47 |
+
matcher = self.graph.nodes.match(label, **properties)
|
| 48 |
+
if matcher.first():
|
| 49 |
+
return matcher.first()
|
| 50 |
+
else:
|
| 51 |
+
node = Node(label, **properties)
|
| 52 |
+
self.graph.create(node)
|
| 53 |
+
return node
|
| 54 |
+
|
| 55 |
+
def create_relationship(self, label1, properties1, label2, properties2, relationship_type,
|
| 56 |
+
relationship_properties={}):
|
| 57 |
+
node1 = self.create_node(label1, properties1)
|
| 58 |
+
node2 = self.create_node(label2, properties2)
|
| 59 |
+
|
| 60 |
+
matcher = self.graph.match((node1, node2), r_type=relationship_type)
|
| 61 |
+
for rel in matcher:
|
| 62 |
+
if all(rel[key] == value for key, value in relationship_properties.items()):
|
| 63 |
+
return rel
|
| 64 |
+
|
| 65 |
+
relationship = Relationship(node1, relationship_type, node2, **relationship_properties)
|
| 66 |
+
self.graph.create(relationship)
|
| 67 |
+
return relationship
|
| 68 |
+
|
| 69 |
+
def delete_node(self, label, properties):
|
| 70 |
+
matcher = self.graph.nodes.match(label, **properties)
|
| 71 |
+
node = matcher.first()
|
| 72 |
+
if node:
|
| 73 |
+
self.graph.delete(node)
|
| 74 |
+
return True
|
| 75 |
+
return False
|
| 76 |
+
|
| 77 |
+
def update_node(self, label, identifier, updates):
|
| 78 |
+
matcher = self.graph.nodes.match(label, **identifier)
|
| 79 |
+
node = matcher.first()
|
| 80 |
+
if node:
|
| 81 |
+
for key, value in updates.items():
|
| 82 |
+
node[key] = value
|
| 83 |
+
self.graph.push(node)
|
| 84 |
+
return node
|
| 85 |
+
return None
|
| 86 |
+
|
| 87 |
+
def find_node(self, label, properties):
|
| 88 |
+
matcher = self.graph.nodes.match(label, **properties)
|
| 89 |
+
return list(matcher)
|
| 90 |
+
|
| 91 |
+
def create_nodes(self, label, properties_list):
|
| 92 |
+
nodes = []
|
| 93 |
+
for properties in properties_list:
|
| 94 |
+
node = self.create_node(label, properties)
|
| 95 |
+
nodes.append(node)
|
| 96 |
+
return nodes
|
| 97 |
+
|
| 98 |
+
def create_relationships(self, relationships):
|
| 99 |
+
created_relationships = []
|
| 100 |
+
for rel in relationships:
|
| 101 |
+
label1, properties1, label2, properties2, relationship_type = rel
|
| 102 |
+
relationship = self.create_relationship(label1, properties1, label2, properties2, relationship_type)
|
| 103 |
+
created_relationships.append(relationship)
|
| 104 |
+
return created_relationships
|
| 105 |
+
|
Ollama_api/__pycache__/ollama_api.cpython-310.pyc
ADDED
|
Binary file (721 Bytes). View file
|
|
|
Ollama_api/ollama_api.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
# 提供api获取ollama 模型列表
|
| 5 |
+
def get_llm():
|
| 6 |
+
respone = requests.get(url="http://localhost:11434/api/tags")
|
| 7 |
+
result = json.loads(respone.content)
|
| 8 |
+
llms = []
|
| 9 |
+
for llm in result["models"]:
|
| 10 |
+
if "code" not in llm["name"] and "embed" not in llm["name"]:
|
| 11 |
+
llms.append(llm["name"])
|
| 12 |
+
return llms
|
| 13 |
+
|
| 14 |
+
def get_embeding_model():
|
| 15 |
+
respone = requests.get(url="http://localhost:11434/api/tags")
|
| 16 |
+
result = json.loads(respone.content)
|
| 17 |
+
llms = []
|
| 18 |
+
for llm in result["models"]:
|
| 19 |
+
if "embed" in llm["name"]:
|
| 20 |
+
llms.append(llm["name"])
|
| 21 |
+
return llms
|
app.py
ADDED
|
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import threading
|
| 3 |
+
import asyncio
|
| 4 |
+
import logging
|
| 5 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 6 |
+
from functools import lru_cache
|
| 7 |
+
import requests
|
| 8 |
+
import json
|
| 9 |
+
|
| 10 |
+
# 假设这些是您的自定义模块,需要根据实际情况进行调整
|
| 11 |
+
from Config.config import VECTOR_DB, DB_directory
|
| 12 |
+
from Ollama_api.ollama_api import *
|
| 13 |
+
from rag.rag_class import *
|
| 14 |
+
|
| 15 |
+
# 设置日志
|
| 16 |
+
logging.basicConfig(level=logging.INFO)
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
# 根据VECTOR_DB选择合适的向量数据库
|
| 20 |
+
if VECTOR_DB == 1:
|
| 21 |
+
from embeding.chromadb import ChromaDB as vectorDB
|
| 22 |
+
vectordb = vectorDB(persist_directory=DB_directory)
|
| 23 |
+
elif VECTOR_DB == 2:
|
| 24 |
+
from embeding.faissdb import FaissDB as vectorDB
|
| 25 |
+
vectordb = vectorDB(persist_directory=DB_directory)
|
| 26 |
+
elif VECTOR_DB == 3:
|
| 27 |
+
from embeding.elasticsearchStore import ElsStore as vectorDB
|
| 28 |
+
vectordb = vectorDB()
|
| 29 |
+
|
| 30 |
+
# 存储上传的文件
|
| 31 |
+
uploaded_files = []
|
| 32 |
+
|
| 33 |
+
@lru_cache(maxsize=100)
|
| 34 |
+
def get_knowledge_base_files():
|
| 35 |
+
cl_dict = {}
|
| 36 |
+
cols = vectordb.get_all_collections_name()
|
| 37 |
+
for c_name in cols:
|
| 38 |
+
cl_dict[c_name] = vectordb.get_collcetion_content_files(c_name)
|
| 39 |
+
return cl_dict
|
| 40 |
+
|
| 41 |
+
knowledge_base_files = get_knowledge_base_files()
|
| 42 |
+
|
| 43 |
+
def upload_files(files):
|
| 44 |
+
if files:
|
| 45 |
+
new_files = [file.name for file in files]
|
| 46 |
+
uploaded_files.extend(new_files)
|
| 47 |
+
update_knowledge_base_files()
|
| 48 |
+
logger.info(f"Uploaded files: {new_files}")
|
| 49 |
+
return update_file_list(), new_files, "<div style='color: green; padding: 10px; border: 2px solid green; border-radius: 5px;'>Upload successful!</div>"
|
| 50 |
+
update_knowledge_base_files()
|
| 51 |
+
return update_file_list(), [], "<div style='color: red; padding: 10px; border: 2px solid red; border-radius: 5px;'>Upload failed!</div>"
|
| 52 |
+
|
| 53 |
+
def delete_files(selected_files):
|
| 54 |
+
global uploaded_files
|
| 55 |
+
uploaded_files = [f for f in uploaded_files if f not in selected_files]
|
| 56 |
+
if selected_files:
|
| 57 |
+
update_knowledge_base_files()
|
| 58 |
+
logger.info(f"Deleted files: {selected_files}")
|
| 59 |
+
return update_file_list(), "<div style='color: green; padding: 10px; border: 2px solid green; border-radius: 5px;'>Delete successful!</div>"
|
| 60 |
+
update_knowledge_base_files()
|
| 61 |
+
return update_file_list(), "<div style='color: red; padding: 10px; border: 2px solid red; border-radius: 5px;'>Delete failed!</div>"
|
| 62 |
+
|
| 63 |
+
def delete_collection(selected_knowledge_base):
|
| 64 |
+
if selected_knowledge_base and selected_knowledge_base != "创建知识库":
|
| 65 |
+
vectordb.delete_collection(selected_knowledge_base)
|
| 66 |
+
update_knowledge_base_files()
|
| 67 |
+
logger.info(f"Deleted collection: {selected_knowledge_base}")
|
| 68 |
+
return update_knowledge_base_dropdown(), "<div style='color: green; padding: 10px; border: 2px solid green; border-radius: 5px;'>Collection deleted successfully!</div>"
|
| 69 |
+
return update_knowledge_base_dropdown(), "<div style='color: red; padding: 10px; border: 2px solid red; border-radius: 5px;'>Delete collection failed!</div>"
|
| 70 |
+
|
| 71 |
+
async def async_vectorize_files(selected_files, selected_knowledge_base, new_kb_name, chunk_size, chunk_overlap):
|
| 72 |
+
if selected_files:
|
| 73 |
+
if selected_knowledge_base == "创建知识库":
|
| 74 |
+
knowledge_base = new_kb_name
|
| 75 |
+
vectordb.create_collection(selected_files, knowledge_base, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
| 76 |
+
else:
|
| 77 |
+
knowledge_base = selected_knowledge_base
|
| 78 |
+
vectordb.add_chroma(selected_files, knowledge_base, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
| 79 |
+
|
| 80 |
+
if knowledge_base not in knowledge_base_files:
|
| 81 |
+
knowledge_base_files[knowledge_base] = []
|
| 82 |
+
knowledge_base_files[knowledge_base].extend(selected_files)
|
| 83 |
+
|
| 84 |
+
logger.info(f"Vectorized files: {selected_files} for knowledge base: {knowledge_base}")
|
| 85 |
+
await asyncio.sleep(0) # 允许其他任务执行
|
| 86 |
+
return f"Vectorized files: {', '.join(selected_files)}\nKnowledge Base: {knowledge_base}\nUploaded Files: {', '.join(uploaded_files)}", "<div style='color: green; padding: 10px; border: 2px solid green; border-radius: 5px;'>Vectorization successful!</div>"
|
| 87 |
+
return "", "<div style='color: red; padding: 10px; border: 2px solid red; border-radius: 5px;'>Vectorization failed!</div>"
|
| 88 |
+
|
| 89 |
+
def update_file_list():
|
| 90 |
+
return gr.update(choices=uploaded_files, value=[])
|
| 91 |
+
|
| 92 |
+
def search_knowledge_base(selected_knowledge_base):
|
| 93 |
+
if selected_knowledge_base in knowledge_base_files:
|
| 94 |
+
kb_files = knowledge_base_files[selected_knowledge_base]
|
| 95 |
+
return gr.update(choices=kb_files, value=[])
|
| 96 |
+
return gr.update(choices=[], value=[])
|
| 97 |
+
|
| 98 |
+
def update_knowledge_base_files():
|
| 99 |
+
global knowledge_base_files
|
| 100 |
+
knowledge_base_files = get_knowledge_base_files()
|
| 101 |
+
|
| 102 |
+
# 处理聊天消息的函数
|
| 103 |
+
chat_history = []
|
| 104 |
+
|
| 105 |
+
def safe_chat_response(model_dropdown, vector_dropdown, chat_knowledge_base_dropdown, chain_dropdown, message):
|
| 106 |
+
try:
|
| 107 |
+
return chat_response(model_dropdown, vector_dropdown, chat_knowledge_base_dropdown, chain_dropdown, message)
|
| 108 |
+
except Exception as e:
|
| 109 |
+
logger.error(f"Error in chat response: {str(e)}")
|
| 110 |
+
return f"<div style='color: red;'>Error: {str(e)}</div>", ""
|
| 111 |
+
|
| 112 |
+
def chat_response(model_dropdown, vector_dropdown, chat_knowledge_base_dropdown, chain_dropdown, message):
|
| 113 |
+
global chat_history
|
| 114 |
+
if message:
|
| 115 |
+
chat_history.append(("User", message))
|
| 116 |
+
if chat_knowledge_base_dropdown == "仅使用模型":
|
| 117 |
+
rag = RAG_class(model=model_dropdown,persist_directory=DB_directory)
|
| 118 |
+
answer = rag.mult_chat(chat_history)
|
| 119 |
+
if chat_knowledge_base_dropdown and chat_knowledge_base_dropdown != "仅使用模型":
|
| 120 |
+
rag = RAG_class(model=model_dropdown, embed=vector_dropdown, c_name=chat_knowledge_base_dropdown, persist_directory=DB_directory)
|
| 121 |
+
if chain_dropdown == "复杂召回方式":
|
| 122 |
+
questions = rag.decomposition_chain(message)
|
| 123 |
+
answer = rag.rag_chain(questions)
|
| 124 |
+
elif chain_dropdown == "简单召回方式":
|
| 125 |
+
answer = rag.simple_chain(message)
|
| 126 |
+
else:
|
| 127 |
+
answer = rag.rerank_chain(message)
|
| 128 |
+
|
| 129 |
+
response = f" {answer}"
|
| 130 |
+
chat_history.append(("Bot", response))
|
| 131 |
+
return format_chat_history(chat_history), ""
|
| 132 |
+
|
| 133 |
+
def clear_chat():
|
| 134 |
+
global chat_history
|
| 135 |
+
chat_history = []
|
| 136 |
+
return format_chat_history(chat_history)
|
| 137 |
+
|
| 138 |
+
def format_chat_history(history):
|
| 139 |
+
formatted_history = ""
|
| 140 |
+
for user, msg in history:
|
| 141 |
+
if user == "User":
|
| 142 |
+
formatted_history += f'''
|
| 143 |
+
<div style="text-align: right; margin: 10px;">
|
| 144 |
+
<div style="display: inline-block; background-color: #DCF8C6; padding: 10px; border-radius: 10px; max-width: 60%;">
|
| 145 |
+
{msg}
|
| 146 |
+
</div>
|
| 147 |
+
<b>:User</b>
|
| 148 |
+
</div>
|
| 149 |
+
'''
|
| 150 |
+
else:
|
| 151 |
+
if "```" in msg: # 检测是否包含代码片段
|
| 152 |
+
code_content = msg.split("```")[1]
|
| 153 |
+
formatted_history += f'''
|
| 154 |
+
<div style="text-align: left; margin: 10px;">
|
| 155 |
+
<b>Bot:</b>
|
| 156 |
+
<div style="display: inline-block; background-color: #F1F0F0; padding: 10px; border-radius: 10px; max-width: 60%;">
|
| 157 |
+
<pre><code>{code_content}</code></pre>
|
| 158 |
+
</div>
|
| 159 |
+
</div>
|
| 160 |
+
'''
|
| 161 |
+
else:
|
| 162 |
+
formatted_history += f'''
|
| 163 |
+
<div style="text-align: left; margin: 10px;">
|
| 164 |
+
<b>Bot:</b>
|
| 165 |
+
<div style="display: inline-block; background-color: #F1F0F0; padding: 10px; border-radius: 10px; max-width: 60%;">
|
| 166 |
+
{msg}
|
| 167 |
+
</div>
|
| 168 |
+
</div>
|
| 169 |
+
'''
|
| 170 |
+
return formatted_history
|
| 171 |
+
|
| 172 |
+
def clear_status():
|
| 173 |
+
upload_status.update("")
|
| 174 |
+
delete_status.update("")
|
| 175 |
+
vectorize_status.update("")
|
| 176 |
+
delete_collection_status.update("")
|
| 177 |
+
|
| 178 |
+
def handle_knowledge_base_selection(selected_knowledge_base):
|
| 179 |
+
if selected_knowledge_base == "创建知识库":
|
| 180 |
+
return gr.update(visible=True, interactive=True), gr.update(choices=[], value=[]), gr.update(visible=False)
|
| 181 |
+
elif selected_knowledge_base == "仅使用模型":
|
| 182 |
+
return gr.update(visible=False, interactive=False), gr.update(choices=[], value=[]), gr.update(visible=False)
|
| 183 |
+
else:
|
| 184 |
+
return gr.update(visible=False, interactive=False), search_knowledge_base(selected_knowledge_base), gr.update(visible=True)
|
| 185 |
+
|
| 186 |
+
def update_knowledge_base_dropdown():
|
| 187 |
+
global knowledge_base_files
|
| 188 |
+
choices = ["创建知识库"] + list(knowledge_base_files.keys())
|
| 189 |
+
return gr.update(choices=choices)
|
| 190 |
+
|
| 191 |
+
def update_chat_knowledge_base_dropdown():
|
| 192 |
+
global knowledge_base_files
|
| 193 |
+
choices = ["仅使用模型"] + list(knowledge_base_files.keys())
|
| 194 |
+
return gr.update(choices=choices)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
# SearxNG搜索函数
|
| 198 |
+
def search_searxng(query):
|
| 199 |
+
searxng_url = 'http://localhost:8080/search' # 替换为你的SearxNG实例URL
|
| 200 |
+
params = {
|
| 201 |
+
'q': query,
|
| 202 |
+
'format': 'json'
|
| 203 |
+
}
|
| 204 |
+
response = requests.get(searxng_url, params=params)
|
| 205 |
+
response.raise_for_status()
|
| 206 |
+
return response.json()
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
# Ollama总结函数
|
| 210 |
+
def summarize_with_ollama(model_dropdown,text, question):
|
| 211 |
+
prompt = """
|
| 212 |
+
根据下边的内容,回答用户问题,
|
| 213 |
+
内容为:‘{0}‘\n
|
| 214 |
+
问题为:{1}
|
| 215 |
+
""".format(text, question)
|
| 216 |
+
ollama_url = 'http://localhost:11434/api/generate' # 替换为你的Ollama实例URL
|
| 217 |
+
data = {
|
| 218 |
+
'model': model_dropdown,
|
| 219 |
+
"prompt": prompt,
|
| 220 |
+
"stream": False
|
| 221 |
+
}
|
| 222 |
+
response = requests.post(ollama_url, json=data)
|
| 223 |
+
response.raise_for_status()
|
| 224 |
+
return response.json()
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
# 处理函数
|
| 228 |
+
def ai_web_search(model_dropdown,user_query):
|
| 229 |
+
# 使用SearxNG进行搜索
|
| 230 |
+
search_results = search_searxng(user_query)
|
| 231 |
+
search_texts = [result['title'] + "\n" + result['content'] for result in search_results['results']]
|
| 232 |
+
combined_text = "\n\n".join(search_texts)
|
| 233 |
+
|
| 234 |
+
# 使用Ollama进行总结
|
| 235 |
+
summary = summarize_with_ollama(model_dropdown,combined_text, user_query)
|
| 236 |
+
# print(summary)
|
| 237 |
+
# 返回结果
|
| 238 |
+
return summary['response']
|
| 239 |
+
# 添加新的函数来处理AI网络搜索
|
| 240 |
+
# def ai_web_search(model_dropdown, query):
|
| 241 |
+
# try:
|
| 242 |
+
# # 这里添加实际的网络搜索和AI处理逻辑
|
| 243 |
+
# # 这只是一个示例,您需要根据实际情况实现
|
| 244 |
+
# search_result = f"搜索结果: {query}"
|
| 245 |
+
# ai_response = f"AI回答: 基于搜索结果,对于'{query}'的回答是..."
|
| 246 |
+
# return f"{search_result}\n\n{ai_response}"
|
| 247 |
+
# except Exception as e:
|
| 248 |
+
# logger.error(f"Error in AI web search: {str(e)}")
|
| 249 |
+
# return f"<div style='color: red;'>Error: {str(e)}</div>"
|
| 250 |
+
|
| 251 |
+
# 创建 Gradio 界面
|
| 252 |
+
with gr.Blocks() as demo:
|
| 253 |
+
with gr.Column():
|
| 254 |
+
# 添加标题
|
| 255 |
+
title = gr.HTML("<h1 style='text-align: center; font-size: 32px; font-weight: bold;'>RAG精致系统</h1>")
|
| 256 |
+
# 添加公告栏
|
| 257 |
+
announcement = gr.HTML("<div style='text-align: center; font-size: 18px; color: red;'>公告栏: RAG精致系统,【检索增强生成】系统!<br/>莫大大</div>")
|
| 258 |
+
|
| 259 |
+
with gr.Tabs():
|
| 260 |
+
with gr.TabItem("知识库"):
|
| 261 |
+
knowledge_base_dropdown = gr.Dropdown(choices=["创建知识库"] + list(knowledge_base_files.keys()),
|
| 262 |
+
label="选择知识库")
|
| 263 |
+
new_kb_input = gr.Textbox(label="输入新的知识库名称", visible=False, interactive=True)
|
| 264 |
+
file_input = gr.Files(label="Upload files")
|
| 265 |
+
upload_btn = gr.Button("Upload")
|
| 266 |
+
file_list = gr.CheckboxGroup(label="Uploaded Files")
|
| 267 |
+
delete_btn = gr.Button("Delete Selected Files")
|
| 268 |
+
with gr.Row():
|
| 269 |
+
chunk_size_dropdown = gr.Dropdown(choices=[50, 100, 200, 300, 500, 700], label="chunk_size", value=200)
|
| 270 |
+
chunk_overlap_dropdown = gr.Dropdown(choices=[20, 50, 100, 200], label="chunk_overlap", value=50)
|
| 271 |
+
vectorize_btn = gr.Button("Vectorize Selected Files")
|
| 272 |
+
delete_collection_btn = gr.Button("Delete Collection")
|
| 273 |
+
upload_status = gr.HTML()
|
| 274 |
+
delete_status = gr.HTML()
|
| 275 |
+
vectorize_status = gr.HTML()
|
| 276 |
+
delete_collection_status = gr.HTML()
|
| 277 |
+
|
| 278 |
+
with gr.TabItem("Chat"):
|
| 279 |
+
with gr.Row():
|
| 280 |
+
model_dropdown = gr.Dropdown(choices=get_llm(), label="模型")
|
| 281 |
+
vector_dropdown = gr.Dropdown(choices=get_embeding_model(), label="向量")
|
| 282 |
+
chat_knowledge_base_dropdown = gr.Dropdown(choices=["仅使用模型"] + vectordb.get_all_collections_name(), label="知识库")
|
| 283 |
+
chain_dropdown = gr.Dropdown(choices=["复杂召回方式", "简单召回方式","rerank"], label="chain方式", visible=False)
|
| 284 |
+
chat_display = gr.HTML(label="Chat History")
|
| 285 |
+
chat_input = gr.Textbox(label="Type a message")
|
| 286 |
+
chat_btn = gr.Button("Send")
|
| 287 |
+
clear_btn = gr.Button("Clear Chat History")
|
| 288 |
+
|
| 289 |
+
with gr.TabItem("AI网络搜索"):
|
| 290 |
+
with gr.Row():
|
| 291 |
+
web_search_model_dropdown = gr.Dropdown(choices=get_llm(), label="模型")
|
| 292 |
+
web_search_output = gr.Textbox(label="搜索结果和AI回答", lines=10)
|
| 293 |
+
web_search_input = gr.Textbox(label="输入搜索查询")
|
| 294 |
+
|
| 295 |
+
web_search_btn = gr.Button("搜索")
|
| 296 |
+
|
| 297 |
+
def handle_upload(files):
|
| 298 |
+
upload_result, new_files, status = upload_files(files)
|
| 299 |
+
threading.Thread(target=clear_status).start()
|
| 300 |
+
return upload_result, new_files, status, update_chat_knowledge_base_dropdown()
|
| 301 |
+
|
| 302 |
+
def handle_delete(selected_knowledge_base, selected_files):
|
| 303 |
+
tmp = []
|
| 304 |
+
cols_files_tmp = vectordb.get_collcetion_content_files(c_name=selected_knowledge_base)
|
| 305 |
+
for i in selected_files:
|
| 306 |
+
if i in cols_files_tmp:
|
| 307 |
+
tmp.append(i)
|
| 308 |
+
del cols_files_tmp
|
| 309 |
+
if tmp:
|
| 310 |
+
vectordb.del_files(tmp, c_name=selected_knowledge_base)
|
| 311 |
+
del tmp
|
| 312 |
+
delete_result, status = delete_files(selected_files)
|
| 313 |
+
threading.Thread(target=clear_status).start()
|
| 314 |
+
return delete_result, status, update_chat_knowledge_base_dropdown()
|
| 315 |
+
|
| 316 |
+
def handle_vectorize(selected_files, selected_knowledge_base, new_kb_name, chunk_size, chunk_overlap):
|
| 317 |
+
vectorize_result, status = asyncio.run(async_vectorize_files(selected_files, selected_knowledge_base, new_kb_name, chunk_size, chunk_overlap))
|
| 318 |
+
threading.Thread(target=clear_status).start()
|
| 319 |
+
return vectorize_result, status, update_knowledge_base_dropdown(), update_chat_knowledge_base_dropdown()
|
| 320 |
+
|
| 321 |
+
def handle_delete_collection(selected_knowledge_base):
|
| 322 |
+
result, status = delete_collection(selected_knowledge_base)
|
| 323 |
+
threading.Thread(target=clear_status).start()
|
| 324 |
+
return result, status, update_chat_knowledge_base_dropdown()
|
| 325 |
+
|
| 326 |
+
knowledge_base_dropdown.change(
|
| 327 |
+
handle_knowledge_base_selection,
|
| 328 |
+
inputs=knowledge_base_dropdown,
|
| 329 |
+
outputs=[new_kb_input, file_list, chain_dropdown]
|
| 330 |
+
)
|
| 331 |
+
upload_btn.click(handle_upload, inputs=file_input, outputs=[file_list, file_list, upload_status, chat_knowledge_base_dropdown])
|
| 332 |
+
delete_btn.click(handle_delete, inputs=[knowledge_base_dropdown, file_list], outputs=[file_list, delete_status, chat_knowledge_base_dropdown])
|
| 333 |
+
vectorize_btn.click(handle_vectorize, inputs=[file_list, knowledge_base_dropdown, new_kb_input, chunk_size_dropdown, chunk_overlap_dropdown],
|
| 334 |
+
outputs=[gr.Textbox(visible=False), vectorize_status, knowledge_base_dropdown, chat_knowledge_base_dropdown])
|
| 335 |
+
delete_collection_btn.click(handle_delete_collection, inputs=knowledge_base_dropdown,
|
| 336 |
+
outputs=[knowledge_base_dropdown, delete_collection_status, chat_knowledge_base_dropdown])
|
| 337 |
+
|
| 338 |
+
chat_btn.click(chat_response, inputs=[model_dropdown, vector_dropdown, chat_knowledge_base_dropdown, chain_dropdown, chat_input], outputs=[chat_display, chat_input])
|
| 339 |
+
clear_btn.click(clear_chat, outputs=chat_display)
|
| 340 |
+
|
| 341 |
+
chat_knowledge_base_dropdown.change(
|
| 342 |
+
fn=lambda selected: gr.update(visible=selected != "仅使用模型"),
|
| 343 |
+
inputs=chat_knowledge_base_dropdown,
|
| 344 |
+
outputs=chain_dropdown
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
# 添加新的点击事件处理
|
| 348 |
+
web_search_btn.click(
|
| 349 |
+
ai_web_search,
|
| 350 |
+
inputs=[web_search_model_dropdown, web_search_input],
|
| 351 |
+
outputs=web_search_output
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
demo.launch(debug=True,share=True)
|
embeding/__pycache__/asr_utils.cpython-310.pyc
ADDED
|
Binary file (634 Bytes). View file
|
|
|
embeding/__pycache__/chromadb.cpython-310.pyc
ADDED
|
Binary file (3.91 kB). View file
|
|
|
embeding/__pycache__/elasticsearchStore.cpython-310.pyc
ADDED
|
Binary file (4.18 kB). View file
|
|
|
embeding/__pycache__/faissdb.cpython-310.pyc
ADDED
|
Binary file (4.21 kB). View file
|
|
|
embeding/asr_utils.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#coding:utf-8
|
| 2 |
+
from funasr import AutoModel
|
| 3 |
+
# paraformer-zh is a multi-functional asr model
|
| 4 |
+
# use vad, punc, spk or not as you need
|
| 5 |
+
model = AutoModel(model="paraformer-zh", vad_model="fsmn-vad", punc_model="ct-punc",
|
| 6 |
+
# spk_model="cam++"
|
| 7 |
+
)
|
| 8 |
+
def get_spk_txt(file):
|
| 9 |
+
res = model.generate(input=file,
|
| 10 |
+
batch_size_s=300,
|
| 11 |
+
hotword='魔搭')
|
| 12 |
+
print(res[0]["text"])
|
| 13 |
+
fw = "embeding/tmp.txt"
|
| 14 |
+
f = open(fw,"w",encoding="utf-8")
|
| 15 |
+
f.write('"context"\n'+res[0]["text"])
|
| 16 |
+
f.close()
|
| 17 |
+
return fw
|
embeding/chromadb.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_community.vectorstores import Chroma
|
| 2 |
+
from langchain_community.embeddings import OllamaEmbeddings
|
| 3 |
+
from langchain_community.document_loaders import TextLoader,UnstructuredCSVLoader, UnstructuredPDFLoader,UnstructuredWordDocumentLoader,UnstructuredExcelLoader,UnstructuredMarkdownLoader
|
| 4 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 5 |
+
from .asr_utils import get_spk_txt
|
| 6 |
+
|
| 7 |
+
class ChromaDB():
|
| 8 |
+
def __init__(self, embedding="mofanke/acge_text_embedding:latest", persist_directory="./Chroma_db/"):
|
| 9 |
+
|
| 10 |
+
self.embedding = OllamaEmbeddings(model=embedding)
|
| 11 |
+
self.persist_directory = persist_directory
|
| 12 |
+
self.chromadb = Chroma(persist_directory=persist_directory)
|
| 13 |
+
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=200, chunk_overlap=50)
|
| 14 |
+
|
| 15 |
+
def parse_data(self,file):
|
| 16 |
+
if "txt" in file.lower() or "csv" in file.lower():
|
| 17 |
+
try:
|
| 18 |
+
loaders = UnstructuredCSVLoader(file)
|
| 19 |
+
data = loaders.load()
|
| 20 |
+
except:
|
| 21 |
+
loaders = TextLoader(file,encoding="utf-8")
|
| 22 |
+
data = loaders.load()
|
| 23 |
+
if ".doc" in file.lower() or ".docx" in file.lower():
|
| 24 |
+
loaders = UnstructuredWordDocumentLoader(file)
|
| 25 |
+
data = loaders.load()
|
| 26 |
+
if "pdf" in file.lower():
|
| 27 |
+
loaders = UnstructuredPDFLoader(file)
|
| 28 |
+
data = loaders.load()
|
| 29 |
+
if ".xlsx" in file.lower():
|
| 30 |
+
loaders = UnstructuredExcelLoader(file)
|
| 31 |
+
data = loaders.load()
|
| 32 |
+
if ".md" in file.lower():
|
| 33 |
+
loaders = UnstructuredMarkdownLoader(file)
|
| 34 |
+
data = loaders.load()
|
| 35 |
+
if "mp3" in file.lower() or "mp4" in file.lower() or "wav" in file.lower():
|
| 36 |
+
# 语音解析成文字
|
| 37 |
+
fw = get_spk_txt(file)
|
| 38 |
+
loaders = UnstructuredCSVLoader(fw)
|
| 39 |
+
data = loaders.load()
|
| 40 |
+
tmp = []
|
| 41 |
+
for i in data:
|
| 42 |
+
i.metadata["source"] = file
|
| 43 |
+
tmp.append(i)
|
| 44 |
+
data = tmp
|
| 45 |
+
return data
|
| 46 |
+
|
| 47 |
+
# 创建 新的collection 并且初始化
|
| 48 |
+
def create_collection(self, files, c_name,chunk_size=200, chunk_overlap=50):
|
| 49 |
+
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
| 50 |
+
print("开始创建数据库 ....")
|
| 51 |
+
tmps = []
|
| 52 |
+
for file in files:
|
| 53 |
+
data = self.parse_data(file)
|
| 54 |
+
tmps.extend(data)
|
| 55 |
+
|
| 56 |
+
splits = self.text_splitter.split_documents(tmps)
|
| 57 |
+
|
| 58 |
+
vectorstore = self.chromadb.from_documents(documents=splits, collection_name=c_name,
|
| 59 |
+
embedding=self.embedding, persist_directory=self.persist_directory)
|
| 60 |
+
print("数据块总量:", vectorstore._collection.count())
|
| 61 |
+
|
| 62 |
+
return vectorstore
|
| 63 |
+
|
| 64 |
+
# 添加 数据到已有数据库
|
| 65 |
+
def add_chroma(self, files, c_name,chunk_size=200, chunk_overlap=50):
|
| 66 |
+
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
| 67 |
+
print("开始添加文件...")
|
| 68 |
+
tmps = []
|
| 69 |
+
for file in files:
|
| 70 |
+
data = self.parse_data(file)
|
| 71 |
+
tmps.extend(data)
|
| 72 |
+
|
| 73 |
+
splits = self.text_splitter.split_documents(tmps)
|
| 74 |
+
|
| 75 |
+
vectorstore = Chroma(persist_directory=self.persist_directory, collection_name=c_name,
|
| 76 |
+
embedding_function=self.embedding)
|
| 77 |
+
vectorstore.add_documents(splits)
|
| 78 |
+
print("数据块总量:", vectorstore._collection.count())
|
| 79 |
+
|
| 80 |
+
return vectorstore
|
| 81 |
+
|
| 82 |
+
# 删除 某个collection中的 某个文件
|
| 83 |
+
def del_files(self, del_files_name, c_name):
|
| 84 |
+
|
| 85 |
+
vectorstore = self.chromadb._client.get_collection(c_name)
|
| 86 |
+
del_ids = []
|
| 87 |
+
vec_dict = vectorstore.get()
|
| 88 |
+
for id, md in zip(vec_dict["ids"], vec_dict["metadatas"]):
|
| 89 |
+
for dl in del_files_name:
|
| 90 |
+
if dl in md["source"]:
|
| 91 |
+
del_ids.append(id)
|
| 92 |
+
vectorstore.delete(ids=del_ids)
|
| 93 |
+
print("数据块总量:", vectorstore.count())
|
| 94 |
+
|
| 95 |
+
return vectorstore
|
| 96 |
+
|
| 97 |
+
# 删除某个 知识库 collection
|
| 98 |
+
def delete_collection(self, c_name):
|
| 99 |
+
|
| 100 |
+
self.chromadb._client.delete_collection(c_name)
|
| 101 |
+
|
| 102 |
+
# 获取目前所有 collection
|
| 103 |
+
def get_all_collections_name(self):
|
| 104 |
+
cl_names = []
|
| 105 |
+
|
| 106 |
+
test = self.chromadb._client.list_collections()
|
| 107 |
+
for i in range(len(test)):
|
| 108 |
+
cl_names.append(test[i].name)
|
| 109 |
+
return cl_names
|
| 110 |
+
|
| 111 |
+
# 获取 collection中的所有文件
|
| 112 |
+
def get_collcetion_content_files(self, c_name):
|
| 113 |
+
vectorstore = self.chromadb._client.get_collection(c_name)
|
| 114 |
+
c_files = []
|
| 115 |
+
vec_dict = vectorstore.get()
|
| 116 |
+
for md in vec_dict["metadatas"]:
|
| 117 |
+
c_files.append(md["source"])
|
| 118 |
+
return list(set(c_files))
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# if __name__ == "__main__":
|
| 122 |
+
# chromadb = ChromaDB()
|
| 123 |
+
# c_name = "sss3"
|
| 124 |
+
#
|
| 125 |
+
# print(chromadb.get_all_collections_name())
|
| 126 |
+
# chromadb.create_collection(["data/���内科学.txt", "data/jl.pdf"], c_name=c_name)
|
| 127 |
+
# print(chromadb.get_all_collections_name())
|
| 128 |
+
# chromadb.add_chroma(["data/儿科学.txt"], c_name=c_name)
|
| 129 |
+
# print(c_name, "包含的文件:", chromadb.get_collcetion_content_files(c_name))
|
| 130 |
+
# chromadb.del_files(["data/肾内科学.txt"], c_name=c_name)
|
| 131 |
+
# print(c_name, "包含的文件:", chromadb.get_collcetion_content_files(c_name))
|
| 132 |
+
# print(chromadb.get_all_collections_name())
|
| 133 |
+
# chromadb.delete_collection(c_name=c_name)
|
| 134 |
+
# print(chromadb.get_all_collections_name())
|
embeding/elasticsearchStore.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from elasticsearch import Elasticsearch
|
| 2 |
+
from langchain_elasticsearch.vectorstores import ElasticsearchStore
|
| 3 |
+
from langchain_community.embeddings import OllamaEmbeddings
|
| 4 |
+
from langchain_community.document_loaders import TextLoader, UnstructuredCSVLoader, UnstructuredPDFLoader, \
|
| 5 |
+
UnstructuredWordDocumentLoader, UnstructuredExcelLoader, UnstructuredMarkdownLoader
|
| 6 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 7 |
+
from .asr_utils import get_spk_txt
|
| 8 |
+
import requests
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ElsStore():
|
| 12 |
+
def __init__(self, embedding="mofanke/acge_text_embedding:latest", es_url="http://localhost:9200",
|
| 13 |
+
index_name='test_index'):
|
| 14 |
+
self.embedding = OllamaEmbeddings(model=embedding)
|
| 15 |
+
self.es_url = es_url
|
| 16 |
+
self.elastic_vector_search = ElasticsearchStore(
|
| 17 |
+
es_url=self.es_url,
|
| 18 |
+
index_name=index_name,
|
| 19 |
+
embedding=self.embedding
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
def parse_data(self, file):
|
| 23 |
+
if "txt" in file.lower() or "csv" in file.lower():
|
| 24 |
+
try:
|
| 25 |
+
loaders = UnstructuredCSVLoader(file)
|
| 26 |
+
data = loaders.load()
|
| 27 |
+
except:
|
| 28 |
+
loaders = TextLoader(file, encoding="utf-8")
|
| 29 |
+
data = loaders.load()
|
| 30 |
+
if ".doc" in file.lower() or ".docx" in file.lower():
|
| 31 |
+
loaders = UnstructuredWordDocumentLoader(file)
|
| 32 |
+
data = loaders.load()
|
| 33 |
+
if "pdf" in file.lower():
|
| 34 |
+
loaders = UnstructuredPDFLoader(file)
|
| 35 |
+
data = loaders.load()
|
| 36 |
+
if ".xlsx" in file.lower():
|
| 37 |
+
loaders = UnstructuredExcelLoader(file)
|
| 38 |
+
data = loaders.load()
|
| 39 |
+
if ".md" in file.lower():
|
| 40 |
+
loaders = UnstructuredMarkdownLoader(file)
|
| 41 |
+
data = loaders.load()
|
| 42 |
+
if "mp3" in file.lower() or "mp4" in file.lower() or "wav" in file.lower():
|
| 43 |
+
# 语音解析成文字
|
| 44 |
+
fw = get_spk_txt(file)
|
| 45 |
+
loaders = UnstructuredCSVLoader(fw)
|
| 46 |
+
data = loaders.load()
|
| 47 |
+
tmp = []
|
| 48 |
+
for i in data:
|
| 49 |
+
i.metadata["source"] = file
|
| 50 |
+
tmp.append(i)
|
| 51 |
+
data = tmp
|
| 52 |
+
return data
|
| 53 |
+
|
| 54 |
+
def get_count(self, c_name):
|
| 55 |
+
# 获取index-anme中的数据块数
|
| 56 |
+
|
| 57 |
+
# 初始化 Elasticsearch 客户端
|
| 58 |
+
es = Elasticsearch([{
|
| 59 |
+
'host': self.es_url.split(":")[1][2:],
|
| 60 |
+
'port': int(self.es_url.split(":")[2]),
|
| 61 |
+
'scheme': 'http' # 指定使用的协议
|
| 62 |
+
}])
|
| 63 |
+
|
| 64 |
+
# 指定索引名称
|
| 65 |
+
index_name = c_name
|
| 66 |
+
|
| 67 |
+
# 获取文档总数
|
| 68 |
+
response = es.count(index=index_name)
|
| 69 |
+
|
| 70 |
+
# 输出文档总数
|
| 71 |
+
return response['count']
|
| 72 |
+
|
| 73 |
+
# 创建 新的index_name 并且初始化
|
| 74 |
+
def create_collection(self, files, c_name, chunk_size=200, chunk_overlap=50):
|
| 75 |
+
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
| 76 |
+
print("开始创建数据库 ....")
|
| 77 |
+
tmps = []
|
| 78 |
+
for file in files:
|
| 79 |
+
data = self.parse_data(file)
|
| 80 |
+
tmps.extend(data)
|
| 81 |
+
|
| 82 |
+
splits = self.text_splitter.split_documents(tmps)
|
| 83 |
+
|
| 84 |
+
self.elastic_vector_search = ElasticsearchStore.from_documents(
|
| 85 |
+
documents=splits,
|
| 86 |
+
embedding=self.embedding,
|
| 87 |
+
es_url=self.es_url,
|
| 88 |
+
index_name=c_name,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
self.elastic_vector_search.client.indices.refresh(index=c_name)
|
| 92 |
+
|
| 93 |
+
print("数据块总量:", self.get_count(c_name))
|
| 94 |
+
|
| 95 |
+
return self.elastic_vector_search
|
| 96 |
+
|
| 97 |
+
# 添加 数据到已有数据库
|
| 98 |
+
def add_chroma(self, files, c_name, chunk_size=200, chunk_overlap=50):
|
| 99 |
+
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
| 100 |
+
print("开始添加文件...")
|
| 101 |
+
tmps = []
|
| 102 |
+
for file in files:
|
| 103 |
+
data = self.parse_data(file)
|
| 104 |
+
tmps.extend(data)
|
| 105 |
+
|
| 106 |
+
splits = self.text_splitter.split_documents(tmps)
|
| 107 |
+
|
| 108 |
+
self.elastic_vector_search = ElasticsearchStore(
|
| 109 |
+
es_url=self.es_url,
|
| 110 |
+
index_name=c_name,
|
| 111 |
+
embedding=self.embedding
|
| 112 |
+
)
|
| 113 |
+
self.elastic_vector_search.add_documents(splits)
|
| 114 |
+
self.elastic_vector_search.client.indices.refresh(index=c_name)
|
| 115 |
+
print("数据块总量:", self.get_count(c_name))
|
| 116 |
+
|
| 117 |
+
return self.elastic_vector_search
|
| 118 |
+
|
| 119 |
+
# 删除某个 知识库 collection
|
| 120 |
+
def delete_collection(self, c_name):
|
| 121 |
+
url = self.es_url + "/" + c_name
|
| 122 |
+
# 发送 DELETE 请求
|
| 123 |
+
response = requests.delete(url)
|
| 124 |
+
|
| 125 |
+
# 检查响应状态码
|
| 126 |
+
if response.status_code == 200:
|
| 127 |
+
return f"索引 'test-basic1' 已成功删除。"
|
| 128 |
+
elif response.status_code == 404:
|
| 129 |
+
return f"索引 'test-basic1' 不存在。"
|
| 130 |
+
else:
|
| 131 |
+
return f"删除索引时出错: {response.status_code}, {response.text}"
|
| 132 |
+
|
| 133 |
+
# 获取目前所有 index_names
|
| 134 |
+
def get_all_collections_name(self):
|
| 135 |
+
indices = self.elastic_vector_search.client.indices.get_alias()
|
| 136 |
+
index_names = list(indices.keys())
|
| 137 |
+
|
| 138 |
+
return index_names
|
| 139 |
+
|
| 140 |
+
def get_collcetion_content_files(self,c_name):
|
| 141 |
+
return []
|
| 142 |
+
|
| 143 |
+
# 删除 某个collection中的 某个文件
|
| 144 |
+
def del_files(self, del_files_name, c_name):
|
| 145 |
+
return None
|
| 146 |
+
|
| 147 |
+
|
embeding/faissdb.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_community.vectorstores import FAISS
|
| 2 |
+
from langchain_community.embeddings import OllamaEmbeddings
|
| 3 |
+
from langchain_community.document_loaders import TextLoader,UnstructuredCSVLoader, UnstructuredPDFLoader,UnstructuredWordDocumentLoader,UnstructuredExcelLoader,UnstructuredMarkdownLoader
|
| 4 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 5 |
+
import shutil
|
| 6 |
+
import os
|
| 7 |
+
from .asr_utils import get_spk_txt
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class FaissDB():
|
| 11 |
+
def __init__(self, embedding="mofanke/acge_text_embedding:latest", persist_directory="./Faiss_db/"):
|
| 12 |
+
|
| 13 |
+
self.embedding = OllamaEmbeddings(model=embedding)
|
| 14 |
+
self.persist_directory = persist_directory
|
| 15 |
+
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=50, add_start_index=True)
|
| 16 |
+
|
| 17 |
+
def parse_data(self,file):
|
| 18 |
+
if "txt" in file.lower() or "csv" in file.lower():
|
| 19 |
+
try:
|
| 20 |
+
loaders = UnstructuredCSVLoader(file)
|
| 21 |
+
data = loaders.load()
|
| 22 |
+
except:
|
| 23 |
+
loaders = TextLoader(file,encoding="utf-8")
|
| 24 |
+
data = loaders.load()
|
| 25 |
+
if ".doc" in file.lower() or ".docx" in file.lower():
|
| 26 |
+
loaders = UnstructuredWordDocumentLoader(file)
|
| 27 |
+
data = loaders.load()
|
| 28 |
+
if "pdf" in file.lower():
|
| 29 |
+
loaders = UnstructuredPDFLoader(file)
|
| 30 |
+
data = loaders.load()
|
| 31 |
+
if ".xlsx" in file.lower():
|
| 32 |
+
loaders = UnstructuredExcelLoader(file)
|
| 33 |
+
data = loaders.load()
|
| 34 |
+
if ".md" in file.lower():
|
| 35 |
+
loaders = UnstructuredMarkdownLoader(file)
|
| 36 |
+
data = loaders.load()
|
| 37 |
+
if "mp3" in file.lower() or "mp4" in file.lower() or "wav" in file.lower():
|
| 38 |
+
# 语音解析成文字
|
| 39 |
+
fw = get_spk_txt(file)
|
| 40 |
+
loaders = UnstructuredCSVLoader(fw)
|
| 41 |
+
data = loaders.load()
|
| 42 |
+
tmp = []
|
| 43 |
+
for i in data:
|
| 44 |
+
i.metadata["source"] = file
|
| 45 |
+
tmp.append(i)
|
| 46 |
+
data = tmp
|
| 47 |
+
return data
|
| 48 |
+
|
| 49 |
+
# 创建 新的collection 并且初始化
|
| 50 |
+
def create_collection(self, files, c_name,chunk_size=200, chunk_overlap=50):
|
| 51 |
+
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
| 52 |
+
print("开始创建数据库 ....")
|
| 53 |
+
tmps = []
|
| 54 |
+
for file in files:
|
| 55 |
+
data = self.parse_data(file)
|
| 56 |
+
tmps.extend(data)
|
| 57 |
+
|
| 58 |
+
splits = self.text_splitter.split_documents(tmps)
|
| 59 |
+
|
| 60 |
+
vectorstore = FAISS.from_documents(documents=splits,
|
| 61 |
+
embedding=self.embedding)
|
| 62 |
+
vectorstore.save_local(self.persist_directory + c_name)
|
| 63 |
+
print("数据块总量:", vectorstore.index.ntotal)
|
| 64 |
+
|
| 65 |
+
return vectorstore
|
| 66 |
+
|
| 67 |
+
# 添加 数据到已有数据库
|
| 68 |
+
def add_chroma(self, files, c_name,chunk_size=200, chunk_overlap=50):
|
| 69 |
+
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
| 70 |
+
print("开始添加文件...")
|
| 71 |
+
tmps = []
|
| 72 |
+
for file in files:
|
| 73 |
+
data = self.parse_data(file)
|
| 74 |
+
tmps.extend(data)
|
| 75 |
+
|
| 76 |
+
splits = self.text_splitter.split_documents(tmps)
|
| 77 |
+
|
| 78 |
+
vectorstore = FAISS.load_local(folder_path=self.persist_directory + c_name, embeddings=self.embedding,
|
| 79 |
+
allow_dangerous_deserialization=True)
|
| 80 |
+
vectorstore.add_documents(documents=splits)
|
| 81 |
+
vectorstore.save_local("Faiss_db/" + c_name)
|
| 82 |
+
print("数据块总量:", vectorstore.index.ntotal)
|
| 83 |
+
|
| 84 |
+
return vectorstore
|
| 85 |
+
|
| 86 |
+
# 删除 某个collection中的 某个文件
|
| 87 |
+
def del_files(self, del_files_name, c_name):
|
| 88 |
+
|
| 89 |
+
vectorstore = FAISS.load_local(folder_path=self.persist_directory + c_name, embeddings=self.embedding,
|
| 90 |
+
allow_dangerous_deserialization=True)
|
| 91 |
+
del_ids = []
|
| 92 |
+
vec_dict = vectorstore.docstore._dict
|
| 93 |
+
for id, md in vec_dict.items():
|
| 94 |
+
for dl in del_files_name:
|
| 95 |
+
if dl in md.metadata["source"]:
|
| 96 |
+
del_ids.append(id)
|
| 97 |
+
vectorstore.delete(ids=del_ids)
|
| 98 |
+
vectorstore.save_local(self.persist_directory + c_name)
|
| 99 |
+
print("数据块总量:", vectorstore.index.ntotal)
|
| 100 |
+
|
| 101 |
+
return vectorstore
|
| 102 |
+
|
| 103 |
+
# 删除某个 知识库 collection
|
| 104 |
+
def delete_collection(self, c_name):
|
| 105 |
+
shutil.rmtree(self.persist_directory + c_name)
|
| 106 |
+
|
| 107 |
+
# 获取目前所有 collection
|
| 108 |
+
def get_all_collections_name(self):
|
| 109 |
+
cl_names = [i for i in os.listdir(self.persist_directory) if os.path.isdir(self.persist_directory+i)]
|
| 110 |
+
|
| 111 |
+
return cl_names
|
| 112 |
+
|
| 113 |
+
# 获取 collection中的所有文件
|
| 114 |
+
def get_collcetion_content_files(self, c_name):
|
| 115 |
+
vectorstore = FAISS.load_local(folder_path=self.persist_directory + c_name, embeddings=self.embedding,
|
| 116 |
+
allow_dangerous_deserialization=True)
|
| 117 |
+
c_files = []
|
| 118 |
+
vec_dict = vectorstore.docstore._dict
|
| 119 |
+
for _, md in vec_dict.items():
|
| 120 |
+
c_files.append(md.metadata["source"])
|
| 121 |
+
|
| 122 |
+
return list(set(c_files))
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
# if __name__ == "__main__":
|
| 126 |
+
# chromadb = FaissDB()
|
| 127 |
+
# c_name = "sss3"
|
| 128 |
+
#
|
| 129 |
+
# print(chromadb.get_all_collections_name())
|
| 130 |
+
# chromadb.create_collection(["data/jl.txt", "data/jl.pdf"], c_name=c_name)
|
| 131 |
+
# print(chromadb.get_all_collections_name())
|
| 132 |
+
# chromadb.add_chroma(["data/tmp.txt"], c_name=c_name)
|
| 133 |
+
# print(c_name, "包含的文件:", chromadb.get_collcetion_content_files(c_name))
|
| 134 |
+
# chromadb.del_files(["data/tmp.txt"], c_name=c_name)
|
| 135 |
+
# print(c_name, "包含的文件:", chromadb.get_collcetion_content_files(c_name))
|
| 136 |
+
# print(chromadb.get_all_collections_name())
|
| 137 |
+
# chromadb.delete_collection(c_name=c_name)
|
| 138 |
+
# print(chromadb.get_all_collections_name())
|
embeding/tmp.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"context"
|
| 2 |
+
你是不是觉得自己说话的声音直来直去呢?现在告诉你一个主持人吐字的小秘密,那就是每个字在口腔当中像是翻跟头一样打一圈再出来。比如说故人西辞黄鹤楼,而不是故人西辞黄鹤楼。再比如说乌衣巷口夕阳斜,而不是乌衣巷口夕阳斜,你也试试看抖音。
|
graph_demo_ui.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
from flask import Flask, render_template, request, jsonify
|
| 3 |
+
import json
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
+
from langchain_community.llms import Ollama
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
load_dotenv()
|
| 9 |
+
|
| 10 |
+
app = Flask(__name__)
|
| 11 |
+
|
| 12 |
+
# 测试了 llama3:8b,gemma2:9b,qwen2:7b,glm4:9b,arcee-ai/arcee-agent:latest 目前来看 qwen2:7 效果最好
|
| 13 |
+
llm = Ollama(model="qwen2:7b")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
json_example = {'edges': [{'data': {'color': '#FFA07A',
|
| 17 |
+
'label': 'label 1',
|
| 18 |
+
'source': 'source 1',
|
| 19 |
+
'target': 'target 1'}},
|
| 20 |
+
{'data': {'color': '#FFA07A',
|
| 21 |
+
'label': 'label 2',
|
| 22 |
+
'source': 'source 2',
|
| 23 |
+
'target': 'target 2'}}
|
| 24 |
+
],
|
| 25 |
+
'nodes': [{'data': {'color': '#FFC0CB', 'id': 'id 1', 'label': 'label 1'}},
|
| 26 |
+
{'data': {'color': '#90EE90', 'id': 'id 2', 'label': 'label 2'}},
|
| 27 |
+
{'data': {'color': '#87CEEB', 'id': 'id 3', 'label': 'label 3'}}]}
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
__retriever_prompt = f"""
|
| 32 |
+
您是一名专门从事知识图谱创建的人工智能专家,目标是根据给定的输入或请求捕获关系。
|
| 33 |
+
基于各种形式的用户输入,如段落、电子邮件、文本文件等。
|
| 34 |
+
你的任务是根据输入创建一个知识图谱。
|
| 35 |
+
nodes必须具有label参数,并且label是来自输入的词语或短语,nodes必须具有id参数,id的格式是"id_数字",不能重复。
|
| 36 |
+
edges还必须有一个label参数,其中label是输入中的直接词语或短语,edges中的source和target取自nodes中的id。
|
| 37 |
+
仅使用JSON进行响应,其格式可以在python中进行jsonify,并直接输入cy.add(data),包括“color”属性,以在前端显示图形。
|
| 38 |
+
您可以参考给定的示例:{json_example}。存储node和edge的数组中,最后一个元素后边不要有逗号,
|
| 39 |
+
确保边的目标和源与现有节点匹配。
|
| 40 |
+
不要在JSON的上方和下方包含markdown三引号,直接用花括号括起来。
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def generate_graph_info(raw_text: str) -> str | None:
|
| 45 |
+
"""
|
| 46 |
+
generate graph info from raw text
|
| 47 |
+
:param raw_text:
|
| 48 |
+
:return:
|
| 49 |
+
"""
|
| 50 |
+
messages = [
|
| 51 |
+
{"role": "system", "content": "你现在扮演信息抽取的角色,要求根据用户输入和AI的回答,正确提取出信息,记得不多对实体进行翻译。"},
|
| 52 |
+
{"role": "user", "content": raw_text},
|
| 53 |
+
{"role": "user", "content": __retriever_prompt}
|
| 54 |
+
]
|
| 55 |
+
print("解析中....")
|
| 56 |
+
for i in range(3):
|
| 57 |
+
graph_info_result = llm.invoke(messages)
|
| 58 |
+
if len(graph_info_result)<10:
|
| 59 |
+
print("-------",i,"-------------------")
|
| 60 |
+
continue
|
| 61 |
+
else:
|
| 62 |
+
break
|
| 63 |
+
print(graph_info_result)
|
| 64 |
+
return graph_info_result
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@app.route('/')
|
| 68 |
+
def index():
|
| 69 |
+
return render_template('index.html')
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@app.route('/update_graph', methods=['POST'])
|
| 73 |
+
def update_graph():
|
| 74 |
+
raw_text = request.json.get('text', '')
|
| 75 |
+
try:
|
| 76 |
+
result = generate_graph_info(raw_text)
|
| 77 |
+
if '```' in result:
|
| 78 |
+
graph_data=json.loads(result.split('```',2)[1].replace("json", ''))
|
| 79 |
+
else:
|
| 80 |
+
graph_data=json.loads(result)
|
| 81 |
+
return graph_data
|
| 82 |
+
except Exception as e:
|
| 83 |
+
return {'error': f"Error parsing graph data: {str(e)}"}
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
if __name__ == '__main__':
|
| 87 |
+
app.run(host='0.0.0.0', port=7860)
|
img/graph-tool.png
ADDED
|
Git LFS Details
|
img/readme.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
1
|
img/zhu.png
ADDED
|
Git LFS Details
|
img/zhuye.png
ADDED
|
img//345/244/215/346/235/202/346/226/271/345/274/217.png
ADDED
|
img//345/276/256/344/277/241/345/233/276/347/211/207_20240524180648.jpg
ADDED
|
rag/__init__.py
ADDED
|
File without changes
|
rag/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (132 Bytes). View file
|
|
|
rag/__pycache__/config.cpython-310.pyc
ADDED
|
Binary file (364 Bytes). View file
|
|
|
rag/__pycache__/rag_class.cpython-310.pyc
ADDED
|
Binary file (5.39 kB). View file
|
|
|
rag/__pycache__/rerank.cpython-310.pyc
ADDED
|
Binary file (878 Bytes). View file
|
|
|
rag/__pycache__/rerank.cpython-39.pyc
ADDED
|
Binary file (869 Bytes). View file
|
|
|
rag/__pycache__/rerank_code.cpython-310.pyc
ADDED
|
Binary file (883 Bytes). View file
|
|
|
rag/rag_class.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_community.vectorstores import Chroma,FAISS
|
| 2 |
+
from langchain_community.llms import Ollama
|
| 3 |
+
from langchain_core.output_parsers import StrOutputParser
|
| 4 |
+
from langchain_community.embeddings import OllamaEmbeddings
|
| 5 |
+
from langchain_core.runnables import RunnablePassthrough
|
| 6 |
+
from operator import itemgetter
|
| 7 |
+
from langchain.prompts import ChatPromptTemplate
|
| 8 |
+
from rerank_code import rerank_topn
|
| 9 |
+
from Config.config import VECTOR_DB,DB_directory
|
| 10 |
+
from langchain_elasticsearch.vectorstores import ElasticsearchStore
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class RAG_class:
|
| 14 |
+
def __init__(self, model="qwen2:7b", embed="milkey/dmeta-embedding-zh:f16", c_name="sss1",
|
| 15 |
+
persist_directory="E:/pycode/jupyter_code/langGraph/sss2/chroma.sqlite3/",es_url="http://localhost:9200"):
|
| 16 |
+
template = """
|
| 17 |
+
根据上下文回答以下问题,不要自己发挥,要根据以下参考内容总结答案,如果以下内容无法得到答案,就返回无法根据参考内容获取答案,
|
| 18 |
+
|
| 19 |
+
参考内容为:{context}
|
| 20 |
+
|
| 21 |
+
问题: {question}
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
self.prompts = ChatPromptTemplate.from_template(template)
|
| 25 |
+
|
| 26 |
+
# 使用 问题扩展+结果递归方式得到最终答案
|
| 27 |
+
template1 = """你是一个乐于助人的助手,可以生成与输入问题相关的多个子问题。
|
| 28 |
+
目标是将输入分解为一组可以单独回答的子问题/子问题。
|
| 29 |
+
生成多个与以下内容相关的搜索查询:{question}
|
| 30 |
+
输出4个相关问题,以换行符隔开:"""
|
| 31 |
+
self.prompt_questions = ChatPromptTemplate.from_template(template1)
|
| 32 |
+
|
| 33 |
+
# 构建 问答对
|
| 34 |
+
template2 = """
|
| 35 |
+
以下是您需要回答的问题:
|
| 36 |
+
|
| 37 |
+
\n--\n {question} \n---\n
|
| 38 |
+
|
| 39 |
+
以下是任何可用的背景问答对:
|
| 40 |
+
|
| 41 |
+
\n--\n {q_a_pairs} \n---\n
|
| 42 |
+
|
| 43 |
+
以下是与该问题相关的其他上下文:
|
| 44 |
+
|
| 45 |
+
\n--\n {context} \n---\n
|
| 46 |
+
|
| 47 |
+
使用以上上下文和背景问答对来回答问题,问题是:{question} ,答案是:
|
| 48 |
+
"""
|
| 49 |
+
self.decomposition_prompt = ChatPromptTemplate.from_template(template2)
|
| 50 |
+
|
| 51 |
+
self.llm = Ollama(model=model)
|
| 52 |
+
self.embeding = OllamaEmbeddings(model=embed)
|
| 53 |
+
if VECTOR_DB==1:
|
| 54 |
+
self.vectstore = Chroma(embedding_function=self.embeding, collection_name=c_name,
|
| 55 |
+
persist_directory=persist_directory)
|
| 56 |
+
elif VECTOR_DB ==2:
|
| 57 |
+
self.vectstore = FAISS.load_local(folder_path=persist_directory + c_name, embeddings=self.embeding,
|
| 58 |
+
allow_dangerous_deserialization=True)
|
| 59 |
+
elif VECTOR_DB ==3:
|
| 60 |
+
self.vectstore = ElasticsearchStore(
|
| 61 |
+
es_url=es_url,
|
| 62 |
+
index_name=c_name,
|
| 63 |
+
embedding=self.embeding
|
| 64 |
+
)
|
| 65 |
+
self.retriever = self.vectstore.as_retriever()
|
| 66 |
+
try:
|
| 67 |
+
if VECTOR_DB==1:
|
| 68 |
+
self.vectstore = Chroma(embedding_function=self.embeding, collection_name=c_name,
|
| 69 |
+
persist_directory=persist_directory)
|
| 70 |
+
elif VECTOR_DB ==2:
|
| 71 |
+
self.vectstore = FAISS.load_local(folder_path=persist_directory + c_name, embeddings=self.embeding,
|
| 72 |
+
allow_dangerous_deserialization=True)
|
| 73 |
+
elif VECTOR_DB ==3:
|
| 74 |
+
self.vectstore = ElasticsearchStore(
|
| 75 |
+
es_url=es_url,
|
| 76 |
+
index_name=c_name,
|
| 77 |
+
embedding=self.embeding
|
| 78 |
+
)
|
| 79 |
+
self.retriever = self.vectstore.as_retriever()
|
| 80 |
+
except Exception as e:
|
| 81 |
+
print("仅模型时无需加载数据库",e)
|
| 82 |
+
#
|
| 83 |
+
# Post-processing
|
| 84 |
+
def format_docs(self,docs):
|
| 85 |
+
return "\n\n".join(doc.page_content for doc in docs)
|
| 86 |
+
# 传统方式召回,单问题召回,然后llm总结答案回答
|
| 87 |
+
def simple_chain(self,question):
|
| 88 |
+
_chain = (
|
| 89 |
+
{"context": self.retriever|self.format_docs,"question":RunnablePassthrough()}
|
| 90 |
+
|self.prompts
|
| 91 |
+
|self.llm
|
| 92 |
+
|StrOutputParser()
|
| 93 |
+
)
|
| 94 |
+
answer = _chain.invoke({"question":question})
|
| 95 |
+
return answer
|
| 96 |
+
|
| 97 |
+
def rerank_chain(self,question):
|
| 98 |
+
retriever = self.vectstore.as_retriever(search_kwargs={"k": 10})
|
| 99 |
+
docs = retriever.invoke(question)
|
| 100 |
+
docs = rerank_topn(question,docs,N=5)
|
| 101 |
+
_chain = (
|
| 102 |
+
self.prompts
|
| 103 |
+
| self.llm
|
| 104 |
+
| StrOutputParser()
|
| 105 |
+
)
|
| 106 |
+
answer = _chain.invoke({"context":self.format_docs(docs),"question": question})
|
| 107 |
+
return answer
|
| 108 |
+
|
| 109 |
+
def format_qa_pairs(self, question, answer):
|
| 110 |
+
formatted_string = ""
|
| 111 |
+
formatted_string += f"Question: {question}\nAnswer:{answer}\n\n"
|
| 112 |
+
return formatted_string
|
| 113 |
+
|
| 114 |
+
# 获取问题的 扩展问题
|
| 115 |
+
def decomposition_chain(self, question):
|
| 116 |
+
_chain = (
|
| 117 |
+
{"question": RunnablePassthrough()}
|
| 118 |
+
| self.prompt_questions
|
| 119 |
+
| self.llm
|
| 120 |
+
| StrOutputParser()
|
| 121 |
+
| (lambda x: x.split("\n"))
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
questions = _chain.invoke({"question": question}) + [question]
|
| 125 |
+
|
| 126 |
+
return questions
|
| 127 |
+
# 多问题递归召回,每次召回后,问题和答案同时作为下一次召回的参考,再次用新问题召回
|
| 128 |
+
def rag_chain(self, questions):
|
| 129 |
+
q_a_pairs = ""
|
| 130 |
+
for q in questions:
|
| 131 |
+
_chain = (
|
| 132 |
+
{"context": itemgetter("question") | self.retriever,
|
| 133 |
+
"question": itemgetter("question"),
|
| 134 |
+
"q_a_pairs": itemgetter("q_a_paris")
|
| 135 |
+
}
|
| 136 |
+
| self.decomposition_prompt
|
| 137 |
+
| self.llm
|
| 138 |
+
| StrOutputParser()
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
answer = _chain.invoke({"question": q, "q_a_paris": q_a_pairs})
|
| 142 |
+
q_a_pairs = self.format_qa_pairs(q, answer)
|
| 143 |
+
q_a_pairs = q_a_pairs + "\n----\n" + q_a_pairs
|
| 144 |
+
return answer
|
| 145 |
+
|
| 146 |
+
# 将聊天历史格式化为一个字符串
|
| 147 |
+
def format_chat_history(self,history):
|
| 148 |
+
formatted_history = ""
|
| 149 |
+
for role,content in history:
|
| 150 |
+
formatted_history += f"{role}: {content}\n"
|
| 151 |
+
return formatted_history
|
| 152 |
+
# 基于ollama大模型的大模型 多轮对话,不使用知识库的
|
| 153 |
+
def mult_chat(self,chat_history):
|
| 154 |
+
# 格式化聊天历史
|
| 155 |
+
formatted_history = self.format_chat_history(chat_history)
|
| 156 |
+
|
| 157 |
+
# 调用模型生成回复
|
| 158 |
+
response = self.llm.invoke(formatted_history)
|
| 159 |
+
return response
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
# if __name__ == "__main__":
|
| 164 |
+
# rag = RAG_class(model="deepseek-r1:14b")
|
| 165 |
+
# question = "人卫社官网网址是?"
|
| 166 |
+
# questions = rag.decomposition_chain(question)
|
| 167 |
+
# print(questions)
|
| 168 |
+
# answer = rag.rag_chain(questions)
|
| 169 |
+
# print(answer)
|
rag/rerank_code.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
| 3 |
+
|
| 4 |
+
tokenizer = AutoTokenizer.from_pretrained('E:\\model\\bge-reranker-large')
|
| 5 |
+
model = AutoModelForSequenceClassification.from_pretrained('E:\\model\\bge-reranker-large')
|
| 6 |
+
model.eval()
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def rerank_topn(question,docs,N=5):
|
| 10 |
+
pairs = []
|
| 11 |
+
for i in docs:
|
| 12 |
+
pairs.append([question,i.page_content])
|
| 13 |
+
|
| 14 |
+
with torch.no_grad():
|
| 15 |
+
inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
|
| 16 |
+
scores = model(**inputs, return_dict=True).logits.view(-1, ).float()
|
| 17 |
+
scores = scores.argsort().numpy()[::-1][:N]
|
| 18 |
+
bk = []
|
| 19 |
+
for i in scores:
|
| 20 |
+
bk.append(docs[i])
|
| 21 |
+
return bk
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio==4.29.0
|
| 2 |
+
langchain-community==0.2.6
|
| 3 |
+
langchain==0.2.6
|
| 4 |
+
langchain-core==0.2.11
|
| 5 |
+
requests
|
| 6 |
+
transformers==4.41.1
|
| 7 |
+
unstructured==0.7.12
|
| 8 |
+
funasr==1.0.24
|
| 9 |
+
modelscope
|
| 10 |
+
chromadb
|
test/__init__.py
ADDED
|
File without changes
|
test/graph2neo4j.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
import sys
|
| 3 |
+
sys.path.append(r"..//")#
|
| 4 |
+
from Neo4j.neo4j_op import KnowledgeGraph
|
| 5 |
+
from Neo4j.graph_extract import update_graph
|
| 6 |
+
from Config.config import neo4j_host,neo4j_name,neo4j_pwd
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
kg = KnowledgeGraph(neo4j_host,neo4j_name,neo4j_pwd)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
if __name__ == "__main__":
|
| 14 |
+
|
| 15 |
+
text = """范冰冰,1981年9月16日生于山东青岛,毕业于上海师范大学谢晋影视艺术学院,中国女演员,歌手。
|
| 16 |
+
1998年参演电视剧《还珠格格》成名。2004年主演电影《手机》获得第27届大众电影百花奖最佳女演员奖。"""
|
| 17 |
+
res = update_graph(text)
|
| 18 |
+
# 批量创建节点
|
| 19 |
+
nodes = kg.create_nodes("node", res["nodes"])
|
| 20 |
+
print(nodes)
|
| 21 |
+
# 批量创建关系
|
| 22 |
+
relationships = kg.create_relationships([
|
| 23 |
+
("node", {"name": edge["source"]}, "node", {"name": edge["target"]}, edge["label"]) for edge in res["edges"]
|
| 24 |
+
])
|
| 25 |
+
print(relationships)
|