Spaces:
Running
Running
Upload 6 files
Browse files- .gitattributes +1 -0
- README.md +4 -4
- app.py +156 -0
- requirements.txt +152 -0
- src/jsonl_Indexer.py +125 -0
- tool-embedding.jsonl +3 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
tool-embedding.jsonl filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.42.2
|
8 |
app_file: app.py
|
|
|
1 |
---
|
2 |
+
title: Tool Retriever
|
3 |
+
emoji: 😻
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: pink
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.42.2
|
8 |
app_file: app.py
|
app.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import faiss
|
4 |
+
import numpy as np
|
5 |
+
import streamlit as st
|
6 |
+
import pandas as pd
|
7 |
+
from text2vec import SentenceModel
|
8 |
+
from src.jsonl_Indexer import JSONLIndexer
|
9 |
+
|
10 |
+
def get_cli_args():
|
11 |
+
args = {}
|
12 |
+
argv = sys.argv[2:] if len(sys.argv) > 2 else []
|
13 |
+
for arg in argv:
|
14 |
+
if '=' in arg:
|
15 |
+
key, value = arg.split('=', 1)
|
16 |
+
args[key.strip()] = value.strip()
|
17 |
+
return args
|
18 |
+
|
19 |
+
cli_args = get_cli_args()
|
20 |
+
|
21 |
+
DEFAULT_CONFIG = {
|
22 |
+
'model_path': 'BAAI/bge-base-en-v1.5',
|
23 |
+
'dataset_path': 'tool-embedding.jsonl',
|
24 |
+
'vector_size': 768,
|
25 |
+
'embedding_field': 'embedding',
|
26 |
+
'id_field': 'id'
|
27 |
+
}
|
28 |
+
|
29 |
+
config = DEFAULT_CONFIG.copy()
|
30 |
+
config.update(cli_args)
|
31 |
+
config['vector_size'] = int(config['vector_size'])
|
32 |
+
|
33 |
+
# ---------------------------
|
34 |
+
# 缓存数据集加载函数(避免每次运行时重复下载数据)
|
35 |
+
# ---------------------------
|
36 |
+
@st.cache_data
|
37 |
+
def load_tools_datasets():
|
38 |
+
from datasets import load_dataset, concatenate_datasets
|
39 |
+
ds1 = load_dataset("mangopy/ToolRet-Tools", "code")
|
40 |
+
ds2 = load_dataset("mangopy/ToolRet-Tools", "customized")
|
41 |
+
ds3 = load_dataset("mangopy/ToolRet-Tools", "web")
|
42 |
+
ds = concatenate_datasets([ds1['tools'], ds2['tools'], ds3['tools']])
|
43 |
+
# 重命名'id'字段为'tool'
|
44 |
+
ds = ds.rename_columns({'id': 'tool'})
|
45 |
+
return ds
|
46 |
+
|
47 |
+
ds = load_tools_datasets()
|
48 |
+
df2 = ds.to_pandas()
|
49 |
+
# 如果数据量较大,可以通过设置索引加速后续的合并操作
|
50 |
+
df2.set_index('tool', inplace=True)
|
51 |
+
|
52 |
+
# ---------------------------
|
53 |
+
# 缓存模型加载函数
|
54 |
+
# ---------------------------
|
55 |
+
@st.cache_resource
|
56 |
+
def get_model(model_path: str = config['model_path']):
|
57 |
+
return SentenceModel(model_path)
|
58 |
+
|
59 |
+
# 缓存检索器创建函数
|
60 |
+
@st.cache_resource
|
61 |
+
def create_retriever(vector_sz: int, dataset_path: str, embedding_field: str, id_field: str, _model):
|
62 |
+
retriever = JSONLIndexer(vector_sz=vector_sz, model=_model)
|
63 |
+
retriever.load_jsonl(dataset_path, embedding_field=embedding_field, id_field=id_field)
|
64 |
+
return retriever
|
65 |
+
|
66 |
+
# ---------------------------
|
67 |
+
# 侧边栏配置
|
68 |
+
# ---------------------------
|
69 |
+
st.sidebar.markdown("<div style='text-align: center;'><h3>📄 Model Configuration</h3></div>", unsafe_allow_html=True)
|
70 |
+
model_options = ["BAAI/bge-base-en-v1.5"]
|
71 |
+
selected_model = st.sidebar.selectbox("Select Model", model_options)
|
72 |
+
st.sidebar.write("Selected model:", selected_model)
|
73 |
+
st.sidebar.write("Embedding length: 768")
|
74 |
+
|
75 |
+
# 使用下拉框选中的模型(避免重复加载)
|
76 |
+
model = get_model(selected_model)
|
77 |
+
retriever = create_retriever(
|
78 |
+
config['vector_size'],
|
79 |
+
config['dataset_path'],
|
80 |
+
config['embedding_field'],
|
81 |
+
config['id_field'],
|
82 |
+
_model=model
|
83 |
+
)
|
84 |
+
|
85 |
+
# ---------------------------
|
86 |
+
# 界面样式设置
|
87 |
+
# ---------------------------
|
88 |
+
st.markdown("""
|
89 |
+
<style>
|
90 |
+
.search-container {
|
91 |
+
display: flex;
|
92 |
+
justify-content: center;
|
93 |
+
align-items: center;
|
94 |
+
gap: 10px;
|
95 |
+
margin-top: 20px;
|
96 |
+
}
|
97 |
+
.search-box input {
|
98 |
+
width: 500px !important;
|
99 |
+
height: 45px;
|
100 |
+
font-size: 16px;
|
101 |
+
border-radius: 25px;
|
102 |
+
padding-left: 15px;
|
103 |
+
}
|
104 |
+
.search-btn button {
|
105 |
+
height: 45px;
|
106 |
+
font-size: 16px;
|
107 |
+
border-radius: 25px;
|
108 |
+
}
|
109 |
+
</style>
|
110 |
+
""", unsafe_allow_html=True)
|
111 |
+
|
112 |
+
st.markdown("<h1 style='text-align: center;'>🔍 Tool Retrieval</h1>", unsafe_allow_html=True)
|
113 |
+
|
114 |
+
# ---------------------------
|
115 |
+
# 主体检索区域
|
116 |
+
# ---------------------------
|
117 |
+
col1, col2 = st.columns([4, 1])
|
118 |
+
with col1:
|
119 |
+
query = st.text_input("", placeholder="Enter your search query...", key="search_query", label_visibility="collapsed")
|
120 |
+
with col2:
|
121 |
+
search_clicked = st.button("🔎 Search", use_container_width=True)
|
122 |
+
|
123 |
+
top_k = st.slider("Top-K tools", 1, 100, 50, help="Choose the number of results to display")
|
124 |
+
|
125 |
+
if search_clicked and query:
|
126 |
+
rec_ids, scores = retriever.search_return_id(query, top_k)
|
127 |
+
# 构建检索结果 DataFrame
|
128 |
+
df1 = pd.DataFrame({"relevance": scores, "tool": rec_ids})
|
129 |
+
# 使用 join 加速合并(前提是 df2 已设置好索引)
|
130 |
+
results_df = df1.join(df2, on='tool', how='left').reset_index(drop=False)
|
131 |
+
|
132 |
+
st.subheader("🗂️ Retrieval results")
|
133 |
+
|
134 |
+
styled_results = results_df.style.apply(
|
135 |
+
lambda x: [
|
136 |
+
"background-color: #F7F7F7" if i % 2 == 0 else "background-color: #FFFFFF"
|
137 |
+
for i in range(len(x))
|
138 |
+
],
|
139 |
+
axis=0,
|
140 |
+
).format({"relevance": "{:.4f}"})
|
141 |
+
|
142 |
+
st.dataframe(
|
143 |
+
styled_results,
|
144 |
+
column_config={
|
145 |
+
"relevance": st.column_config.ProgressColumn(
|
146 |
+
"relevance",
|
147 |
+
help="记录与查询的匹配程度",
|
148 |
+
format="%.4f",
|
149 |
+
min_value=0,
|
150 |
+
max_value=float(max(scores)) if len(scores) > 0 else 1,
|
151 |
+
),
|
152 |
+
"tool": st.column_config.TextColumn("tool", help="tool help text", width="medium")
|
153 |
+
},
|
154 |
+
hide_index=True,
|
155 |
+
use_container_width=True,
|
156 |
+
)
|
requirements.txt
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiohappyeyeballs==2.4.3
|
2 |
+
aiohttp==3.11.0
|
3 |
+
aiosignal==1.3.1
|
4 |
+
altair==5.4.1
|
5 |
+
annotated-types==0.7.0
|
6 |
+
anyio==4.6.2.post1
|
7 |
+
asttokens==2.4.1
|
8 |
+
async-timeout==5.0.1
|
9 |
+
attrs==24.2.0
|
10 |
+
blessed==1.20.0
|
11 |
+
blinker==1.9.0
|
12 |
+
branca==0.8.0
|
13 |
+
cachetools==5.5.0
|
14 |
+
certifi==2024.8.30
|
15 |
+
charset-normalizer==3.4.0
|
16 |
+
click==8.1.7
|
17 |
+
comm==0.2.2
|
18 |
+
contourpy==1.3.0
|
19 |
+
cycler==0.12.1
|
20 |
+
datasets==3.1.0
|
21 |
+
debugpy==1.8.8
|
22 |
+
decorator==5.1.1
|
23 |
+
dill==0.3.8
|
24 |
+
distro==1.9.0
|
25 |
+
et_xmlfile==2.0.0
|
26 |
+
exceptiongroup==1.2.2
|
27 |
+
executing==2.1.0
|
28 |
+
f==0.0.1
|
29 |
+
faiss-gpu==1.7.2
|
30 |
+
filelock==3.16.1
|
31 |
+
folium==0.18.0
|
32 |
+
fonttools==4.54.1
|
33 |
+
frozenlist==1.5.0
|
34 |
+
fsspec==2024.9.0
|
35 |
+
geopandas==1.0.1
|
36 |
+
gitdb==4.0.11
|
37 |
+
GitPython==3.1.43
|
38 |
+
gpustat==1.1.1
|
39 |
+
h11==0.14.0
|
40 |
+
httpcore==1.0.6
|
41 |
+
httpx==0.27.2
|
42 |
+
huggingface-hub==0.26.2
|
43 |
+
idna==3.10
|
44 |
+
importlib_metadata==8.5.0
|
45 |
+
importlib_resources==6.4.5
|
46 |
+
ipykernel==6.29.5
|
47 |
+
ipython==8.18.1
|
48 |
+
jedi==0.19.2
|
49 |
+
jieba==0.42.1
|
50 |
+
Jinja2==3.1.4
|
51 |
+
jiter==0.7.1
|
52 |
+
joblib==1.4.2
|
53 |
+
jsonschema==4.23.0
|
54 |
+
jsonschema-specifications==2024.10.1
|
55 |
+
jupyter_client==8.6.3
|
56 |
+
jupyter_core==5.7.2
|
57 |
+
kiwisolver==1.4.7
|
58 |
+
loguru==0.7.2
|
59 |
+
markdown-it-py==3.0.0
|
60 |
+
MarkupSafe==3.0.2
|
61 |
+
matplotlib==3.9.2
|
62 |
+
matplotlib-inline==0.1.7
|
63 |
+
mdurl==0.1.2
|
64 |
+
mpmath==1.3.0
|
65 |
+
multidict==6.1.0
|
66 |
+
multiprocess==0.70.16
|
67 |
+
narwhals==1.13.5
|
68 |
+
nest-asyncio==1.6.0
|
69 |
+
networkx==3.2.1
|
70 |
+
numpy==1.26.0
|
71 |
+
nvidia-cublas-cu12==12.4.5.8
|
72 |
+
nvidia-cuda-cupti-cu12==12.4.127
|
73 |
+
nvidia-cuda-nvrtc-cu12==12.4.127
|
74 |
+
nvidia-cuda-runtime-cu12==12.4.127
|
75 |
+
nvidia-cudnn-cu12==9.1.0.70
|
76 |
+
nvidia-cufft-cu12==11.2.1.3
|
77 |
+
nvidia-curand-cu12==10.3.5.147
|
78 |
+
nvidia-cusolver-cu12==11.6.1.9
|
79 |
+
nvidia-cusparse-cu12==12.3.1.170
|
80 |
+
nvidia-ml-py==12.560.30
|
81 |
+
nvidia-nccl-cu12==2.21.5
|
82 |
+
nvidia-nvjitlink-cu12==12.4.127
|
83 |
+
nvidia-nvtx-cu12==12.4.127
|
84 |
+
openai==0.28.0
|
85 |
+
openpyxl==3.1.5
|
86 |
+
packaging==24.2
|
87 |
+
pandas==2.2.3
|
88 |
+
parso==0.8.4
|
89 |
+
pexpect==4.9.0
|
90 |
+
pillow==11.0.0
|
91 |
+
platformdirs==4.3.6
|
92 |
+
prettytable==3.12.0
|
93 |
+
prompt_toolkit==3.0.48
|
94 |
+
propcache==0.2.0
|
95 |
+
protobuf==5.28.3
|
96 |
+
psutil==6.1.0
|
97 |
+
ptyprocess==0.7.0
|
98 |
+
pure_eval==0.2.3
|
99 |
+
pyarrow==18.0.0
|
100 |
+
pydantic==2.9.2
|
101 |
+
pydantic_core==2.23.4
|
102 |
+
pydeck==0.9.1
|
103 |
+
pyecharts==2.0.7
|
104 |
+
Pygments==2.18.0
|
105 |
+
pyogrio==0.10.0
|
106 |
+
pyparsing==3.2.0
|
107 |
+
pyproj==3.6.1
|
108 |
+
python-dateutil==2.9.0.post0
|
109 |
+
pytz==2024.2
|
110 |
+
PyYAML==6.0.2
|
111 |
+
pyzmq==26.2.0
|
112 |
+
referencing==0.35.1
|
113 |
+
regex==2024.11.6
|
114 |
+
requests==2.32.3
|
115 |
+
rich==13.9.4
|
116 |
+
rpds-py==0.21.0
|
117 |
+
safetensors==0.4.5
|
118 |
+
scikit-learn==1.5.2
|
119 |
+
scipy==1.13.1
|
120 |
+
seaborn==0.13.2
|
121 |
+
shapely==2.0.6
|
122 |
+
simplejson==3.19.3
|
123 |
+
six==1.16.0
|
124 |
+
smmap==5.0.1
|
125 |
+
sniffio==1.3.1
|
126 |
+
stack-data==0.6.3
|
127 |
+
streamlit==1.40.1
|
128 |
+
streamlit-echarts==0.4.0
|
129 |
+
streamlit-option-menu==0.4.0
|
130 |
+
streamlit_folium==0.23.1
|
131 |
+
sympy==1.13.1
|
132 |
+
tenacity==9.0.0
|
133 |
+
text2vec==1.3.1
|
134 |
+
threadpoolctl==3.5.0
|
135 |
+
tokenizers==0.20.3
|
136 |
+
toml==0.10.2
|
137 |
+
torch==2.5.1
|
138 |
+
tornado==6.4.1
|
139 |
+
tqdm==4.67.0
|
140 |
+
traitlets==5.14.3
|
141 |
+
transformers==4.46.2
|
142 |
+
triton==3.1.0
|
143 |
+
typing_extensions==4.12.2
|
144 |
+
tzdata==2024.2
|
145 |
+
urllib3==2.2.3
|
146 |
+
watchdog==6.0.0
|
147 |
+
wcwidth==0.2.13
|
148 |
+
wordcloud==1.9.4
|
149 |
+
xxhash==3.5.0
|
150 |
+
xyzservices==2024.9.0
|
151 |
+
yarl==1.17.1
|
152 |
+
zipp==3.21.0
|
src/jsonl_Indexer.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import faiss
|
4 |
+
import numpy as np
|
5 |
+
from typing import List, Tuple
|
6 |
+
from text2vec import SentenceModel
|
7 |
+
|
8 |
+
class JSONLIndexer(object):
|
9 |
+
"""
|
10 |
+
JSONL 文件检索器(基于预计算embedding)
|
11 |
+
"""
|
12 |
+
def __init__(self, vector_sz: int, n_subquantizers=0, n_bits=8, model: SentenceModel = None, **kwargs):
|
13 |
+
"""
|
14 |
+
初始化索引器,选择使用FAISS的类型
|
15 |
+
:param vector_sz: 嵌入向量的大小
|
16 |
+
:param n_subquantizers: 子量化器数量
|
17 |
+
:param n_bits: 每个子向量的位数
|
18 |
+
:param model: SentenceModel 模型,用于对query重新embedding
|
19 |
+
"""
|
20 |
+
if n_subquantizers > 0:
|
21 |
+
self.index = faiss.IndexPQ(vector_sz, n_subquantizers, n_bits, faiss.METRIC_INNER_PRODUCT)
|
22 |
+
else:
|
23 |
+
self.index = faiss.IndexFlatIP(vector_sz)
|
24 |
+
self.index_id_to_data = [] # FAISS索引ID到JSON记录索引的映射
|
25 |
+
self.data = [] # 存储所有JSON对象
|
26 |
+
self.model = model
|
27 |
+
|
28 |
+
print(f'Initialized FAISS index of type {type(self.index)}')
|
29 |
+
|
30 |
+
def load_jsonl(self, dataset_path: str, embedding_field: str = "embedding", id_field: str = "id") -> None:
|
31 |
+
"""
|
32 |
+
加载JSONL文件并构建FAISS索引(使用预计算embedding)
|
33 |
+
:param dataset_path: JSONL文件路径
|
34 |
+
:param embedding_field: JSON对象中存放embedding的字段名
|
35 |
+
:param id_field: JSON对象中作为待检索文本的字段(这里认为为id)
|
36 |
+
"""
|
37 |
+
print(f'📂 Loading JSONL file: {dataset_path}...')
|
38 |
+
# 逐行读取JSONL文件
|
39 |
+
with open(dataset_path, 'r', encoding='utf-8') as f:
|
40 |
+
for line in f:
|
41 |
+
line = line.strip()
|
42 |
+
if not line:
|
43 |
+
continue
|
44 |
+
record = json.loads(line)
|
45 |
+
self.data.append(record)
|
46 |
+
total = len(self.data)
|
47 |
+
print(f'✅ Loaded {total} records from {dataset_path}.')
|
48 |
+
|
49 |
+
# 直接从每个JSON对象中提取预计算embedding
|
50 |
+
embeddings_list = []
|
51 |
+
for rec in self.data:
|
52 |
+
emb = rec.get(embedding_field, [])
|
53 |
+
# 检查embedding长度是否符合预期
|
54 |
+
if len(emb) != self.index.d and self.index.ntotal == 0:
|
55 |
+
# 如果第一次添加且长度不匹配,可以根据需要进行处理,比如报错或跳过
|
56 |
+
raise ValueError(f"Embedding length mismatch. Expected {self.index.d}, got {len(emb)}.")
|
57 |
+
embeddings_list.append(np.array(emb, dtype=np.float32))
|
58 |
+
embeddings = np.stack(embeddings_list, axis=0)
|
59 |
+
print(f'✅ Embeddings loaded, shape: {embeddings.shape}. Indexing data...')
|
60 |
+
|
61 |
+
# 用数据在FAISS中建立索引
|
62 |
+
ids = list(range(total))
|
63 |
+
self.index_data(ids, embeddings)
|
64 |
+
print('🎉 Indexing complete!')
|
65 |
+
|
66 |
+
def index_data(self, ids: List[int], embeddings: np.array, **kwargs):
|
67 |
+
"""
|
68 |
+
将预先计算好的embedding添加到FAISS索引中
|
69 |
+
:param ids: 每个记录的索引号(这里用list(range(total)))
|
70 |
+
:param embeddings: 记录对应的embedding矩阵
|
71 |
+
"""
|
72 |
+
self._update_id_mapping(ids)
|
73 |
+
embeddings = embeddings.astype('float32')
|
74 |
+
|
75 |
+
# 如果索引未训练,则先训练
|
76 |
+
if not self.index.is_trained:
|
77 |
+
print('⚙️ Training FAISS index...')
|
78 |
+
self.index.train(embeddings)
|
79 |
+
print('✅ FAISS index trained.')
|
80 |
+
self.index.add(embeddings)
|
81 |
+
print(f'✅ Indexed {len(self.index_id_to_data)} records.')
|
82 |
+
|
83 |
+
def _update_id_mapping(self, row_ids: List[int]):
|
84 |
+
"""更新FAISS索引ID到JSON记录索引的映射"""
|
85 |
+
self.index_id_to_data.extend(row_ids)
|
86 |
+
|
87 |
+
def search_return_id(self, query: str, top_docs: int) -> Tuple[List[str], List[float]]:
|
88 |
+
"""
|
89 |
+
根据query返回最相似的JSON记录的id和相似度分数
|
90 |
+
:param query: 查询文本
|
91 |
+
:param top_docs: 返回的最近邻记录数量
|
92 |
+
:return: (记录的id列表, 分数列表)
|
93 |
+
"""
|
94 |
+
db_indices, scores = self.search(query, top_docs)
|
95 |
+
# 这里假设待检索文本就是json对象中的id字段
|
96 |
+
result_ids = [self.data[i]["id"] for i in db_indices]
|
97 |
+
return result_ids, scores
|
98 |
+
|
99 |
+
def search(self, query: str, top_docs: int) -> Tuple[List[int], List[float]]:
|
100 |
+
"""
|
101 |
+
对query重新embedding后,在FAISS索引中检索
|
102 |
+
:param query: 查询文本
|
103 |
+
:param top_docs: 返回的最近邻记录数量
|
104 |
+
:return: (JSON记录的索引列表, 相似度分数列表)
|
105 |
+
"""
|
106 |
+
# 仅对query重新计算embedding
|
107 |
+
query_vector = self.model.encode(query).astype('float32').reshape(1, -1)
|
108 |
+
scores, indexes = self.index.search(query_vector, top_docs)
|
109 |
+
scores = scores[0]
|
110 |
+
indexes = indexes[0]
|
111 |
+
db_indices = [self.index_id_to_data[i] for i in indexes]
|
112 |
+
return db_indices, scores
|
113 |
+
|
114 |
+
# 示例用法
|
115 |
+
if __name__ == '__main__':
|
116 |
+
model = SentenceModel("BAAI/bge-base-en-v1.5")
|
117 |
+
vector_size = 768 # 请根据你的模型确定嵌入向量维度
|
118 |
+
|
119 |
+
indexer = JSONLIndexer(vector_sz=vector_size, model=model)
|
120 |
+
jsonl_path = "tool-embedding.jsonl" # 替换为实际JSONL文件路径
|
121 |
+
indexer.load_jsonl(jsonl_path)
|
122 |
+
|
123 |
+
query = "your search query here"
|
124 |
+
ids, scores = indexer.search_return_id(query, top_docs=5)
|
125 |
+
print("检索结果:", list(zip(ids, scores)))
|
tool-embedding.jsonl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e172fd628eba5d4d59d1b467e58827a9874e981a4667e3e297ca4c01f1b275f5
|
3 |
+
size 674988141
|