chengyingmo commited on
Commit
83f7ed7
·
verified ·
1 Parent(s): 763356c

Upload 41 files

Browse files
Files changed (42) hide show
  1. .gitattributes +2 -0
  2. Chroma_db/readme.txt +0 -0
  3. Config/__pycache__/config.cpython-310.pyc +0 -0
  4. Config/config.py +14 -0
  5. Faiss_db/readme.txt +0 -0
  6. Faiss_db/sss1/index.faiss +0 -0
  7. Faiss_db/sss1/index.pkl +3 -0
  8. Neo4j/__pycache__/graph_extract.cpython-310.pyc +0 -0
  9. Neo4j/__pycache__/neo4j_op.cpython-310.pyc +0 -0
  10. Neo4j/graph_extract.py +69 -0
  11. Neo4j/neo4j_op.py +105 -0
  12. Ollama_api/__pycache__/ollama_api.cpython-310.pyc +0 -0
  13. Ollama_api/ollama_api.py +21 -0
  14. app.py +354 -0
  15. embeding/__pycache__/asr_utils.cpython-310.pyc +0 -0
  16. embeding/__pycache__/chromadb.cpython-310.pyc +0 -0
  17. embeding/__pycache__/elasticsearchStore.cpython-310.pyc +0 -0
  18. embeding/__pycache__/faissdb.cpython-310.pyc +0 -0
  19. embeding/asr_utils.py +17 -0
  20. embeding/chromadb.py +134 -0
  21. embeding/elasticsearchStore.py +147 -0
  22. embeding/faissdb.py +138 -0
  23. embeding/tmp.txt +2 -0
  24. graph_demo_ui.py +87 -0
  25. img/graph-tool.png +3 -0
  26. img/readme.txt +1 -0
  27. img/zhu.png +3 -0
  28. img/zhuye.png +0 -0
  29. img//345/244/215/346/235/202/346/226/271/345/274/217.png +0 -0
  30. img//345/276/256/344/277/241/345/233/276/347/211/207_20240524180648.jpg +0 -0
  31. rag/__init__.py +0 -0
  32. rag/__pycache__/__init__.cpython-310.pyc +0 -0
  33. rag/__pycache__/config.cpython-310.pyc +0 -0
  34. rag/__pycache__/rag_class.cpython-310.pyc +0 -0
  35. rag/__pycache__/rerank.cpython-310.pyc +0 -0
  36. rag/__pycache__/rerank.cpython-39.pyc +0 -0
  37. rag/__pycache__/rerank_code.cpython-310.pyc +0 -0
  38. rag/rag_class.py +169 -0
  39. rag/rerank_code.py +21 -0
  40. requirements.txt +10 -0
  41. test/__init__.py +0 -0
  42. test/graph2neo4j.py +25 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ img/graph-tool.png filter=lfs diff=lfs merge=lfs -text
37
+ img/zhu.png filter=lfs diff=lfs merge=lfs -text
Chroma_db/readme.txt ADDED
File without changes
Config/__pycache__/config.cpython-310.pyc ADDED
Binary file (362 Bytes). View file
 
Config/config.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 向量数据库选择 【chroma:1】 ,【faiss 2】,【ElasticsearchStore 3】
2
+ VECTOR_DB = 2
3
+ DB_directory = "./Chroma_db/"
4
+ if VECTOR_DB==2:
5
+ DB_directory ="./Faiss_db/"
6
+ elif VECTOR_DB==3:
7
+ DB_directory = "es"
8
+
9
+ # 配置neo4j
10
+ neo4j_host = "bolt://localhost:7687"
11
+ neo4j_name = "neo4j"
12
+ neo4j_pwd = "12345678"
13
+ # 测试了 llama3:8b,gemma2:9b,qwen2:7b,glm4:9b,arcee-ai/arcee-agent:latest 目前来看 qwen2:7 效果最好
14
+ neo4j_model = "qwen2:7b"
Faiss_db/readme.txt ADDED
File without changes
Faiss_db/sss1/index.faiss ADDED
Binary file (82 kB). View file
 
Faiss_db/sss1/index.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2bb588f4bd46218f42b045c42163bdcf3cc76a19e37458823ceaeaf8a1454e3b
3
+ size 9362
Neo4j/__pycache__/graph_extract.cpython-310.pyc ADDED
Binary file (2.51 kB). View file
 
Neo4j/__pycache__/neo4j_op.cpython-310.pyc ADDED
Binary file (3.89 kB). View file
 
Neo4j/graph_extract.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.llms import Ollama
2
+ from Config.config import neo4j_model
3
+
4
+ # 测试了 llama3:8b,gemma2:9b,qwen2:7b,glm4:9b,arcee-ai/arcee-agent:latest 目前来看 qwen2:7 效果最好
5
+ llm = Ollama(model=neo4j_model)
6
+
7
+ json_example = {'edges': [
8
+ {
9
+ 'label': 'label 1',
10
+ 'source': 'source 1',
11
+ 'target': 'target 1'},
12
+ {
13
+ 'label': 'label 1',
14
+ 'source': 'source 1',
15
+ 'target': 'target 1'}
16
+ ],
17
+ 'nodes': [{'name': 'label 1'},
18
+ {'name': 'label 2'},
19
+ {'name': 'label 3'}]
20
+ }
21
+
22
+ __retriever_prompt = f"""
23
+ 您是一名专门从事知识图谱创建的人工智能专家,目标是根据给定的输入或请求捕获关系。
24
+ 基于各种形式的用户输入,如段落、电子邮件、文本文件等。
25
+ 你的任务是根据输入创建一个知识图谱。
26
+ nodes中每个元素只有一个name参数,name对应的值是一个实体,实体来自输入的词语或短语。
27
+ edges还必须有一个label参数,其中label是输入中的直接词语或短语,edges中的source和target取自nodes中的name。
28
+
29
+ 仅使用JSON进行响应,其格式可以在python中进行jsonify,并直接输入cy.add(data),
30
+ 您可以参考给定的示例:{json_example}。存储node和edge的数组中,最后一个元素后边不要有逗号,
31
+ 确保边的目标和源与现有节点匹配。
32
+ 不要在JSON的上方和下方包含markdown三引号,直接用花括号括起来。
33
+ """
34
+
35
+
36
+ def generate_graph_info(raw_text: str) -> str | None:
37
+ """
38
+ generate graph info from raw text
39
+ :param raw_text:
40
+ :return:
41
+ """
42
+ messages = [
43
+ {"role": "system", "content": "你现在扮演信息抽取的角色,要求根据用户输入和AI的回答,正确提取出信息,记得不多对实体进行翻译。"},
44
+ {"role": "user", "content": raw_text},
45
+ {"role": "user", "content": __retriever_prompt}
46
+ ]
47
+ print("解析中....")
48
+ for i in range(3):
49
+ graph_info_result = llm.invoke(messages)
50
+ if len(graph_info_result) < 10:
51
+ print("-------", i, "-------------------")
52
+ continue
53
+ else:
54
+ break
55
+ print(graph_info_result)
56
+ return graph_info_result
57
+
58
+
59
+ def update_graph(raw_text):
60
+ # raw_text = request.json.get('text', '')
61
+ try:
62
+ result = generate_graph_info(raw_text)
63
+ if '```' in result:
64
+ graph_data = eval(result.split('```', 2)[1].replace("json", ''))
65
+ else:
66
+ graph_data = eval(str(result))
67
+ return graph_data
68
+ except Exception as e:
69
+ return {'error': f"Error parsing graph data: {str(e)}"}
Neo4j/neo4j_op.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from py2neo import Graph, Node, Relationship
2
+ from langchain_community.document_loaders import TextLoader,UnstructuredCSVLoader, UnstructuredPDFLoader,UnstructuredWordDocumentLoader,UnstructuredExcelLoader,UnstructuredMarkdownLoader
3
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
4
+
5
+
6
+
7
+ class KnowledgeGraph:
8
+ def __init__(self, uri, user, password):
9
+ self.graph = Graph(uri, auth=(user, password))
10
+
11
+ def parse_data(self,file):
12
+ if "txt" in file.lower() or "csv" in file.lower():
13
+ try:
14
+ loaders = UnstructuredCSVLoader(file)
15
+ data = loaders.load()
16
+ except:
17
+ loaders = TextLoader(file,encoding="utf-8")
18
+ data = loaders.load()
19
+ if ".doc" in file.lower() or ".docx" in file.lower():
20
+ loaders = UnstructuredWordDocumentLoader(file)
21
+ data = loaders.load()
22
+ if "pdf" in file.lower():
23
+ loaders = UnstructuredPDFLoader(file)
24
+ data = loaders.load()
25
+ if ".xlsx" in file.lower():
26
+ loaders = UnstructuredExcelLoader(file)
27
+ data = loaders.load()
28
+ if ".md" in file.lower():
29
+ loaders = UnstructuredMarkdownLoader(file)
30
+ data = loaders.load()
31
+ return data
32
+
33
+ # 切分 数据
34
+ def split_files(self, files,chunk_size=500, chunk_overlap=100):
35
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
36
+ print("开始创建数据库 ....")
37
+ tmps = []
38
+ for file in files:
39
+ data = self.parse_data(file)
40
+ tmps.extend(data)
41
+
42
+ splits = text_splitter.split_documents(tmps)
43
+
44
+ return splits
45
+
46
+ def create_node(self, label, properties):
47
+ matcher = self.graph.nodes.match(label, **properties)
48
+ if matcher.first():
49
+ return matcher.first()
50
+ else:
51
+ node = Node(label, **properties)
52
+ self.graph.create(node)
53
+ return node
54
+
55
+ def create_relationship(self, label1, properties1, label2, properties2, relationship_type,
56
+ relationship_properties={}):
57
+ node1 = self.create_node(label1, properties1)
58
+ node2 = self.create_node(label2, properties2)
59
+
60
+ matcher = self.graph.match((node1, node2), r_type=relationship_type)
61
+ for rel in matcher:
62
+ if all(rel[key] == value for key, value in relationship_properties.items()):
63
+ return rel
64
+
65
+ relationship = Relationship(node1, relationship_type, node2, **relationship_properties)
66
+ self.graph.create(relationship)
67
+ return relationship
68
+
69
+ def delete_node(self, label, properties):
70
+ matcher = self.graph.nodes.match(label, **properties)
71
+ node = matcher.first()
72
+ if node:
73
+ self.graph.delete(node)
74
+ return True
75
+ return False
76
+
77
+ def update_node(self, label, identifier, updates):
78
+ matcher = self.graph.nodes.match(label, **identifier)
79
+ node = matcher.first()
80
+ if node:
81
+ for key, value in updates.items():
82
+ node[key] = value
83
+ self.graph.push(node)
84
+ return node
85
+ return None
86
+
87
+ def find_node(self, label, properties):
88
+ matcher = self.graph.nodes.match(label, **properties)
89
+ return list(matcher)
90
+
91
+ def create_nodes(self, label, properties_list):
92
+ nodes = []
93
+ for properties in properties_list:
94
+ node = self.create_node(label, properties)
95
+ nodes.append(node)
96
+ return nodes
97
+
98
+ def create_relationships(self, relationships):
99
+ created_relationships = []
100
+ for rel in relationships:
101
+ label1, properties1, label2, properties2, relationship_type = rel
102
+ relationship = self.create_relationship(label1, properties1, label2, properties2, relationship_type)
103
+ created_relationships.append(relationship)
104
+ return created_relationships
105
+
Ollama_api/__pycache__/ollama_api.cpython-310.pyc ADDED
Binary file (721 Bytes). View file
 
