Spaces:
Runtime error
Runtime error
| from elasticsearch import Elasticsearch | |
| from langchain_elasticsearch.vectorstores import ElasticsearchStore | |
| from langchain_community.embeddings import OllamaEmbeddings | |
| from langchain_community.document_loaders import TextLoader, UnstructuredCSVLoader, UnstructuredPDFLoader, \ | |
| UnstructuredWordDocumentLoader, UnstructuredExcelLoader, UnstructuredMarkdownLoader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from .asr_utils import get_spk_txt | |
| import requests | |
| class ElsStore(): | |
| def __init__(self, embedding="mofanke/acge_text_embedding:latest", es_url="http://localhost:9200", | |
| index_name='test_index'): | |
| self.embedding = OllamaEmbeddings(model=embedding) | |
| self.es_url = es_url | |
| self.elastic_vector_search = ElasticsearchStore( | |
| es_url=self.es_url, | |
| index_name=index_name, | |
| embedding=self.embedding | |
| ) | |
| def parse_data(self, file): | |
| if "txt" in file.lower() or "csv" in file.lower(): | |
| try: | |
| loaders = UnstructuredCSVLoader(file) | |
| data = loaders.load() | |
| except: | |
| loaders = TextLoader(file, encoding="utf-8") | |
| data = loaders.load() | |
| if ".doc" in file.lower() or ".docx" in file.lower(): | |
| loaders = UnstructuredWordDocumentLoader(file) | |
| data = loaders.load() | |
| if "pdf" in file.lower(): | |
| loaders = UnstructuredPDFLoader(file) | |
| data = loaders.load() | |
| if ".xlsx" in file.lower(): | |
| loaders = UnstructuredExcelLoader(file) | |
| data = loaders.load() | |
| if ".md" in file.lower(): | |
| loaders = UnstructuredMarkdownLoader(file) | |
| data = loaders.load() | |
| if "mp3" in file.lower() or "mp4" in file.lower() or "wav" in file.lower(): | |
| # 语音解析成文字 | |
| fw = get_spk_txt(file) | |
| loaders = UnstructuredCSVLoader(fw) | |
| data = loaders.load() | |
| tmp = [] | |
| for i in data: | |
| i.metadata["source"] = file | |
| tmp.append(i) | |
| data = tmp | |
| return data | |
| def get_count(self, c_name): | |
| # 获取index-anme中的数据块数 | |
| # 初始化 Elasticsearch 客户端 | |
| es = Elasticsearch([{ | |
| 'host': self.es_url.split(":")[1][2:], | |
| 'port': int(self.es_url.split(":")[2]), | |
| 'scheme': 'http' # 指定使用的协议 | |
| }]) | |
| # 指定索引名称 | |
| index_name = c_name | |
| # 获取文档总数 | |
| response = es.count(index=index_name) | |
| # 输出文档总数 | |
| return response['count'] | |
| # 创建 新的index_name 并且初始化 | |
| def create_collection(self, files, c_name, chunk_size=200, chunk_overlap=50): | |
| self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) | |
| print("开始创建数据库 ....") | |
| tmps = [] | |
| for file in files: | |
| data = self.parse_data(file) | |
| tmps.extend(data) | |
| splits = self.text_splitter.split_documents(tmps) | |
| self.elastic_vector_search = ElasticsearchStore.from_documents( | |
| documents=splits, | |
| embedding=self.embedding, | |
| es_url=self.es_url, | |
| index_name=c_name, | |
| ) | |
| self.elastic_vector_search.client.indices.refresh(index=c_name) | |
| print("数据块总量:", self.get_count(c_name)) | |
| return self.elastic_vector_search | |
| # 添加 数据到已有数据库 | |
| def add_chroma(self, files, c_name, chunk_size=200, chunk_overlap=50): | |
| self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) | |
| print("开始添加文件...") | |
| tmps = [] | |
| for file in files: | |
| data = self.parse_data(file) | |
| tmps.extend(data) | |
| splits = self.text_splitter.split_documents(tmps) | |
| self.elastic_vector_search = ElasticsearchStore( | |
| es_url=self.es_url, | |
| index_name=c_name, | |
| embedding=self.embedding | |
| ) | |
| self.elastic_vector_search.add_documents(splits) | |
| self.elastic_vector_search.client.indices.refresh(index=c_name) | |
| print("数据块总量:", self.get_count(c_name)) | |
| return self.elastic_vector_search | |
| # 删除某个 知识库 collection | |
| def delete_collection(self, c_name): | |
| url = self.es_url + "/" + c_name | |
| # 发送 DELETE 请求 | |
| response = requests.delete(url) | |
| # 检查响应状态码 | |
| if response.status_code == 200: | |
| return f"索引 'test-basic1' 已成功删除。" | |
| elif response.status_code == 404: | |
| return f"索引 'test-basic1' 不存在。" | |
| else: | |
| return f"删除索引时出错: {response.status_code}, {response.text}" | |
| # 获取目前所有 index_names | |
| def get_all_collections_name(self): | |
| indices = self.elastic_vector_search.client.indices.get_alias() | |
| index_names = list(indices.keys()) | |
| return index_names | |
| def get_collcetion_content_files(self,c_name): | |
| return [] | |
| # 删除 某个collection中的 某个文件 | |
| def del_files(self, del_files_name, c_name): | |
| return None | |