Yyy0530 commited on
Commit
9ab9c77
·
1 Parent(s): 5ad0dae

添加 JSONL 文件检索器和 Streamlit 应用,支持基于预计算 embedding 的相似记录检索

Browse files
requirements.txt ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohappyeyeballs==2.4.3
2
+ aiohttp==3.11.0
3
+ aiosignal==1.3.1
4
+ altair==5.4.1
5
+ annotated-types==0.7.0
6
+ anyio==4.6.2.post1
7
+ asttokens==2.4.1
8
+ async-timeout==5.0.1
9
+ attrs==24.2.0
10
+ blessed==1.20.0
11
+ blinker==1.9.0
12
+ branca==0.8.0
13
+ cachetools==5.5.0
14
+ certifi==2024.8.30
15
+ charset-normalizer==3.4.0
16
+ click==8.1.7
17
+ comm==0.2.2
18
+ contourpy==1.3.0
19
+ cycler==0.12.1
20
+ datasets==3.1.0
21
+ debugpy==1.8.8
22
+ decorator==5.1.1
23
+ dill==0.3.8
24
+ distro==1.9.0
25
+ et_xmlfile==2.0.0
26
+ exceptiongroup==1.2.2
27
+ executing==2.1.0
28
+ f==0.0.1
29
+ faiss-gpu==1.7.2
30
+ filelock==3.16.1
31
+ folium==0.18.0
32
+ fonttools==4.54.1
33
+ frozenlist==1.5.0
34
+ fsspec==2024.9.0
35
+ geopandas==1.0.1
36
+ gitdb==4.0.11
37
+ GitPython==3.1.43
38
+ gpustat==1.1.1
39
+ h11==0.14.0
40
+ httpcore==1.0.6
41
+ httpx==0.27.2
42
+ huggingface-hub==0.26.2
43
+ idna==3.10
44
+ importlib_metadata==8.5.0
45
+ importlib_resources==6.4.5
46
+ ipykernel==6.29.5
47
+ ipython==8.18.1
48
+ jedi==0.19.2
49
+ jieba==0.42.1
50
+ Jinja2==3.1.4
51
+ jiter==0.7.1
52
+ joblib==1.4.2
53
+ jsonschema==4.23.0
54
+ jsonschema-specifications==2024.10.1
55
+ jupyter_client==8.6.3
56
+ jupyter_core==5.7.2
57
+ kiwisolver==1.4.7
58
+ loguru==0.7.2
59
+ markdown-it-py==3.0.0
60
+ MarkupSafe==3.0.2
61
+ matplotlib==3.9.2
62
+ matplotlib-inline==0.1.7
63
+ mdurl==0.1.2
64
+ mpmath==1.3.0
65
+ multidict==6.1.0
66
+ multiprocess==0.70.16
67
+ narwhals==1.13.5
68
+ nest-asyncio==1.6.0
69
+ networkx==3.2.1
70
+ numpy==1.26.0
71
+ nvidia-cublas-cu12==12.4.5.8
72
+ nvidia-cuda-cupti-cu12==12.4.127
73
+ nvidia-cuda-nvrtc-cu12==12.4.127
74
+ nvidia-cuda-runtime-cu12==12.4.127
75
+ nvidia-cudnn-cu12==9.1.0.70
76
+ nvidia-cufft-cu12==11.2.1.3
77
+ nvidia-curand-cu12==10.3.5.147
78
+ nvidia-cusolver-cu12==11.6.1.9
79
+ nvidia-cusparse-cu12==12.3.1.170
80
+ nvidia-ml-py==12.560.30
81
+ nvidia-nccl-cu12==2.21.5
82
+ nvidia-nvjitlink-cu12==12.4.127
83
+ nvidia-nvtx-cu12==12.4.127
84
+ openai==0.28.0
85
+ openpyxl==3.1.5
86
+ packaging==24.2
87
+ pandas==2.2.3
88
+ parso==0.8.4
89
+ pexpect==4.9.0
90
+ pillow==11.0.0
91
+ platformdirs==4.3.6
92
+ prettytable==3.12.0
93
+ prompt_toolkit==3.0.48
94
+ propcache==0.2.0
95
+ protobuf==5.28.3
96
+ psutil==6.1.0
97
+ ptyprocess==0.7.0
98
+ pure_eval==0.2.3
99
+ pyarrow==18.0.0
100
+ pydantic==2.9.2
101
+ pydantic_core==2.23.4
102
+ pydeck==0.9.1
103
+ pyecharts==2.0.7
104
+ Pygments==2.18.0
105
+ pyogrio==0.10.0
106
+ pyparsing==3.2.0
107
+ pyproj==3.6.1
108
+ python-dateutil==2.9.0.post0
109
+ pytz==2024.2
110
+ PyYAML==6.0.2
111
+ pyzmq==26.2.0
112
+ referencing==0.35.1
113
+ regex==2024.11.6
114
+ requests==2.32.3
115
+ rich==13.9.4
116
+ rpds-py==0.21.0
117
+ safetensors==0.4.5
118
+ scikit-learn==1.5.2
119
+ scipy==1.13.1
120
+ seaborn==0.13.2
121
+ shapely==2.0.6
122
+ simplejson==3.19.3
123
+ six==1.16.0
124
+ smmap==5.0.1
125
+ sniffio==1.3.1
126
+ stack-data==0.6.3
127
+ streamlit==1.40.1
128
+ streamlit-echarts==0.4.0
129
+ streamlit-option-menu==0.4.0
130
+ streamlit_folium==0.23.1
131
+ sympy==1.13.1
132
+ tenacity==9.0.0
133
+ text2vec==1.3.1
134
+ threadpoolctl==3.5.0
135
+ tokenizers==0.20.3
136
+ toml==0.10.2
137
+ torch==2.5.1
138
+ tornado==6.4.1
139
+ tqdm==4.67.0
140
+ traitlets==5.14.3
141
+ transformers==4.46.2
142
+ triton==3.1.0
143
+ typing_extensions==4.12.2
144
+ tzdata==2024.2
145
+ urllib3==2.2.3
146
+ watchdog==6.0.0
147
+ wcwidth==0.2.13
148
+ wordcloud==1.9.4
149
+ xxhash==3.5.0
150
+ xyzservices==2024.9.0
151
+ yarl==1.17.1
152
+ zipp==3.21.0
src/jsonl_Indexer.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import faiss
4
+ import numpy as np
5
+ from typing import List, Tuple
6
+ from text2vec import SentenceModel
7
+
8
+ class JSONLIndexer(object):
9
+ """
10
+ JSONL 文件检索器(基于预计算embedding)
11
+ """
12
+ def __init__(self, vector_sz: int, n_subquantizers=0, n_bits=8, model: SentenceModel = None, **kwargs):
13
+ """
14
+ 初始化索引器,选择使用FAISS的类型
15
+ :param vector_sz: 嵌入向量的大小
16
+ :param n_subquantizers: 子量化器数量
17
+ :param n_bits: 每个子向量的位数
18
+ :param model: SentenceModel 模型,用于对query重新embedding
19
+ """
20
+ if n_subquantizers > 0:
21
+ self.index = faiss.IndexPQ(vector_sz, n_subquantizers, n_bits, faiss.METRIC_INNER_PRODUCT)
22
+ else:
23
+ self.index = faiss.IndexFlatIP(vector_sz)
24
+ self.index_id_to_data = [] # FAISS索引ID到JSON记录索引的映射
25
+ self.data = [] # 存储所有JSON对象
26
+ self.model = model
27
+
28
+ print(f'Initialized FAISS index of type {type(self.index)}')
29
+
30
+ def load_jsonl(self, dataset_path: str, embedding_field: str = "embedding", id_field: str = "id") -> None:
31
+ """
32
+ 加载JSONL文件并构建FAISS索引(使用预计算embedding)
33
+ :param dataset_path: JSONL文件路径
34
+ :param embedding_field: JSON对象中存放embedding的字段名
35
+ :param id_field: JSON对象中作为待检索文本的字段(这里认为为id)
36
+ """
37
+ print(f'📂 Loading JSONL file: {dataset_path}...')
38
+ # 逐行读取JSONL文件
39
+ with open(dataset_path, 'r', encoding='utf-8') as f:
40
+ for line in f:
41
+ line = line.strip()
42
+ if not line:
43
+ continue
44
+ record = json.loads(line)
45
+ self.data.append(record)
46
+ total = len(self.data)
47
+ print(f'✅ Loaded {total} records from {dataset_path}.')
48
+
49
+ # 直接从每个JSON对象中提取预计算embedding
50
+ embeddings_list = []
51
+ for rec in self.data:
52
+ emb = rec.get(embedding_field, [])
53
+ # 检查embedding长度是否符合预期
54
+ if len(emb) != self.index.d and self.index.ntotal == 0:
55
+ # 如果第一次添加且长度不匹配,可以根据需要进行处理,比如报错或跳过
56
+ raise ValueError(f"Embedding length mismatch. Expected {self.index.d}, got {len(emb)}.")
57
+ embeddings_list.append(np.array(emb, dtype=np.float32))
58
+ embeddings = np.stack(embeddings_list, axis=0)
59
+ print(f'✅ Embeddings loaded, shape: {embeddings.shape}. Indexing data...')
60
+
61
+ # 用数据在FAISS中建立索引
62
+ ids = list(range(total))
63
+ self.index_data(ids, embeddings)
64
+ print('🎉 Indexing complete!')
65
+
66
+ def index_data(self, ids: List[int], embeddings: np.array, **kwargs):
67
+ """
68
+ 将预先计算好的embedding添加到FAISS索引中
69
+ :param ids: 每个记录的索引号(这里用list(range(total)))
70
+ :param embeddings: 记录对应的embedding矩阵
71
+ """
72
+ self._update_id_mapping(ids)
73
+ embeddings = embeddings.astype('float32')
74
+
75
+ # 如果索引未训练,则先训练
76
+ if not self.index.is_trained:
77
+ print('⚙️ Training FAISS index...')
78
+ self.index.train(embeddings)
79
+ print('✅ FAISS index trained.')
80
+ self.index.add(embeddings)
81
+ print(f'✅ Indexed {len(self.index_id_to_data)} records.')
82
+
83
+ def _update_id_mapping(self, row_ids: List[int]):
84
+ """更新FAISS索引ID到JSON记录索引的映射"""
85
+ self.index_id_to_data.extend(row_ids)
86
+
87
+ def search_return_id(self, query: str, top_docs: int) -> Tuple[List[str], List[float]]:
88
+ """
89
+ 根据query返回最相似的JSON记录的id和相似度分数
90
+ :param query: 查询文本
91
+ :param top_docs: 返回的最近邻记录数量
92
+ :return: (记录的id列表, 分数列表)
93
+ """
94
+ db_indices, scores = self.search(query, top_docs)
95
+ # 这里假设待检索文本就是json对象中的id字段
96
+ result_ids = [self.data[i]["id"] for i in db_indices]
97
+ return result_ids, scores
98
+
99
+ def search(self, query: str, top_docs: int) -> Tuple[List[int], List[float]]:
100
+ """
101
+ 对query重新embedding后,在FAISS索引中检索
102
+ :param query: 查询文本
103
+ :param top_docs: 返回的最近邻记录数量
104
+ :return: (JSON记录的索引列表, 相似度分数列表)
105
+ """
106
+ # 仅对query重新计算embedding
107
+ query_vector = self.model.encode(query).astype('float32').reshape(1, -1)
108
+ scores, indexes = self.index.search(query_vector, top_docs)
109
+ scores = scores[0]
110
+ indexes = indexes[0]
111
+ db_indices = [self.index_id_to_data[i] for i in indexes]
112
+ return db_indices, scores
113
+
114
+ # 示例用法
115
+ if __name__ == '__main__':
116
+ model = SentenceModel("BAAI/bge-base-en-v1.5")
117
+ vector_size = 768 # 请根据你的模型确定嵌入向量维度
118
+
119
+ indexer = JSONLIndexer(vector_sz=vector_size, model=model)
120
+ jsonl_path = "tool-embedding.jsonl" # 替换为实际JSONL文件路径
121
+ indexer.load_jsonl(jsonl_path)
122
+
123
+ query = "your search query here"
124
+ ids, scores = indexer.search_return_id(query, top_docs=5)
125
+ print("检索结果:", list(zip(ids, scores)))
streamlit_jsonl_retriever.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['CUDA_VISIBLE_DEVICES'] = '3'
3
+
4
+ import os
5
+ import sys
6
+ import faiss
7
+ import numpy as np
8
+ import streamlit as st
9
+ from text2vec import SentenceModel
10
+ # 请确保 JSONLIndexer 在 src 目录下或者已正确安装
11
+ from src.jsonl_Indexer import JSONLIndexer
12
+
13
+ # 命令行参数处理函数
14
+ def get_cli_args():
15
+ args = {}
16
+ # 跳过第一个参数(脚本名)和第二个参数(streamlit run)
17
+ argv = sys.argv[2:] if len(sys.argv) > 2 else []
18
+ for arg in argv:
19
+ if '=' in arg:
20
+ key, value = arg.split('=', 1)
21
+ args[key.strip()] = value.strip()
22
+ return args
23
+
24
+ # 获取命令行参数
25
+ cli_args = get_cli_args()
26
+
27
+ # 设置默认值(适用于 JSONL 文件)
28
+ DEFAULT_CONFIG = {
29
+ 'model_path': 'BAAI/bge-base-en-v1.5',
30
+ 'dataset_path': 'src/tool-embedding.jsonl', # JSONL 文件路径
31
+ 'vector_size': 768,
32
+ 'embedding_field': 'embedding', # JSON中存储embedding的字段名
33
+ 'id_field': 'id' # JSON中作为待检索文本的字段
34
+ }
35
+
36
+ # 合并默认配置和命令行参数
37
+ config = DEFAULT_CONFIG.copy()
38
+ config.update(cli_args)
39
+
40
+ # 将 vector_size 转换为整数
41
+ config['vector_size'] = int(config['vector_size'])
42
+
43
+ @st.cache_resource
44
+ def get_model(model_path: str = config['model_path']):
45
+ model = SentenceModel(model_path)
46
+ return model
47
+
48
+ @st.cache_resource
49
+ def create_retriever(vector_sz: int, dataset_path: str, embedding_field: str, id_field: str, _model):
50
+ retriever = JSONLIndexer(vector_sz=vector_sz, model=_model)
51
+ retriever.load_jsonl(dataset_path, embedding_field=embedding_field, id_field=id_field)
52
+ return retriever
53
+
54
+ # 在侧边栏显示当前配置
55
+ if st.sidebar.checkbox("Show Configuration"):
56
+ st.sidebar.write("Current Configuration:")
57
+ for key, value in config.items():
58
+ st.sidebar.write(f"{key}: {value}")
59
+
60
+ # 初始化模型和检索器
61
+ model = get_model(config['model_path'])
62
+ retriever = create_retriever(
63
+ config['vector_size'],
64
+ config['dataset_path'],
65
+ config['embedding_field'],
66
+ config['id_field'],
67
+ _model=model
68
+ )
69
+
70
+ # Streamlit 应用界面
71
+ st.title("JSONL Data Retrieval Visualization")
72
+ st.write("该应用基于预计算的 JSONL 文件 embedding,输入查询后将检索相似记录。")
73
+
74
+ # 查询输入
75
+ query = st.text_input("Enter a search query:")
76
+ top_k = st.slider("Select number of results to display", min_value=1, max_value=100, value=5)
77
+
78
+ # 检索并展示结果
79
+ if st.button("Search") and query:
80
+ # 注意:JSONLIndexer 提供的是 search_return_id 方法,返回的是 JSON 中 id 字段
81
+ rec_ids, scores = retriever.search_return_id(query, top_k)
82
+
83
+ st.write("### Results:")
84
+
85
+ with st.expander("Retrieval Results (click to expand)"):
86
+ for j, rec_id in enumerate(rec_ids):
87
+ st.markdown(
88
+ f"""
89
+ <div style="border:1px solid #ccc; padding:10px; border-radius:5px; margin-bottom:10px; background-color:#f9f9f9;">
90
+ <p><b>Record {j+1} ID:</b> {rec_id}</p>
91
+ <p><b>Score:</b> {scores[j]:.4f}</p>
92
+ </div>
93
+ """,
94
+ unsafe_allow_html=True
95
+ )