tool_retriever / src /jsonl_Indexer.py
Yyy0530's picture
添加 JSONL 文件检索器和 Streamlit 应用,支持基于预计算 embedding 的相似记录检索
9ab9c77
import os
import json
import faiss
import numpy as np
from typing import List, Tuple
from text2vec import SentenceModel
class JSONLIndexer(object):
"""
JSONL 文件检索器(基于预计算embedding)
"""
def __init__(self, vector_sz: int, n_subquantizers=0, n_bits=8, model: SentenceModel = None, **kwargs):
"""
初始化索引器,选择使用FAISS的类型
:param vector_sz: 嵌入向量的大小
:param n_subquantizers: 子量化器数量
:param n_bits: 每个子向量的位数
:param model: SentenceModel 模型,用于对query重新embedding
"""
if n_subquantizers > 0:
self.index = faiss.IndexPQ(vector_sz, n_subquantizers, n_bits, faiss.METRIC_INNER_PRODUCT)
else:
self.index = faiss.IndexFlatIP(vector_sz)
self.index_id_to_data = [] # FAISS索引ID到JSON记录索引的映射
self.data = [] # 存储所有JSON对象
self.model = model
print(f'Initialized FAISS index of type {type(self.index)}')
def load_jsonl(self, dataset_path: str, embedding_field: str = "embedding", id_field: str = "id") -> None:
"""
加载JSONL文件并构建FAISS索引(使用预计算embedding)
:param dataset_path: JSONL文件路径
:param embedding_field: JSON对象中存放embedding的字段名
:param id_field: JSON对象中作为待检索文本的字段(这里认为为id)
"""
print(f'📂 Loading JSONL file: {dataset_path}...')
# 逐行读取JSONL文件
with open(dataset_path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if not line:
continue
record = json.loads(line)
self.data.append(record)
total = len(self.data)
print(f'✅ Loaded {total} records from {dataset_path}.')
# 直接从每个JSON对象中提取预计算embedding
embeddings_list = []
for rec in self.data:
emb = rec.get(embedding_field, [])
# 检查embedding长度是否符合预期
if len(emb) != self.index.d and self.index.ntotal == 0:
# 如果第一次添加且长度不匹配,可以根据需要进行处理,比如报错或跳过
raise ValueError(f"Embedding length mismatch. Expected {self.index.d}, got {len(emb)}.")
embeddings_list.append(np.array(emb, dtype=np.float32))
embeddings = np.stack(embeddings_list, axis=0)
print(f'✅ Embeddings loaded, shape: {embeddings.shape}. Indexing data...')
# 用数据在FAISS中建立索引
ids = list(range(total))
self.index_data(ids, embeddings)
print('🎉 Indexing complete!')
def index_data(self, ids: List[int], embeddings: np.array, **kwargs):
"""
将预先计算好的embedding添加到FAISS索引中
:param ids: 每个记录的索引号(这里用list(range(total)))
:param embeddings: 记录对应的embedding矩阵
"""
self._update_id_mapping(ids)
embeddings = embeddings.astype('float32')
# 如果索引未训练,则先训练
if not self.index.is_trained:
print('⚙️ Training FAISS index...')
self.index.train(embeddings)
print('✅ FAISS index trained.')
self.index.add(embeddings)
print(f'✅ Indexed {len(self.index_id_to_data)} records.')
def _update_id_mapping(self, row_ids: List[int]):
"""更新FAISS索引ID到JSON记录索引的映射"""
self.index_id_to_data.extend(row_ids)
def search_return_id(self, query: str, top_docs: int) -> Tuple[List[str], List[float]]:
"""
根据query返回最相似的JSON记录的id和相似度分数
:param query: 查询文本
:param top_docs: 返回的最近邻记录数量
:return: (记录的id列表, 分数列表)
"""
db_indices, scores = self.search(query, top_docs)
# 这里假设待检索文本就是json对象中的id字段
result_ids = [self.data[i]["id"] for i in db_indices]
return result_ids, scores
def search(self, query: str, top_docs: int) -> Tuple[List[int], List[float]]:
"""
对query重新embedding后,在FAISS索引中检索
:param query: 查询文本
:param top_docs: 返回的最近邻记录数量
:return: (JSON记录的索引列表, 相似度分数列表)
"""
# 仅对query重新计算embedding
query_vector = self.model.encode(query).astype('float32').reshape(1, -1)
scores, indexes = self.index.search(query_vector, top_docs)
scores = scores[0]
indexes = indexes[0]
db_indices = [self.index_id_to_data[i] for i in indexes]
return db_indices, scores
# 示例用法
if __name__ == '__main__':
model = SentenceModel("BAAI/bge-base-en-v1.5")
vector_size = 768 # 请根据你的模型确定嵌入向量维度
indexer = JSONLIndexer(vector_sz=vector_size, model=model)
jsonl_path = "tool-embedding.jsonl" # 替换为实际JSONL文件路径
indexer.load_jsonl(jsonl_path)
query = "your search query here"
ids, scores = indexer.search_return_id(query, top_docs=5)
print("检索结果:", list(zip(ids, scores)))