Ollama_api/ollama_api.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import json
3
+
4
+ # 提供api获取ollama 模型列表
5
+ def get_llm():
6
+ respone = requests.get(url="http://localhost:11434/api/tags")
7
+ result = json.loads(respone.content)
8
+ llms = []
9
+ for llm in result["models"]:
10
+ if "code" not in llm["name"] and "embed" not in llm["name"]:
11
+ llms.append(llm["name"])
12
+ return llms
13
+
14
+ def get_embeding_model():
15
+ respone = requests.get(url="http://localhost:11434/api/tags")
16
+ result = json.loads(respone.content)
17
+ llms = []
18
+ for llm in result["models"]:
19
+ if "embed" in llm["name"]:
20
+ llms.append(llm["name"])
21
+ return llms
app.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import threading
3
+ import asyncio
4
+ import logging
5
+ from concurrent.futures import ThreadPoolExecutor
6
+ from functools import lru_cache
7
+ import requests
8
+ import json
9
+
10
+ # 假设这些是您的自定义模块,需要根据实际情况进行调整
11
+ from Config.config import VECTOR_DB, DB_directory
12
+ from Ollama_api.ollama_api import *
13
+ from rag.rag_class import *
14
+
15
+ # 设置日志
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+ # 根据VECTOR_DB选择合适的向量数据库
20
+ if VECTOR_DB == 1:
21
+ from embeding.chromadb import ChromaDB as vectorDB
22
+ vectordb = vectorDB(persist_directory=DB_directory)
23
+ elif VECTOR_DB == 2:
24
+ from embeding.faissdb import FaissDB as vectorDB
25
+ vectordb = vectorDB(persist_directory=DB_directory)
26
+ elif VECTOR_DB == 3:
27
+ from embeding.elasticsearchStore import ElsStore as vectorDB
28
+ vectordb = vectorDB()
29
+
30
+ # 存储上传的文件
31
+ uploaded_files = []
32
+
33
+ @lru_cache(maxsize=100)
34
+ def get_knowledge_base_files():
35
+ cl_dict = {}
36
+ cols = vectordb.get_all_collections_name()
37
+ for c_name in cols:
38
+ cl_dict[c_name] = vectordb.get_collcetion_content_files(c_name)
39
+ return cl_dict
40
+
41
+ knowledge_base_files = get_knowledge_base_files()
42
+
43
+ def upload_files(files):
44
+ if files:
45
+ new_files = [file.name for file in files]
46
+ uploaded_files.extend(new_files)
47
+ update_knowledge_base_files()
48
+ logger.info(f"Uploaded files: {new_files}")
49
+ return update_file_list(), new_files, "<div style='color: green; padding: 10px; border: 2px solid green; border-radius: 5px;'>Upload successful!</div>"
50
+ update_knowledge_base_files()
51
+ return update_file_list(), [], "<div style='color: red; padding: 10px; border: 2px solid red; border-radius: 5px;'>Upload failed!</div>"
52
+
53
+ def delete_files(selected_files):
54
+ global uploaded_files
55
+ uploaded_files = [f for f in uploaded_files if f not in selected_files]
56
+ if selected_files:
57
+ update_knowledge_base_files()
58
+ logger.info(f"Deleted files: {selected_files}")
59
+ return update_file_list(), "<div style='color: green; padding: 10px; border: 2px solid green; border-radius: 5px;'>Delete successful!</div>"
60
+ update_knowledge_base_files()
61
+ return update_file_list(), "<div style='color: red; padding: 10px; border: 2px solid red; border-radius: 5px;'>Delete failed!</div>"
62
+
63
+ def delete_collection(selected_knowledge_base):
64
+ if selected_knowledge_base and selected_knowledge_base != "创建知识库":
65
+ vectordb.delete_collection(selected_knowledge_base)
66
+ update_knowledge_base_files()
67
+ logger.info(f"Deleted collection: {selected_knowledge_base}")
68
+ return update_knowledge_base_dropdown(), "<div style='color: green; padding: 10px; border: 2px solid green; border-radius: 5px;'>Collection deleted successfully!</div>"
69
+ return update_knowledge_base_dropdown(), "<div style='color: red; padding: 10px; border: 2px solid red; border-radius: 5px;'>Delete collection failed!</div>"
70
+
71
+ async def async_vectorize_files(selected_files, selected_knowledge_base, new_kb_name, chunk_size, chunk_overlap):
72
+ if selected_files:
73
+ if selected_knowledge_base == "创建知识库":
74
+ knowledge_base = new_kb_name
75
+ vectordb.create_collection(selected_files, knowledge_base, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
76
+ else:
77
+ knowledge_base = selected_knowledge_base
78
+ vectordb.add_chroma(selected_files, knowledge_base, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
79
+
80
+ if knowledge_base not in knowledge_base_files:
81
+ knowledge_base_files[knowledge_base] = []
82
+ knowledge_base_files[knowledge_base].extend(selected_files)
83
+
84
+ logger.info(f"Vectorized files: {selected_files} for knowledge base: {knowledge_base}")
85
+ await asyncio.sleep(0) # 允许其他任务执行
86
+ return f"Vectorized files: {', '.join(selected_files)}\nKnowledge Base: {knowledge_base}\nUploaded Files: {', '.join(uploaded_files)}", "<div style='color: green; padding: 10px; border: 2px solid green; border-radius: 5px;'>Vectorization successful!</div>"
87
+ return "", "<div style='color: red; padding: 10px; border: 2px solid red; border-radius: 5px;'>Vectorization failed!</div>"
88
+
89
+ def update_file_list():
90
+ return gr.update(choices=uploaded_files, value=[])
91
+
92
+ def search_knowledge_base(selected_knowledge_base):
93
+ if selected_knowledge_base in knowledge_base_files:
94
+ kb_files = knowledge_base_files[selected_knowledge_base]
95
+ return gr.update(choices=kb_files, value=[])
96
+ return gr.update(choices=[], value=[])
97
+
98
+ def update_knowledge_base_files():
99
+ global knowledge_base_files
100
+ knowledge_base_files = get_knowledge_base_files()
101
+
102
+ # 处理聊天消息的函数
103
+ chat_history = []
104
+
105
+ def safe_chat_response(model_dropdown, vector_dropdown, chat_knowledge_base_dropdown, chain_dropdown, message):
106
+ try:
107
+ return chat_response(model_dropdown, vector_dropdown, chat_knowledge_base_dropdown, chain_dropdown, message)
108
+ except Exception as e:
109
+ logger.error(f"Error in chat response: {str(e)}")
110
+ return f"<div style='color: red;'>Error: {str(e)}</div>", ""
111
+
112
+ def chat_response(model_dropdown, vector_dropdown, chat_knowledge_base_dropdown, chain_dropdown, message):
113
+ global chat_history
114
+ if message:
115
+ chat_history.append(("User", message))
116
+ if chat_knowledge_base_dropdown == "仅使用模型":
117
+ rag = RAG_class(model=model_dropdown,persist_directory=DB_directory)
118
+ answer = rag.mult_chat(chat_history)
119
+ if chat_knowledge_base_dropdown and chat_knowledge_base_dropdown != "仅使用模型":
120
+ rag = RAG_class(model=model_dropdown, embed=vector_dropdown, c_name=chat_knowledge_base_dropdown, persist_directory=DB_directory)
121
+ if chain_dropdown == "复杂召回方式":
122
+ questions = rag.decomposition_chain(message)
123
+ answer = rag.rag_chain(questions)
124
+ elif chain_dropdown == "简单召回方式":
125
+ answer = rag.simple_chain(message)
126
+ else:
127
+ answer = rag.rerank_chain(message)
128
+
129
+ response = f" {answer}"
130
+ chat_history.append(("Bot", response))
131
+ return format_chat_history(chat_history), ""
132
+
133
+ def clear_chat():
134
+ global chat_history
135
+ chat_history = []
136
+ return format_chat_history(chat_history)
137
+
138
+ def format_chat_history(history):
139
+ formatted_history = ""
140
+ for user, msg in history:
141
+ if user == "User":
142
+ formatted_history += f'''
143
+ <div style="text-align: right; margin: 10px;">
144
+ <div style="display: inline-block; background-color: #DCF8C6; padding: 10px; border-radius: 10px; max-width: 60%;">
145
+ {msg}
146
+ </div>
147
+ <b>:User</b>
148
+ </div>
149
+ '''
150
+ else:
151
+ if "```" in msg: # 检测是否包含代码片段
152
+ code_content = msg.split("```")[1]
153
+ formatted_history += f'''
154
+ <div style="text-align: left; margin: 10px;">
155
+ <b>Bot:</b>
156
+ <div style="display: inline-block; background-color: #F1F0F0; padding: 10px; border-radius: 10px; max-width: 60%;">
157
+ <pre><code>{code_content}</code></pre>
158
+ </div>
159
+ </div>
160
+ '''
161
+ else:
162
+ formatted_history += f'''
163
+ <div style="text-align: left; margin: 10px;">
164
+ <b>Bot:</b>
165
+ <div style="display: inline-block; background-color: #F1F0F0; padding: 10px; border-radius: 10px; max-width: 60%;">
166
+ {msg}
167
+ </div>
168
+ </div>
169
+ '''
170
+ return formatted_history
171
+
172
+ def clear_status():
173
+ upload_status.update("")
174
+ delete_status.update("")
175
+ vectorize_status.update("")
176
+ delete_collection_status.update("")
177
+
178
+ def handle_knowledge_base_selection(selected_knowledge_base):
179
+ if selected_knowledge_base == "创建知识库":
180
+ return gr.update(visible=True, interactive=True), gr.update(choices=[], value=[]), gr.update(visible=False)
181
+ elif selected_knowledge_base == "仅使用模型":
182
+ return gr.update(visible=False, interactive=False), gr.update(choices=[], value=[]), gr.update(visible=False)
183
+ else:
184
+ return gr.update(visible=False, interactive=False), search_knowledge_base(selected_knowledge_base), gr.update(visible=True)
185
+
186
+ def update_knowledge_base_dropdown():
187
+ global knowledge_base_files
188
+ choices = ["创建知识库"] + list(knowledge_base_files.keys())
189
+ return gr.update(choices=choices)
190
+
191
+ def update_chat_knowledge_base_dropdown():
192
+ global knowledge_base_files
193
+ choices = ["仅使用模型"] + list(knowledge_base_files.keys())
194
+ return gr.update(choices=choices)
195
+
196
+
197
+ # SearxNG搜索函数
198
+ def search_searxng(query):
199
+ searxng_url = 'http://localhost:8080/search' # 替换为你的SearxNG实例URL
200
+ params = {
201
+ 'q': query,
202
+ 'format': 'json'
203
+ }
204
+ response = requests.get(searxng_url, params=params)
205
+ response.raise_for_status()
206
+ return response.json()
207
+
208
+
209
+ # Ollama总结函数
210
+ def summarize_with_ollama(model_dropdown,text, question):
211
+ prompt = """
212
+ 根据下边的内容,回答用户问题,
213
+ 内容为:‘{0}‘\n
214
+ 问题为:{1}
215
+ """.format(text, question)
216
+ ollama_url = 'http://localhost:11434/api/generate' # 替换为你的Ollama实例URL
217
+ data = {
218
+ 'model': model_dropdown,
219
+ "prompt": prompt,
220
+ "stream": False
221
+ }
222
+ response = requests.post(ollama_url, json=data)
223
+ response.raise_for_status()
224
+ return response.json()
225
+
226
+
227
+ # 处理函数
228
+ def ai_web_search(model_dropdown,user_query):
229
+ # 使用SearxNG进行搜索
230
+ search_results = search_searxng(user_query)
231
+ search_texts = [result['title'] + "\n" + result['content'] for result in search_results['results']]
232
+ combined_text = "\n\n".join(search_texts)
233
+
234
+ # 使用Ollama进行总结
235
+ summary = summarize_with_ollama(model_dropdown,combined_text, user_query)
236
+ # print(summary)
237
+ # 返回结果
238
+ return summary['response']
239
+ # 添加新的函数来处理AI网络搜索
240
+ # def ai_web_search(model_dropdown, query):
241
+ # try:
242
+ # # 这里添加实际的网络搜索和AI处理逻辑
243
+ # # 这只是一个示例,您需要根据实际情况实现
244
+ # search_result = f"搜索结果: {query}"
245
+ # ai_response = f"AI回答: 基于搜索结果,对于'{query}'的回答是..."
246
+ # return f"{search_result}\n\n{ai_response}"
247
+ # except Exception as e:
248
+ # logger.error(f"Error in AI web search: {str(e)}")
249
+ # return f"<div style='color: red;'>Error: {str(e)}</div>"
250
+
251
+ # 创建 Gradio 界面
252
+ with gr.Blocks() as demo:
253
+ with gr.Column():
254
+ # 添加标题
255
+ title = gr.HTML("<h1 style='text-align: center; font-size: 32px; font-weight: bold;'>RAG精致系统</h1>")
256
+ # 添加公告栏
257
+ announcement = gr.HTML("<div style='text-align: center; font-size: 18px; color: red;'>公告栏: RAG精致系统,【检索增强生成】系统!<br/>莫大大</div>")
258
+
259
+ with gr.Tabs():
260
+ with gr.TabItem("知识库"):
261
+ knowledge_base_dropdown = gr.Dropdown(choices=["创建知识库"] + list(knowledge_base_files.keys()),
262
+ label="选择知识库")
263
+ new_kb_input = gr.Textbox(label="输入新的知识库名称", visible=False, interactive=True)
264
+ file_input = gr.Files(label="Upload files")
265
+ upload_btn = gr.Button("Upload")
266
+ file_list = gr.CheckboxGroup(label="Uploaded Files")
267
+ delete_btn = gr.Button("Delete Selected Files")
268
+ with gr.Row():
269
+ chunk_size_dropdown = gr.Dropdown(choices=[50, 100, 200, 300, 500, 700], label="chunk_size", value=200)
270
+ chunk_overlap_dropdown = gr.Dropdown(choices=[20, 50, 100, 200], label="chunk_overlap", value=50)
271
+ vectorize_btn = gr.Button("Vectorize Selected Files")
272
+ delete_collection_btn = gr.Button("Delete Collection")
273
+ upload_status = gr.HTML()
274
+ delete_status = gr.HTML()
275
+ vectorize_status = gr.HTML()
276
+ delete_collection_status = gr.HTML()
277
+
278
+ with gr.TabItem("Chat"):
279
+ with gr.Row():
280
+ model_dropdown = gr.Dropdown(choices=get_llm(), label="模型")
281
+ vector_dropdown = gr.Dropdown(choices=get_embeding_model(), label="向量")
282
+ chat_knowledge_base_dropdown = gr.Dropdown(choices=["仅使用模型"] + vectordb.get_all_collections_name(), label="知识库")
283
+ chain_dropdown = gr.Dropdown(choices=["复杂召回方式", "简单召回方式","rerank"], label="chain方式", visible=False)
284
+ chat_display = gr.HTML(label="Chat History")
285
+ chat_input = gr.Textbox(label="Type a message")
286
+ chat_btn = gr.Button("Send")
287
+ clear_btn = gr.Button("Clear Chat History")
288
+
289
+ with gr.TabItem("AI网络搜索"):
290
+ with gr.Row():
291
+ web_search_model_dropdown = gr.Dropdown(choices=get_llm(), label="模型")
292
+ web_search_output = gr.Textbox(label="搜索结果和AI回答", lines=10)
293
+ web_search_input = gr.Textbox(label="输入搜索查询")
294
+
295
+ web_search_btn = gr.Button("搜索")
296
+
297
+ def handle_upload(files):
298
+ upload_result, new_files, status = upload_files(files)
299
+ threading.Thread(target=clear_status).start()
300
+ return upload_result, new_files, status, update_chat_knowledge_base_dropdown()
301
+
302
+ def handle_delete(selected_knowledge_base, selected_files):
303
+ tmp = []
304
+ cols_files_tmp = vectordb.get_collcetion_content_files(c_name=selected_knowledge_base)
305
+ for i in selected_files:
306
+ if i in cols_files_tmp:
307
+ tmp.append(i)
308
+ del cols_files_tmp
309
+ if tmp:
310
+ vectordb.del_files(tmp, c_name=selected_knowledge_base)
311
+ del tmp
312
+ delete_result, status = delete_files(selected_files)
313
+ threading.Thread(target=clear_status).start()
314
+ return delete_result, status, update_chat_knowledge_base_dropdown()
315
+
316
+ def handle_vectorize(selected_files, selected_knowledge_base, new_kb_name, chunk_size, chunk_overlap):
317
+ vectorize_result, status = asyncio.run(async_vectorize_files(selected_files, selected_knowledge_base, new_kb_name, chunk_size, chunk_overlap))
318
+ threading.Thread(target=clear_status).start()
319
+ return vectorize_result, status, update_knowledge_base_dropdown(), update_chat_knowledge_base_dropdown()
320
+
321
+ def handle_delete_collection(selected_knowledge_base):
322
+ result, status = delete_collection(selected_knowledge_base)
323
+ threading.Thread(target=clear_status).start()
324
+ return result, status, update_chat_knowledge_base_dropdown()
325
+
326
+ knowledge_base_dropdown.change(
327
+ handle_knowledge_base_selection,
328
+ inputs=knowledge_base_dropdown,
329
+ outputs=[new_kb_input, file_list, chain_dropdown]
330
+ )
331
+ upload_btn.click(handle_upload, inputs=file_input, outputs=[file_list, file_list, upload_status, chat_knowledge_base_dropdown])
332
+ delete_btn.click(handle_delete, inputs=[knowledge_base_dropdown, file_list], outputs=[file_list, delete_status, chat_knowledge_base_dropdown])
333
+ vectorize_btn.click(handle_vectorize, inputs=[file_list, knowledge_base_dropdown, new_kb_input, chunk_size_dropdown, chunk_overlap_dropdown],
334
+ outputs=[gr.Textbox(visible=False), vectorize_status, knowledge_base_dropdown, chat_knowledge_base_dropdown])
335
+ delete_collection_btn.click(handle_delete_collection, inputs=knowledge_base_dropdown,
336
+ outputs=[knowledge_base_dropdown, delete_collection_status, chat_knowledge_base_dropdown])
337
+
338
+ chat_btn.click(chat_response, inputs=[model_dropdown, vector_dropdown, chat_knowledge_base_dropdown, chain_dropdown, chat_input], outputs=[chat_display, chat_input])
339
+ clear_btn.click(clear_chat, outputs=chat_display)
340
+
341
+ chat_knowledge_base_dropdown.change(
342
+ fn=lambda selected: gr.update(visible=selected != "仅使用模型"),
343
+ inputs=chat_knowledge_base_dropdown,
344
+ outputs=chain_dropdown
345
+ )
346
+
347
+ # 添加新的点击事件处理
348
+ web_search_btn.click(
349
+ ai_web_search,
350
+ inputs=[web_search_model_dropdown, web_search_input],
351
+ outputs=web_search_output
352
+ )
353
+
354
+ demo.launch(debug=True,share=True)
embeding/__pycache__/asr_utils.cpython-310.pyc ADDED
Binary file (634 Bytes). View file
 
embeding/__pycache__/chromadb.cpython-310.pyc ADDED
Binary file (3.91 kB). View file
 
embeding/__pycache__/elasticsearchStore.cpython-310.pyc ADDED
Binary file (4.18 kB). View file
 
embeding/__pycache__/faissdb.cpython-310.pyc ADDED
Binary file (4.21 kB). View file
 
embeding/asr_utils.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #coding:utf-8
2
+ from funasr import AutoModel
3
+ # paraformer-zh is a multi-functional asr model
4
+ # use vad, punc, spk or not as you need
5
+ model = AutoModel(model="paraformer-zh", vad_model="fsmn-vad", punc_model="ct-punc",
6
+ # spk_model="cam++"
7
+ )
8
+ def get_spk_txt(file):
9
+ res = model.generate(input=file,
10
+ batch_size_s=300,
11
+ hotword='魔搭')
12
+ print(res[0]["text"])
13
+ fw = "embeding/tmp.txt"
14
+ f = open(fw,"w",encoding="utf-8")
15
+ f.write('"context"\n'+res[0]["text"])
16
+ f.close()
17
+ return fw
embeding/chromadb.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.vectorstores import Chroma
2
+ from langchain_community.embeddings import OllamaEmbeddings
3
+ from langchain_community.document_loaders import TextLoader,UnstructuredCSVLoader, UnstructuredPDFLoader,UnstructuredWordDocumentLoader,UnstructuredExcelLoader,UnstructuredMarkdownLoader
4
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
5
+ from .asr_utils import get_spk_txt
6
+
7
+ class ChromaDB():
8
+ def __init__(self, embedding="mofanke/acge_text_embedding:latest", persist_directory="./Chroma_db/"):
9
+
10
+ self.embedding = OllamaEmbeddings(model=embedding)
11
+ self.persist_directory = persist_directory
12
+ self.chromadb = Chroma(persist_directory=persist_directory)
13
+ self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=200, chunk_overlap=50)
14
+
15
+ def parse_data(self,file):
16
+ if "txt" in file.lower() or "csv" in file.lower():
17
+ try:
18
+ loaders = UnstructuredCSVLoader(file)
19
+ data = loaders.load()
20
+ except:
21
+ loaders = TextLoader(file,encoding="utf-8")
22
+ data = loaders.load()
23
+ if ".doc" in file.lower() or ".docx" in file.lower():
24
+ loaders = UnstructuredWordDocumentLoader(file)
25
+ data = loaders.load()
26
+ if "pdf" in file.lower():
27
+ loaders = UnstructuredPDFLoader(file)
28
+ data = loaders.load()
29
+ if ".xlsx" in file.lower():
30
+ loaders = UnstructuredExcelLoader(file)
31
+ data = loaders.load()
32
+ if ".md" in file.lower():
33
+ loaders = UnstructuredMarkdownLoader(file)
34
+ data = loaders.load()
35
+ if "mp3" in file.lower() or "mp4" in file.lower() or "wav" in file.lower():
36
+ # 语音解析成文字
37
+ fw = get_spk_txt(file)
38
+ loaders = UnstructuredCSVLoader(fw)
39
+ data = loaders.load()
40
+ tmp = []
41
+ for i in data:
42
+ i.metadata["source"] = file
43
+ tmp.append(i)
44
+ data = tmp
45
+ return data
46
+
47
+ # 创建 新的collection 并且初始化
48
+ def create_collection(self, files, c_name,chunk_size=200, chunk_overlap=50):
49
+ self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
50
+ print("开始创建数据库 ....")
51
+ tmps = []
52
+ for file in files:
53
+ data = self.parse_data(file)
54
+ tmps.extend(data)
55
+
56
+ splits = self.text_splitter.split_documents(tmps)
57
+
58
+ vectorstore = self.chromadb.from_documents(documents=splits, collection_name=c_name,
59
+ embedding=self.embedding, persist_directory=self.persist_directory)
60
+ print("数据块总量:", vectorstore._collection.count())
61
+
62
+ return vectorstore
63
+
64
+ # 添加 数据到已有数据库
65
+ def add_chroma(self, files, c_name,chunk_size=200, chunk_overlap=50):
66
+ self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
67
+ print("开始添加文件...")
68
+ tmps = []
69
+ for file in files:
70
+ data = self.parse_data(file)
71
+ tmps.extend(data)
72
+
73
+ splits = self.text_splitter.split_documents(tmps)
74
+
75
+ vectorstore = Chroma(persist_directory=self.persist_directory, collection_name=c_name,
76
+ embedding_function=self.embedding)
77
+ vectorstore.add_documents(splits)
78
+ print("数据块总量:", vectorstore._collection.count())
79
+
80
+ return vectorstore
81
+
82
+ # 删除 某个collection中的 某个文件
83
+ def del_files(self, del_files_name, c_name):
84
+
85
+ vectorstore = self.chromadb._client.get_collection(c_name)
86
+ del_ids = []
87
+ vec_dict = vectorstore.get()
88
+ for id, md in zip(vec_dict["ids"], vec_dict["metadatas"]):
89
+ for dl in del_files_name:
90
+ if dl in md["source"]:
91
+ del_ids.append(id)
92
+ vectorstore.delete(ids=del_ids)
93
+ print("数据块总量:", vectorstore.count())
94
+
95
+ return vectorstore
96
+
97
+ # 删除某个 知识库 collection
98
+ def delete_collection(self, c_name):
99
+
100
+ self.chromadb._client.delete_collection(c_name)
101
+
102
+ # 获取目前所有 collection
103
+ def get_all_collections_name(self):
104
+ cl_names = []
105
+
106
+ test = self.chromadb._client.list_collections()
107
+ for i in range(len(test)):
108
+ cl_names.append(test[i].name)
109
+ return cl_names
110
+
111
+ # 获取 collection中的所有文件
112
+ def get_collcetion_content_files(self, c_name):
113
+ vectorstore = self.chromadb._client.get_collection(c_name)
114
+ c_files = []
115
+ vec_dict = vectorstore.get()
116
+ for md in vec_dict["metadatas"]:
117
+ c_files.append(md["source"])
118
+ return list(set(c_files))
119
+
120
+
121
+ # if __name__ == "__main__":
122
+ # chromadb = ChromaDB()
123
+ # c_name = "sss3"
124
+ #
125
+ # print(chromadb.get_all_collections_name())
126
+ # chromadb.create_collection(["data/���内科学.txt", "data/jl.pdf"], c_name=c_name)
127
+ # print(chromadb.get_all_collections_name())
128
+ # chromadb.add_chroma(["data/儿科学.txt"], c_name=c_name)
129
+ # print(c_name, "包含的文件:", chromadb.get_collcetion_content_files(c_name))
130
+ # chromadb.del_files(["data/肾内科学.txt"], c_name=c_name)
131
+ # print(c_name, "包含的文件:", chromadb.get_collcetion_content_files(c_name))
132
+ # print(chromadb.get_all_collections_name())
133
+ # chromadb.delete_collection(c_name=c_name)
134
+ # print(chromadb.get_all_collections_name())
embeding/elasticsearchStore.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from elasticsearch import Elasticsearch
2
+ from langchain_elasticsearch.vectorstores import ElasticsearchStore
3
+ from langchain_community.embeddings import OllamaEmbeddings
4
+ from langchain_community.document_loaders import TextLoader, UnstructuredCSVLoader, UnstructuredPDFLoader, \
5
+ UnstructuredWordDocumentLoader, UnstructuredExcelLoader, UnstructuredMarkdownLoader
6
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
7
+ from .asr_utils import get_spk_txt
8
+ import requests
9
+
10
+
11
+ class ElsStore():
12
+ def __init__(self, embedding="mofanke/acge_text_embedding:latest", es_url="http://localhost:9200",
13
+ index_name='test_index'):
14
+ self.embedding = OllamaEmbeddings(model=embedding)
15
+ self.es_url = es_url
16
+ self.elastic_vector_search = ElasticsearchStore(
17
+ es_url=self.es_url,
18
+ index_name=index_name,
19
+ embedding=self.embedding
20
+ )
21
+
22
+ def parse_data(self, file):
23
+ if "txt" in file.lower() or "csv" in file.lower():
24
+ try:
25
+ loaders = UnstructuredCSVLoader(file)
26
+ data = loaders.load()
27
+ except:
28
+ loaders = TextLoader(file, encoding="utf-8")
29
+ data = loaders.load()
30
+ if ".doc" in file.lower() or ".docx" in file.lower():
31
+ loaders = UnstructuredWordDocumentLoader(file)
32
+ data = loaders.load()
33
+ if "pdf" in file.lower():
34
+ loaders = UnstructuredPDFLoader(file)
35
+ data = loaders.load()
36
+ if ".xlsx" in file.lower():
37
+ loaders = UnstructuredExcelLoader(file)
38
+ data = loaders.load()
39
+ if ".md" in file.lower():
40
+ loaders = UnstructuredMarkdownLoader(file)
41
+ data = loaders.load()
42
+ if "mp3" in file.lower() or "mp4" in file.lower() or "wav" in file.lower():
43
+ # 语音解析成文字
44
+ fw = get_spk_txt(file)
45
+ loaders = UnstructuredCSVLoader(fw)
46
+ data = loaders.load()
47
+ tmp = []
48
+ for i in data:
49
+ i.metadata["source"] = file
50
+ tmp.append(i)
51
+ data = tmp
52
+ return data
53
+
54
+ def get_count(self, c_name):
55
+ # 获取index-anme中的数据块数
56
+
57
+ # 初始化 Elasticsearch 客户端
58
+ es = Elasticsearch([{
59
+ 'host': self.es_url.split(":")[1][2:],
60
+ 'port': int(self.es_url.split(":")[2]),
61
+ 'scheme': 'http' # 指定使用的协议
62
+ }])
63
+
64
+ # 指定索引名称
65
+ index_name = c_name
66
+
67
+ # 获取文档总数
68
+ response = es.count(index=index_name)
69
+
70
+ # 输出文档总数
71
+ return response['count']
72
+
73
+ # 创建 新的index_name 并且初始化
74
+ def create_collection(self, files, c_name, chunk_size=200, chunk_overlap=50):
75
+ self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
76
+ print("开始创建数据库 ....")
77
+ tmps = []
78
+ for file in files:
79
+ data = self.parse_data(file)
80
+ tmps.extend(data)
81
+
82
+ splits = self.text_splitter.split_documents(tmps)
83
+
84
+ self.elastic_vector_search = ElasticsearchStore.from_documents(
85
+ documents=splits,
86
+ embedding=self.embedding,
87
+ es_url=self.es_url,
88
+ index_name=c_name,
89
+ )
90
+
91
+ self.elastic_vector_search.client.indices.refresh(index=c_name)
92
+
93
+ print("数据块总量:", self.get_count(c_name))
94
+
95
+ return self.elastic_vector_search
96
+
97
+ # 添加 数据到已有数据库
98
+ def add_chroma(self, files, c_name, chunk_size=200, chunk_overlap=50):
99
+ self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
100
+ print("开始添加文件...")
101
+ tmps = []
102
+ for file in files:
103
+ data = self.parse_data(file)
104
+ tmps.extend(data)
105
+
106
+ splits = self.text_splitter.split_documents(tmps)
107
+
108
+ self.elastic_vector_search = ElasticsearchStore(
109
+ es_url=self.es_url,
110
+ index_name=c_name,
111
+ embedding=self.embedding
112
+ )
113
+ self.elastic_vector_search.add_documents(splits)
114
+ self.elastic_vector_search.client.indices.refresh(index=c_name)
115
+ print("数据块总量:", self.get_count(c_name))
116
+
117
+ return self.elastic_vector_search
118
+
119
+ # 删除某个 知识库 collection
120
+ def delete_collection(self, c_name):
121
+ url = self.es_url + "/" + c_name
122
+ # 发送 DELETE 请求
123
+ response = requests.delete(url)
124
+
125
+ # 检查响应状态码
126
+ if response.status_code == 200:
127
+ return f"索引 'test-basic1' 已成功删除。"
128
+ elif response.status_code == 404:
129
+ return f"索引 'test-basic1' 不存在。"
130
+ else:
131
+ return f"删除索引时出错: {response.status_code}, {response.text}"
132
+
133
+ # 获取目前所有 index_names
134
+ def get_all_collections_name(self):
135
+ indices = self.elastic_vector_search.client.indices.get_alias()
136
+ index_names = list(indices.keys())
137
+
138
+ return index_names
139
+
140
+ def get_collcetion_content_files(self,c_name):
141
+ return []
142
+
143
+ # 删除 某个collection中的 某个文件
144
+ def del_files(self, del_files_name, c_name):
145
+ return None
146
+
147
+
embeding/faissdb.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.vectorstores import FAISS
2
+ from langchain_community.embeddings import OllamaEmbeddings
3
+ from langchain_community.document_loaders import TextLoader,UnstructuredCSVLoader, UnstructuredPDFLoader,UnstructuredWordDocumentLoader,UnstructuredExcelLoader,UnstructuredMarkdownLoader
4
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
5
+ import shutil
6
+ import os
7
+ from .asr_utils import get_spk_txt
8
+
9
+
10
+ class FaissDB():
11
+ def __init__(self, embedding="mofanke/acge_text_embedding:latest", persist_directory="./Faiss_db/"):
12
+
13
+ self.embedding = OllamaEmbeddings(model=embedding)
14
+ self.persist_directory = persist_directory
15
+ self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=50, add_start_index=True)
16
+
17
+ def parse_data(self,file):
18
+ if "txt" in file.lower() or "csv" in file.lower():
19
+ try:
20
+ loaders = UnstructuredCSVLoader(file)
21
+ data = loaders.load()
22
+ except:
23
+ loaders = TextLoader(file,encoding="utf-8")
24
+ data = loaders.load()
25
+ if ".doc" in file.lower() or ".docx" in file.lower():
26
+ loaders = UnstructuredWordDocumentLoader(file)
27
+ data = loaders.load()
28
+ if "pdf" in file.lower():
29
+ loaders = UnstructuredPDFLoader(file)
30
+ data = loaders.load()
31
+ if ".xlsx" in file.lower():
32
+ loaders = UnstructuredExcelLoader(file)
33
+ data = loaders.load()
34
+ if ".md" in file.lower():
35
+ loaders = UnstructuredMarkdownLoader(file)
36
+ data = loaders.load()
37
+ if "mp3" in file.lower() or "mp4" in file.lower() or "wav" in file.lower():
38
+ # 语音解析成文字
39
+ fw = get_spk_txt(file)
40
+ loaders = UnstructuredCSVLoader(fw)
41
+ data = loaders.load()
42
+ tmp = []
43
+ for i in data:
44
+ i.metadata["source"] = file
45
+ tmp.append(i)
46
+ data = tmp
47
+ return data
48
+
49
+ # 创建 新的collection 并且初始化
50
+ def create_collection(self, files, c_name,chunk_size=200, chunk_overlap=50):
51
+ self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
52
+ print("开始创建数据库 ....")
53
+ tmps = []
54
+ for file in files:
55
+ data = self.parse_data(file)
56
+ tmps.extend(data)
57
+
58
+ splits = self.text_splitter.split_documents(tmps)
59
+
60
+ vectorstore = FAISS.from_documents(documents=splits,
61
+ embedding=self.embedding)
62
+ vectorstore.save_local(self.persist_directory + c_name)
63
+ print("数据块总量:", vectorstore.index.ntotal)
64
+
65
+ return vectorstore
66
+
67
+ # 添加 数据到已有数据库
68
+ def add_chroma(self, files, c_name,chunk_size=200, chunk_overlap=50):
69
+ self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
70
+ print("开始添加文件...")
71
+ tmps = []
72
+ for file in files:
73
+ data = self.parse_data(file)
74
+ tmps.extend(data)
75
+
76
+ splits = self.text_splitter.split_documents(tmps)
77
+
78
+ vectorstore = FAISS.load_local(folder_path=self.persist_directory + c_name, embeddings=self.embedding,
79
+ allow_dangerous_deserialization=True)
80
+ vectorstore.add_documents(documents=splits)
81
+ vectorstore.save_local("Faiss_db/" + c_name)
82
+ print("数据块总量:", vectorstore.index.ntotal)
83
+
84
+ return vectorstore
85
+
86
+ # 删除 某个collection中的 某个文件
87
+ def del_files(self, del_files_name, c_name):
88
+
89
+ vectorstore = FAISS.load_local(folder_path=self.persist_directory + c_name, embeddings=self.embedding,
90
+ allow_dangerous_deserialization=True)
91
+ del_ids = []
92
+ vec_dict = vectorstore.docstore._dict
93
+ for id, md in vec_dict.items():
94
+ for dl in del_files_name:
95
+ if dl in md.metadata["source"]:
96
+ del_ids.append(id)
97
+ vectorstore.delete(ids=del_ids)
98
+ vectorstore.save_local(self.persist_directory + c_name)
99
+ print("数据块总量:", vectorstore.index.ntotal)
100
+
101
+ return vectorstore
102
+
103
+ # 删除某个 知识库 collection
104
+ def delete_collection(self, c_name):
105
+ shutil.rmtree(self.persist_directory + c_name)
106
+
107
+ # 获取目前所有 collection
108
+ def get_all_collections_name(self):
109
+ cl_names = [i for i in os.listdir(self.persist_directory) if os.path.isdir(self.persist_directory+i)]
110
+
111
+ return cl_names
112
+
113
+ # 获取 collection中的所有文件
114
+ def get_collcetion_content_files(self, c_name):
115
+ vectorstore = FAISS.load_local(folder_path=self.persist_directory + c_name, embeddings=self.embedding,
116
+ allow_dangerous_deserialization=True)
117
+ c_files = []
118
+ vec_dict = vectorstore.docstore._dict
119
+ for _, md in vec_dict.items():
120
+ c_files.append(md.metadata["source"])
121
+
122
+ return list(set(c_files))
123
+
124
+
125
+ # if __name__ == "__main__":
126
+ # chromadb = FaissDB()
127
+ # c_name = "sss3"
128
+ #
129
+ # print(chromadb.get_all_collections_name())
130
+ # chromadb.create_collection(["data/jl.txt", "data/jl.pdf"], c_name=c_name)
131
+ # print(chromadb.get_all_collections_name())
132
+ # chromadb.add_chroma(["data/tmp.txt"], c_name=c_name)
133
+ # print(c_name, "包含的文件:", chromadb.get_collcetion_content_files(c_name))
134
+ # chromadb.del_files(["data/tmp.txt"], c_name=c_name)
135
+ # print(c_name, "包含的文件:", chromadb.get_collcetion_content_files(c_name))
136
+ # print(chromadb.get_all_collections_name())
137
+ # chromadb.delete_collection(c_name=c_name)
138
+ # print(chromadb.get_all_collections_name())
embeding/tmp.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ "context"
2
+ 你是不是觉得自己说话的声音直来直去呢?现在告诉你一个主持人吐字的小秘密,那就是每个字在口腔当中像是翻跟头一样打一圈再出来。比如说故人西辞黄鹤楼,而不是故人西辞黄鹤楼。再比如说乌衣巷口夕阳斜,而不是乌衣巷口夕阳斜,你也试试看抖音。
graph_demo_ui.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ from flask import Flask, render_template, request, jsonify
3
+ import json
4
+ from dotenv import load_dotenv
5
+ from langchain_community.llms import Ollama
6
+
7
+
8
+ load_dotenv()
9
+
10
+ app = Flask(__name__)
11
+
12
+ # 测试了 llama3:8b,gemma2:9b,qwen2:7b,glm4:9b,arcee-ai/arcee-agent:latest 目前来看 qwen2:7 效果最好
13
+ llm = Ollama(model="qwen2:7b")
14
+
15
+
16
+ json_example = {'edges': [{'data': {'color': '#FFA07A',
17
+ 'label': 'label 1',
18
+ 'source': 'source 1',
19
+ 'target': 'target 1'}},
20
+ {'data': {'color': '#FFA07A',
21
+ 'label': 'label 2',
22
+ 'source': 'source 2',
23
+ 'target': 'target 2'}}
24
+ ],
25
+ 'nodes': [{'data': {'color': '#FFC0CB', 'id': 'id 1', 'label': 'label 1'}},
26
+ {'data': {'color': '#90EE90', 'id': 'id 2', 'label': 'label 2'}},
27
+ {'data': {'color': '#87CEEB', 'id': 'id 3', 'label': 'label 3'}}]}
28
+
29
+
30
+
31
+ __retriever_prompt = f"""
32
+ 您是一名专门从事知识图谱创建的人工智能专家,目标是根据给定的输入或请求捕获关系。
33
+ 基于各种形式的用户输入,如段落、电子邮件、文本文件等。
34
+ 你的任务是根据输入创建一个知识图谱。
35
+ nodes必须具有label参数,并且label是来自输入的词语或短语,nodes必须具有id参数,id的格式是"id_数字",不能重复。
36
+ edges还必须有一个label参数,其中label是输入中的直接词语或短语,edges中的source和target取自nodes中的id。
37
+ 仅使用JSON进行响应,其格式可以在python中进行jsonify,并直接输入cy.add(data),包括“color”属性,以在前端显示图形。
38
+ 您可以参考给定的示例:{json_example}。存储node和edge的数组中,最后一个元素后边不要有逗号,
39
+ 确保边的目标和源与现有节点匹配。
40
+ 不要在JSON的上方和下方包含markdown三引号,直接用花括号括起来。
41
+ """
42
+
43
+
44
+ def generate_graph_info(raw_text: str) -> str | None:
45
+ """
46
+ generate graph info from raw text
47
+ :param raw_text:
48
+ :return:
49
+ """
50
+ messages = [
51
+ {"role": "system", "content": "你现在扮演信息抽取的角色,要求根据用户输入和AI的回答,正确提取出信息,记得不多对实体进行翻译。"},
52
+ {"role": "user", "content": raw_text},
53
+ {"role": "user", "content": __retriever_prompt}
54
+ ]
55
+ print("解析中....")
56
+ for i in range(3):
57
+ graph_info_result = llm.invoke(messages)
58
+ if len(graph_info_result)<10:
59
+ print("-------",i,"-------------------")
60
+ continue
61
+ else:
62
+ break
63
+ print(graph_info_result)
64
+ return graph_info_result
65
+
66
+
67
+ @app.route('/')
68
+ def index():
69
+ return render_template('index.html')
70
+
71
+
72
+ @app.route('/update_graph', methods=['POST'])
73
+ def update_graph():
74
+ raw_text = request.json.get('text', '')
75
+ try:
76
+ result = generate_graph_info(raw_text)
77
+ if '```' in result:
78
+ graph_data=json.loads(result.split('```',2)[1].replace("json", ''))
79
+ else:
80
+ graph_data=json.loads(result)
81
+ return graph_data
82
+ except Exception as e:
83
+ return {'error': f"Error parsing graph data: {str(e)}"}
84
+
85
+
86
+ if __name__ == '__main__':
87
+ app.run(host='0.0.0.0', port=7860)
img/graph-tool.png ADDED

