Spaces:
Sleeping
Sleeping
Update src/vector_db.py
Browse files- src/vector_db.py +11 -13
src/vector_db.py
CHANGED
|
@@ -17,9 +17,7 @@ class VectorDB:
|
|
| 17 |
db_location = ''
|
| 18 |
|
| 19 |
def __init__(self, emb_model, db_location, actions_list_file_path, num_sub_vectors, batch_size):
|
| 20 |
-
self.
|
| 21 |
-
self.db_location = db_location
|
| 22 |
-
|
| 23 |
emb_config = AutoConfig.from_pretrained(emb_model)
|
| 24 |
emb_dimension = emb_config.hidden_size
|
| 25 |
|
|
@@ -50,7 +48,7 @@ class VectorDB:
|
|
| 50 |
pa.field(self.name_column, pa.string())
|
| 51 |
]
|
| 52 |
)
|
| 53 |
-
tbl = db.create_table(
|
| 54 |
|
| 55 |
|
| 56 |
df = pd.read_csv(actions_list_file_path)
|
|
@@ -76,23 +74,23 @@ class VectorDB:
|
|
| 76 |
tbl.add(df)
|
| 77 |
except:
|
| 78 |
print(f"batch {i} was skipped")
|
|
|
|
|
|
|
|
|
|
| 79 |
print("Vector generation done.")
|
| 80 |
|
| 81 |
|
| 82 |
-
def get_embedding_db_as_pandas(self):
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
|
| 87 |
|
| 88 |
|
| 89 |
def retrieve_prefiltered_hits(self, query, k):
|
| 90 |
-
db = lancedb.connect(".lancedb")
|
| 91 |
-
table = db.open_table(self.table_name)
|
| 92 |
-
retriever = SentenceTransformer(self.emb_model)
|
| 93 |
|
| 94 |
-
query_vec = retriever.encode(query)
|
| 95 |
-
documents = table.search(query_vec, vector_column_name=self.vector_column).limit(k).to_list()
|
| 96 |
names = [doc[self.name_column] for doc in documents]
|
| 97 |
descriptions = [doc[self.description_column] for doc in documents]
|
| 98 |
|
|
|
|
| 17 |
db_location = ''
|
| 18 |
|
| 19 |
def __init__(self, emb_model, db_location, actions_list_file_path, num_sub_vectors, batch_size):
|
| 20 |
+
self.retriever = SentenceTransformer(emb_model)
|
|
|
|
|
|
|
| 21 |
emb_config = AutoConfig.from_pretrained(emb_model)
|
| 22 |
emb_dimension = emb_config.hidden_size
|
| 23 |
|
|
|
|
| 48 |
pa.field(self.name_column, pa.string())
|
| 49 |
]
|
| 50 |
)
|
| 51 |
+
tbl = db.create_table(table_name, schema=schema, mode="overwrite")
|
| 52 |
|
| 53 |
|
| 54 |
df = pd.read_csv(actions_list_file_path)
|
|
|
|
| 74 |
tbl.add(df)
|
| 75 |
except:
|
| 76 |
print(f"batch {i} was skipped")
|
| 77 |
+
|
| 78 |
+
self.db = db
|
| 79 |
+
self.table = tbl
|
| 80 |
print("Vector generation done.")
|
| 81 |
|
| 82 |
|
| 83 |
+
# def get_embedding_db_as_pandas(self):
|
| 84 |
+
# db = lancedb.connect(self.db_location)
|
| 85 |
+
# tbl = db.open_table(self.table_name)
|
| 86 |
+
# return tbl.to_pandas()
|
| 87 |
|
| 88 |
|
| 89 |
|
| 90 |
def retrieve_prefiltered_hits(self, query, k):
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
+
query_vec = self.retriever.encode(query)
|
| 93 |
+
documents = self.table.search(query_vec, vector_column_name=self.vector_column).limit(k).to_list()
|
| 94 |
names = [doc[self.name_column] for doc in documents]
|
| 95 |
descriptions = [doc[self.description_column] for doc in documents]
|
| 96 |
|