Spaces:
Running
Running
import os | |
import json | |
import faiss | |
import numpy as np | |
from typing import List, Tuple | |
from text2vec import SentenceModel | |
class JSONLIndexer(object): | |
""" | |
JSONL 文件检索器(基于预计算embedding) | |
""" | |
def __init__(self, vector_sz: int, n_subquantizers=0, n_bits=8, model: SentenceModel = None, **kwargs): | |
""" | |
初始化索引器,选择使用FAISS的类型 | |
:param vector_sz: 嵌入向量的大小 | |
:param n_subquantizers: 子量化器数量 | |
:param n_bits: 每个子向量的位数 | |
:param model: SentenceModel 模型,用于对query重新embedding | |
""" | |
if n_subquantizers > 0: | |
self.index = faiss.IndexPQ(vector_sz, n_subquantizers, n_bits, faiss.METRIC_INNER_PRODUCT) | |
else: | |
self.index = faiss.IndexFlatIP(vector_sz) | |
self.index_id_to_data = [] # FAISS索引ID到JSON记录索引的映射 | |
self.data = [] # 存储所有JSON对象 | |
self.model = model | |
print(f'Initialized FAISS index of type {type(self.index)}') | |
def load_jsonl(self, dataset_path: str, embedding_field: str = "embedding", id_field: str = "id") -> None: | |
""" | |
加载JSONL文件并构建FAISS索引(使用预计算embedding) | |
:param dataset_path: JSONL文件路径 | |
:param embedding_field: JSON对象中存放embedding的字段名 | |
:param id_field: JSON对象中作为待检索文本的字段(这里认为为id) | |
""" | |
print(f'📂 Loading JSONL file: {dataset_path}...') | |
# 逐行读取JSONL文件 | |
with open(dataset_path, 'r', encoding='utf-8') as f: | |
for line in f: | |
line = line.strip() | |
if not line: | |
continue | |
record = json.loads(line) | |
self.data.append(record) | |
total = len(self.data) | |
print(f'✅ Loaded {total} records from {dataset_path}.') | |
# 直接从每个JSON对象中提取预计算embedding | |
embeddings_list = [] | |
for rec in self.data: | |
emb = rec.get(embedding_field, []) | |
# 检查embedding长度是否符合预期 | |
if len(emb) != self.index.d and self.index.ntotal == 0: | |
# 如果第一次添加且长度不匹配,可以根据需要进行处理,比如报错或跳过 | |
raise ValueError(f"Embedding length mismatch. Expected {self.index.d}, got {len(emb)}.") | |
embeddings_list.append(np.array(emb, dtype=np.float32)) | |
embeddings = np.stack(embeddings_list, axis=0) | |
print(f'✅ Embeddings loaded, shape: {embeddings.shape}. Indexing data...') | |
# 用数据在FAISS中建立索引 | |
ids = list(range(total)) | |
self.index_data(ids, embeddings) | |
print('🎉 Indexing complete!') | |
def index_data(self, ids: List[int], embeddings: np.array, **kwargs): | |
""" | |
将预先计算好的embedding添加到FAISS索引中 | |
:param ids: 每个记录的索引号(这里用list(range(total))) | |
:param embeddings: 记录对应的embedding矩阵 | |
""" | |
self._update_id_mapping(ids) | |
embeddings = embeddings.astype('float32') | |
# 如果索引未训练,则先训练 | |
if not self.index.is_trained: | |
print('⚙️ Training FAISS index...') | |
self.index.train(embeddings) | |
print('✅ FAISS index trained.') | |
self.index.add(embeddings) | |
print(f'✅ Indexed {len(self.index_id_to_data)} records.') | |
def _update_id_mapping(self, row_ids: List[int]): | |
"""更新FAISS索引ID到JSON记录索引的映射""" | |
self.index_id_to_data.extend(row_ids) | |
def search_return_id(self, query: str, top_docs: int) -> Tuple[List[str], List[float]]: | |
""" | |
根据query返回最相似的JSON记录的id和相似度分数 | |
:param query: 查询文本 | |
:param top_docs: 返回的最近邻记录数量 | |
:return: (记录的id列表, 分数列表) | |
""" | |
db_indices, scores = self.search(query, top_docs) | |
# 这里假设待检索文本就是json对象中的id字段 | |
result_ids = [self.data[i]["id"] for i in db_indices] | |
return result_ids, scores | |
def search(self, query: str, top_docs: int) -> Tuple[List[int], List[float]]: | |
""" | |
对query重新embedding后,在FAISS索引中检索 | |
:param query: 查询文本 | |
:param top_docs: 返回的最近邻记录数量 | |
:return: (JSON记录的索引列表, 相似度分数列表) | |
""" | |
# 仅对query重新计算embedding | |
query_vector = self.model.encode(query).astype('float32').reshape(1, -1) | |
scores, indexes = self.index.search(query_vector, top_docs) | |
scores = scores[0] | |
indexes = indexes[0] | |
db_indices = [self.index_id_to_data[i] for i in indexes] | |
return db_indices, scores | |
# 示例用法 | |
if __name__ == '__main__': | |
model = SentenceModel("BAAI/bge-base-en-v1.5") | |
vector_size = 768 # 请根据你的模型确定嵌入向量维度 | |
indexer = JSONLIndexer(vector_sz=vector_size, model=model) | |
jsonl_path = "tool-embedding.jsonl" # 替换为实际JSONL文件路径 | |
indexer.load_jsonl(jsonl_path) | |
query = "your search query here" | |
ids, scores = indexer.search_return_id(query, top_docs=5) | |
print("检索结果:", list(zip(ids, scores))) | |