Git LFS Details

  • SHA256: 8aa90d4cba907a57c8d5cc5e2c193240955c1c9cee23dcffc8bbd597616f6bed
  • Pointer size: 131 Bytes
  • Size of remote file: 207 kB
img/readme.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 1
img/zhu.png ADDED

Git LFS Details

  • SHA256: 71f5efce94f123211ea1b7392e9644f953ae158ee10a3c4d28f522a23b9387b8
  • Pointer size: 131 Bytes
  • Size of remote file: 630 kB
img/zhuye.png ADDED
img//345/244/215/346/235/202/346/226/271/345/274/217.png ADDED
img//345/276/256/344/277/241/345/233/276/347/211/207_20240524180648.jpg ADDED
rag/__init__.py ADDED
File without changes
rag/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (132 Bytes). View file
 
rag/__pycache__/config.cpython-310.pyc ADDED
Binary file (364 Bytes). View file
 
rag/__pycache__/rag_class.cpython-310.pyc ADDED
Binary file (5.39 kB). View file
 
rag/__pycache__/rerank.cpython-310.pyc ADDED
Binary file (878 Bytes). View file
 
rag/__pycache__/rerank.cpython-39.pyc ADDED
Binary file (869 Bytes). View file
 
rag/__pycache__/rerank_code.cpython-310.pyc ADDED
Binary file (883 Bytes). View file
 
rag/rag_class.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.vectorstores import Chroma,FAISS
2
+ from langchain_community.llms import Ollama
3
+ from langchain_core.output_parsers import StrOutputParser
4
+ from langchain_community.embeddings import OllamaEmbeddings
5
+ from langchain_core.runnables import RunnablePassthrough
6
+ from operator import itemgetter
7
+ from langchain.prompts import ChatPromptTemplate
8
+ from rerank_code import rerank_topn
9
+ from Config.config import VECTOR_DB,DB_directory
10
+ from langchain_elasticsearch.vectorstores import ElasticsearchStore
11
+
12
+
13
+ class RAG_class:
14
+ def __init__(self, model="qwen2:7b", embed="milkey/dmeta-embedding-zh:f16", c_name="sss1",
15
+ persist_directory="E:/pycode/jupyter_code/langGraph/sss2/chroma.sqlite3/",es_url="http://localhost:9200"):
16
+ template = """
17
+ 根据上下文回答以下问题,不要自己发挥,要根据以下参考内容总结答案,如果以下内容无法得到答案,就返回无法根据参考内容获取答案,
18
+
19
+ 参考内容为:{context}
20
+
21
+ 问题: {question}
22
+ """
23
+
24
+ self.prompts = ChatPromptTemplate.from_template(template)
25
+
26
+ # 使用 问题扩展+结果递归方式得到最终答案
27
+ template1 = """你是一个乐于助人的助手,可以生成与输入问题相关的多个子问题。
28
+ 目标是将输入分解为一组可以单独回答的子问题/子问题。
29
+ 生成多个与以下内容相关的搜索查询:{question}
30
+ 输出4个相关问题,以换行符隔开:"""
31
+ self.prompt_questions = ChatPromptTemplate.from_template(template1)
32
+
33
+ # 构建 问答对
34
+ template2 = """
35
+ 以下是您需要回答的问题:
36
+
37
+ \n--\n {question} \n---\n
38
+
39
+ 以下是任何可用的背景问答对:
40
+
41
+ \n--\n {q_a_pairs} \n---\n
42
+
43
+ 以下是与该问题相关的其他上下文:
44
+
45
+ \n--\n {context} \n---\n
46
+
47
+ 使用以上上下文和背景问答对来回答问题,问题是:{question} ,答案是:
48
+ """
49
+ self.decomposition_prompt = ChatPromptTemplate.from_template(template2)
50
+
51
+ self.llm = Ollama(model=model)
52
+ self.embeding = OllamaEmbeddings(model=embed)
53
+ if VECTOR_DB==1:
54
+ self.vectstore = Chroma(embedding_function=self.embeding, collection_name=c_name,
55
+ persist_directory=persist_directory)
56
+ elif VECTOR_DB ==2:
57
+ self.vectstore = FAISS.load_local(folder_path=persist_directory + c_name, embeddings=self.embeding,
58
+ allow_dangerous_deserialization=True)
59
+ elif VECTOR_DB ==3:
60
+ self.vectstore = ElasticsearchStore(
61
+ es_url=es_url,
62
+ index_name=c_name,
63
+ embedding=self.embeding
64
+ )
65
+ self.retriever = self.vectstore.as_retriever()
66
+ try:
67
+ if VECTOR_DB==1:
68
+ self.vectstore = Chroma(embedding_function=self.embeding, collection_name=c_name,
69
+ persist_directory=persist_directory)
70
+ elif VECTOR_DB ==2:
71
+ self.vectstore = FAISS.load_local(folder_path=persist_directory + c_name, embeddings=self.embeding,
72
+ allow_dangerous_deserialization=True)
73
+ elif VECTOR_DB ==3:
74
+ self.vectstore = ElasticsearchStore(
75
+ es_url=es_url,
76
+ index_name=c_name,
77
+ embedding=self.embeding
78
+ )
79
+ self.retriever = self.vectstore.as_retriever()
80
+ except Exception as e:
81
+ print("仅模型时无需加载数据库",e)
82
+ #
83
+ # Post-processing
84
+ def format_docs(self,docs):
85
+ return "\n\n".join(doc.page_content for doc in docs)
86
+ # 传统方式召回,单问题召回,然后llm总结答案回答
87
+ def simple_chain(self,question):
88
+ _chain = (
89
+ {"context": self.retriever|self.format_docs,"question":RunnablePassthrough()}
90
+ |self.prompts
91
+ |self.llm
92
+ |StrOutputParser()
93
+ )
94
+ answer = _chain.invoke({"question":question})
95
+ return answer
96
+
97
+ def rerank_chain(self,question):
98
+ retriever = self.vectstore.as_retriever(search_kwargs={"k": 10})
99
+ docs = retriever.invoke(question)
100
+ docs = rerank_topn(question,docs,N=5)
101
+ _chain = (
102
+ self.prompts
103
+ | self.llm
104
+ | StrOutputParser()
105
+ )
106
+ answer = _chain.invoke({"context":self.format_docs(docs),"question": question})
107
+ return answer
108
+
109
+ def format_qa_pairs(self, question, answer):
110
+ formatted_string = ""
111
+ formatted_string += f"Question: {question}\nAnswer:{answer}\n\n"
112
+ return formatted_string
113
+
114
+ # 获取问题的 扩展问题
115
+ def decomposition_chain(self, question):
116
+ _chain = (
117
+ {"question": RunnablePassthrough()}
118
+ | self.prompt_questions
119
+ | self.llm
120
+ | StrOutputParser()
121
+ | (lambda x: x.split("\n"))
122
+ )
123
+
124
+ questions = _chain.invoke({"question": question}) + [question]
125
+
126
+ return questions
127
+ # 多问题递归召回,每次召回后,问题和答案同时作为下一次召回的参考,再次用新问题召回
128
+ def rag_chain(self, questions):
129
+ q_a_pairs = ""
130
+ for q in questions:
131
+ _chain = (
132
+ {"context": itemgetter("question") | self.retriever,
133
+ "question": itemgetter("question"),
134
+ "q_a_pairs": itemgetter("q_a_paris")
135
+ }
136
+ | self.decomposition_prompt
137
+ | self.llm
138
+ | StrOutputParser()
139
+ )
140
+
141
+ answer = _chain.invoke({"question": q, "q_a_paris": q_a_pairs})
142
+ q_a_pairs = self.format_qa_pairs(q, answer)
143
+ q_a_pairs = q_a_pairs + "\n----\n" + q_a_pairs
144
+ return answer
145
+
146
+ # 将聊天历史格式化为一个字符串
147
+ def format_chat_history(self,history):
148
+ formatted_history = ""
149
+ for role,content in history:
150
+ formatted_history += f"{role}: {content}\n"
151
+ return formatted_history
152
+ # 基于ollama大模型的大模型 多轮对话,不使用知识库的
153
+ def mult_chat(self,chat_history):
154
+ # 格式化聊天历史
155
+ formatted_history = self.format_chat_history(chat_history)
156
+
157
+ # 调用模型生成回复
158
+ response = self.llm.invoke(formatted_history)
159
+ return response
160
+
161
+
162
+
163
+ # if __name__ == "__main__":
164
+ # rag = RAG_class(model="deepseek-r1:14b")
165
+ # question = "人卫社官网网址是?"
166
+ # questions = rag.decomposition_chain(question)
167
+ # print(questions)
168
+ # answer = rag.rag_chain(questions)
169
+ # print(answer)
rag/rerank_code.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
3
+
4
+ tokenizer = AutoTokenizer.from_pretrained('E:\\model\\bge-reranker-large')
5
+ model = AutoModelForSequenceClassification.from_pretrained('E:\\model\\bge-reranker-large')
6
+ model.eval()
7
+
8
+
9
+ def rerank_topn(question,docs,N=5):
10
+ pairs = []
11
+ for i in docs:
12
+ pairs.append([question,i.page_content])
13
+
14
+ with torch.no_grad():
15
+ inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
16
+ scores = model(**inputs, return_dict=True).logits.view(-1, ).float()
17
+ scores = scores.argsort().numpy()[::-1][:N]
18
+ bk = []
19
+ for i in scores:
20
+ bk.append(docs[i])
21
+ return bk
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==4.29.0
2
+ langchain-community==0.2.6
3
+ langchain==0.2.6
4
+ langchain-core==0.2.11
5
+ requests
6
+ transformers==4.41.1
7
+ unstructured==0.7.12
8
+ funasr==1.0.24
9
+ modelscope
10
+ chromadb
test/__init__.py ADDED
File without changes
test/graph2neo4j.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import sys
3
+ sys.path.append(r"..//")#
4
+ from Neo4j.neo4j_op import KnowledgeGraph
5
+ from Neo4j.graph_extract import update_graph
6
+ from Config.config import neo4j_host,neo4j_name,neo4j_pwd
7
+
8
+
9
+
10
+ kg = KnowledgeGraph(neo4j_host,neo4j_name,neo4j_pwd)
11
+
12
+
13
+ if __name__ == "__main__":
14
+
15
+ text = """范冰冰,1981年9月16日生于山东青岛,毕业于上海师范大学谢晋影视艺术学院,中国女演员,歌手。
16
+ 1998年参演电视剧《还珠格格》成名。2004年主演电影《手机》获得第27届大众电影百花奖最佳女演员奖。"""
17
+ res = update_graph(text)
18
+ # 批量创建节点
19
+ nodes = kg.create_nodes("node", res["nodes"])
20
+ print(nodes)
21
+ # 批量创建关系
22
+ relationships = kg.create_relationships([
23
+ ("node", {"name": edge["source"]}, "node", {"name": edge["target"]}, edge["label"]) for edge in res["edges"]
24
+ ])
25
+ print(relationships)