mangopy commited on
Commit
cdfebc6
·
verified ·
1 Parent(s): 39b5b1c

Upload 6 files

Browse files
Files changed (6) hide show
  1. .gitattributes +1 -0
  2. README.md +4 -4
  3. app.py +156 -0
  4. requirements.txt +152 -0
  5. src/jsonl_Indexer.py +125 -0
  6. tool-embedding.jsonl +3 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tool-embedding.jsonl filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: ToolRet Demo
3
- emoji: 📉
4
- colorFrom: green
5
- colorTo: gray
6
  sdk: streamlit
7
  sdk_version: 1.42.2
8
  app_file: app.py
 
1
  ---
2
+ title: Tool Retriever
3
+ emoji: 😻
4
+ colorFrom: blue
5
+ colorTo: pink
6
  sdk: streamlit
7
  sdk_version: 1.42.2
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import faiss
4
+ import numpy as np
5
+ import streamlit as st
6
+ import pandas as pd
7
+ from text2vec import SentenceModel
8
+ from src.jsonl_Indexer import JSONLIndexer
9
+
10
+ def get_cli_args():
11
+ args = {}
12
+ argv = sys.argv[2:] if len(sys.argv) > 2 else []
13
+ for arg in argv:
14
+ if '=' in arg:
15
+ key, value = arg.split('=', 1)
16
+ args[key.strip()] = value.strip()
17
+ return args
18
+
19
+ cli_args = get_cli_args()
20
+
21
+ DEFAULT_CONFIG = {
22
+ 'model_path': 'BAAI/bge-base-en-v1.5',
23
+ 'dataset_path': 'tool-embedding.jsonl',
24
+ 'vector_size': 768,
25
+ 'embedding_field': 'embedding',
26
+ 'id_field': 'id'
27
+ }
28
+
29
+ config = DEFAULT_CONFIG.copy()
30
+ config.update(cli_args)
31
+ config['vector_size'] = int(config['vector_size'])
32
+
33
+ # ---------------------------
34
+ # 缓存数据集加载函数(避免每次运行时重复下载数据)
35
+ # ---------------------------
36
+ @st.cache_data
37
+ def load_tools_datasets():
38
+ from datasets import load_dataset, concatenate_datasets
39
+ ds1 = load_dataset("mangopy/ToolRet-Tools", "code")
40
+ ds2 = load_dataset("mangopy/ToolRet-Tools", "customized")
41
+ ds3 = load_dataset("mangopy/ToolRet-Tools", "web")
42
+ ds = concatenate_datasets([ds1['tools'], ds2['tools'], ds3['tools']])
43
+ # 重命名'id'字段为'tool'
44
+ ds = ds.rename_columns({'id': 'tool'})
45
+ return ds
46
+
47
+ ds = load_tools_datasets()
48
+ df2 = ds.to_pandas()
49
+ # 如果数据量较大,可以通过设置索引加速后续的合并操作
50
+ df2.set_index('tool', inplace=True)
51
+
52
+ # ---------------------------
53
+ # 缓存模型加载函数
54
+ # ---------------------------
55
+ @st.cache_resource
56
+ def get_model(model_path: str = config['model_path']):
57
+ return SentenceModel(model_path)
58
+
59
+ # 缓存检索器创建函数
60
+ @st.cache_resource
61
+ def create_retriever(vector_sz: int, dataset_path: str, embedding_field: str, id_field: str, _model):
62
+ retriever = JSONLIndexer(vector_sz=vector_sz, model=_model)
63
+ retriever.load_jsonl(dataset_path, embedding_field=embedding_field, id_field=id_field)
64
+ return retriever
65
+
66
+ # ---------------------------
67
+ # 侧边栏配置
68
+ # ---------------------------
69
+ st.sidebar.markdown("<div style='text-align: center;'><h3>📄 Model Configuration</h3></div>", unsafe_allow_html=True)
70
+ model_options = ["BAAI/bge-base-en-v1.5"]
71
+ selected_model = st.sidebar.selectbox("Select Model", model_options)
72
+ st.sidebar.write("Selected model:", selected_model)
73
+ st.sidebar.write("Embedding length: 768")
74
+
75
+ # 使用下拉框选中的模型(避免重复加载)
76
+ model = get_model(selected_model)
77
+ retriever = create_retriever(
78
+ config['vector_size'],
79
+ config['dataset_path'],
80
+ config['embedding_field'],
81
+ config['id_field'],
82
+ _model=model
83
+ )
84
+
85
+ # ---------------------------
86
+ # 界面样式设置
87
+ # ---------------------------
88
+ st.markdown("""
89
+ <style>
90
+ .search-container {
91
+ display: flex;
92
+ justify-content: center;
93
+ align-items: center;
94
+ gap: 10px;
95
+ margin-top: 20px;
96
+ }
97
+ .search-box input {
98
+ width: 500px !important;
99
+ height: 45px;
100
+ font-size: 16px;
101
+ border-radius: 25px;
102
+ padding-left: 15px;
103
+ }
104
+ .search-btn button {
105
+ height: 45px;
106
+ font-size: 16px;
107
+ border-radius: 25px;
108
+ }
109
+ </style>
110
+ """, unsafe_allow_html=True)
111
+
112
+ st.markdown("<h1 style='text-align: center;'>🔍 Tool Retrieval</h1>", unsafe_allow_html=True)
113
+
114
+ # ---------------------------
115
+ # 主体检索区域
116
+ # ---------------------------
117
+ col1, col2 = st.columns([4, 1])
118
+ with col1:
119
+ query = st.text_input("", placeholder="Enter your search query...", key="search_query", label_visibility="collapsed")
120
+ with col2:
121
+ search_clicked = st.button("🔎 Search", use_container_width=True)
122
+
123
+ top_k = st.slider("Top-K tools", 1, 100, 50, help="Choose the number of results to display")
124
+
125
+ if search_clicked and query:
126
+ rec_ids, scores = retriever.search_return_id(query, top_k)
127
+ # 构建检索结果 DataFrame
128
+ df1 = pd.DataFrame({"relevance": scores, "tool": rec_ids})
129
+ # 使用 join 加速合并(前提是 df2 已设置好索引)
130
+ results_df = df1.join(df2, on='tool', how='left').reset_index(drop=False)
131
+
132
+ st.subheader("🗂️ Retrieval results")
133
+
134
+ styled_results = results_df.style.apply(
135
+ lambda x: [
136
+ "background-color: #F7F7F7" if i % 2 == 0 else "background-color: #FFFFFF"
137
+ for i in range(len(x))
138
+ ],
139
+ axis=0,
140
+ ).format({"relevance": "{:.4f}"})
141
+
142
+ st.dataframe(
143
+ styled_results,
144
+ column_config={
145
+ "relevance": st.column_config.ProgressColumn(
146
+ "relevance",
147
+ help="记录与查询的匹配程度",
148
+ format="%.4f",
149
+ min_value=0,
150
+ max_value=float(max(scores)) if len(scores) > 0 else 1,
151
+ ),
152
+ "tool": st.column_config.TextColumn("tool", help="tool help text", width="medium")
153
+ },
154
+ hide_index=True,
155
+ use_container_width=True,
156
+ )
requirements.txt ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohappyeyeballs==2.4.3
2
+ aiohttp==3.11.0
3
+ aiosignal==1.3.1
4
+ altair==5.4.1
5
+ annotated-types==0.7.0
6
+ anyio==4.6.2.post1
7
+ asttokens==2.4.1
8
+ async-timeout==5.0.1
9
+ attrs==24.2.0
10
+ blessed==1.20.0
11
+ blinker==1.9.0
12
+ branca==0.8.0
13
+ cachetools==5.5.0
14
+ certifi==2024.8.30
15
+ charset-normalizer==3.4.0
16
+ click==8.1.7
17
+ comm==0.2.2
18
+ contourpy==1.3.0
19
+ cycler==0.12.1
20
+ datasets==3.1.0
21
+ debugpy==1.8.8
22
+ decorator==5.1.1
23
+ dill==0.3.8
24
+ distro==1.9.0
25
+ et_xmlfile==2.0.0
26
+ exceptiongroup==1.2.2
27
+ executing==2.1.0
28
+ f==0.0.1
29
+ faiss-gpu==1.7.2
30
+ filelock==3.16.1
31
+ folium==0.18.0
32
+ fonttools==4.54.1
33
+ frozenlist==1.5.0
34
+ fsspec==2024.9.0
35
+ geopandas==1.0.1
36
+ gitdb==4.0.11
37
+ GitPython==3.1.43
38
+ gpustat==1.1.1
39
+ h11==0.14.0
40
+ httpcore==1.0.6
41
+ httpx==0.27.2
42
+ huggingface-hub==0.26.2
43
+ idna==3.10
44
+ importlib_metadata==8.5.0
45
+ importlib_resources==6.4.5
46
+ ipykernel==6.29.5
47
+ ipython==8.18.1
48
+ jedi==0.19.2
49
+ jieba==0.42.1
50
+ Jinja2==3.1.4
51
+ jiter==0.7.1
52
+ joblib==1.4.2
53
+ jsonschema==4.23.0
54
+ jsonschema-specifications==2024.10.1
55
+ jupyter_client==8.6.3
56
+ jupyter_core==5.7.2
57
+ kiwisolver==1.4.7
58
+ loguru==0.7.2
59
+ markdown-it-py==3.0.0
60
+ MarkupSafe==3.0.2
61
+ matplotlib==3.9.2
62
+ matplotlib-inline==0.1.7
63
+ mdurl==0.1.2
64
+ mpmath==1.3.0
65
+ multidict==6.1.0
66
+ multiprocess==0.70.16
67
+ narwhals==1.13.5
68
+ nest-asyncio==1.6.0
69
+ networkx==3.2.1
70
+ numpy==1.26.0
71
+ nvidia-cublas-cu12==12.4.5.8
72
+ nvidia-cuda-cupti-cu12==12.4.127
73
+ nvidia-cuda-nvrtc-cu12==12.4.127
74
+ nvidia-cuda-runtime-cu12==12.4.127
75
+ nvidia-cudnn-cu12==9.1.0.70
76
+ nvidia-cufft-cu12==11.2.1.3
77
+ nvidia-curand-cu12==10.3.5.147
78
+ nvidia-cusolver-cu12==11.6.1.9
79
+ nvidia-cusparse-cu12==12.3.1.170
80
+ nvidia-ml-py==12.560.30
81
+ nvidia-nccl-cu12==2.21.5
82
+ nvidia-nvjitlink-cu12==12.4.127
83
+ nvidia-nvtx-cu12==12.4.127
84
+ openai==0.28.0
85
+ openpyxl==3.1.5
86
+ packaging==24.2
87
+ pandas==2.2.3
88
+ parso==0.8.4
89
+ pexpect==4.9.0
90
+ pillow==11.0.0
91
+ platformdirs==4.3.6
92
+ prettytable==3.12.0
93
+ prompt_toolkit==3.0.48
94
+ propcache==0.2.0
95
+ protobuf==5.28.3
96
+ psutil==6.1.0
97
+ ptyprocess==0.7.0
98
+ pure_eval==0.2.3
99
+ pyarrow==18.0.0
100
+ pydantic==2.9.2
101
+ pydantic_core==2.23.4
102
+ pydeck==0.9.1
103
+ pyecharts==2.0.7
104
+ Pygments==2.18.0
105
+ pyogrio==0.10.0
106
+ pyparsing==3.2.0
107
+ pyproj==3.6.1
108
+ python-dateutil==2.9.0.post0
109
+ pytz==2024.2
110
+ PyYAML==6.0.2
111
+ pyzmq==26.2.0
112
+ referencing==0.35.1
113
+ regex==2024.11.6
114
+ requests==2.32.3
115
+ rich==13.9.4
116
+ rpds-py==0.21.0
117
+ safetensors==0.4.5
118
+ scikit-learn==1.5.2
119
+ scipy==1.13.1
120
+ seaborn==0.13.2
121
+ shapely==2.0.6
122
+ simplejson==3.19.3
123
+ six==1.16.0
124
+ smmap==5.0.1
125
+ sniffio==1.3.1
126
+ stack-data==0.6.3
127
+ streamlit==1.40.1
128
+ streamlit-echarts==0.4.0
129
+ streamlit-option-menu==0.4.0
130
+ streamlit_folium==0.23.1
131
+ sympy==1.13.1
132
+ tenacity==9.0.0
133
+ text2vec==1.3.1
134
+ threadpoolctl==3.5.0
135
+ tokenizers==0.20.3
136
+ toml==0.10.2
137
+ torch==2.5.1
138
+ tornado==6.4.1
139
+ tqdm==4.67.0
140
+ traitlets==5.14.3
141
+ transformers==4.46.2
142
+ triton==3.1.0
143
+ typing_extensions==4.12.2
144
+ tzdata==2024.2
145
+ urllib3==2.2.3
146
+ watchdog==6.0.0
147
+ wcwidth==0.2.13
148
+ wordcloud==1.9.4
149
+ xxhash==3.5.0
150
+ xyzservices==2024.9.0
151
+ yarl==1.17.1
152
+ zipp==3.21.0
src/jsonl_Indexer.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import faiss
4
+ import numpy as np
5
+ from typing import List, Tuple
6
+ from text2vec import SentenceModel
7
+
8
+ class JSONLIndexer(object):
9
+ """
10
+ JSONL 文件检索器(基于预计算embedding)
11
+ """
12
+ def __init__(self, vector_sz: int, n_subquantizers=0, n_bits=8, model: SentenceModel = None, **kwargs):
13
+ """
14
+ 初始化索引器,选择使用FAISS的类型
15
+ :param vector_sz: 嵌入向量的大小
16
+ :param n_subquantizers: 子量化器数量
17
+ :param n_bits: 每个子向量的位数
18
+ :param model: SentenceModel 模型,用于对query重新embedding
19
+ """
20
+ if n_subquantizers > 0:
21
+ self.index = faiss.IndexPQ(vector_sz, n_subquantizers, n_bits, faiss.METRIC_INNER_PRODUCT)
22
+ else:
23
+ self.index = faiss.IndexFlatIP(vector_sz)
24
+ self.index_id_to_data = [] # FAISS索引ID到JSON记录索引的映射
25
+ self.data = [] # 存储所有JSON对象
26
+ self.model = model
27
+
28
+ print(f'Initialized FAISS index of type {type(self.index)}')
29
+
30
+ def load_jsonl(self, dataset_path: str, embedding_field: str = "embedding", id_field: str = "id") -> None:
31
+ """
32
+ 加载JSONL文件并构建FAISS索引(使用预计算embedding)
33
+ :param dataset_path: JSONL文件路径
34
+ :param embedding_field: JSON对象中存放embedding的字段名
35
+ :param id_field: JSON对象中作为待检索文本的字段(这里认为为id)
36
+ """
37
+ print(f'📂 Loading JSONL file: {dataset_path}...')
38
+ # 逐行读取JSONL文件
39
+ with open(dataset_path, 'r', encoding='utf-8') as f:
40
+ for line in f:
41
+ line = line.strip()
42
+ if not line:
43
+ continue
44
+ record = json.loads(line)
45
+ self.data.append(record)
46
+ total = len(self.data)
47
+ print(f'✅ Loaded {total} records from {dataset_path}.')
48
+
49
+ # 直接从每个JSON对象中提取预计算embedding
50
+ embeddings_list = []
51
+ for rec in self.data:
52
+ emb = rec.get(embedding_field, [])
53
+ # 检查embedding长度是否符合预期
54
+ if len(emb) != self.index.d and self.index.ntotal == 0:
55
+ # 如果第一次添加且长度不匹配,可以根据需要进行处理,比如报错或跳过
56
+ raise ValueError(f"Embedding length mismatch. Expected {self.index.d}, got {len(emb)}.")
57
+ embeddings_list.append(np.array(emb, dtype=np.float32))
58
+ embeddings = np.stack(embeddings_list, axis=0)
59
+ print(f'✅ Embeddings loaded, shape: {embeddings.shape}. Indexing data...')
60
+
61
+ # 用数据在FAISS中建立索引
62
+ ids = list(range(total))
63
+ self.index_data(ids, embeddings)
64
+ print('🎉 Indexing complete!')
65
+
66
+ def index_data(self, ids: List[int], embeddings: np.array, **kwargs):
67
+ """
68
+ 将预先计算好的embedding添加到FAISS索引中
69
+ :param ids: 每个记录的索引号(这里用list(range(total)))
70
+ :param embeddings: 记录对应的embedding矩阵
71
+ """
72
+ self._update_id_mapping(ids)
73
+ embeddings = embeddings.astype('float32')
74
+
75
+ # 如果索引未训练,则先训练
76
+ if not self.index.is_trained:
77
+ print('⚙️ Training FAISS index...')
78
+ self.index.train(embeddings)
79
+ print('✅ FAISS index trained.')
80
+ self.index.add(embeddings)
81
+ print(f'✅ Indexed {len(self.index_id_to_data)} records.')
82
+
83
+ def _update_id_mapping(self, row_ids: List[int]):
84
+ """更新FAISS索引ID到JSON记录索引的映射"""
85
+ self.index_id_to_data.extend(row_ids)
86
+
87
+ def search_return_id(self, query: str, top_docs: int) -> Tuple[List[str], List[float]]:
88
+ """
89
+ 根据query返回最相似的JSON记录的id和相似度分数
90
+ :param query: 查询文本
91
+ :param top_docs: 返回的最近邻记录数量
92
+ :return: (记录的id列表, 分数列表)
93
+ """
94
+ db_indices, scores = self.search(query, top_docs)
95
+ # 这里假设待检索文本就是json对象中的id字段
96
+ result_ids = [self.data[i]["id"] for i in db_indices]
97
+ return result_ids, scores
98
+
99
+ def search(self, query: str, top_docs: int) -> Tuple[List[int], List[float]]:
100
+ """
101
+ 对query重新embedding后,在FAISS索引中检索
102
+ :param query: 查询文本
103
+ :param top_docs: 返回的最近邻记录数量
104
+ :return: (JSON记录的索引列表, 相似度分数列表)
105
+ """
106
+ # 仅对query重新计算embedding
107
+ query_vector = self.model.encode(query).astype('float32').reshape(1, -1)
108
+ scores, indexes = self.index.search(query_vector, top_docs)
109
+ scores = scores[0]
110
+ indexes = indexes[0]
111
+ db_indices = [self.index_id_to_data[i] for i in indexes]
112
+ return db_indices, scores
113
+
114
+ # 示例用法
115
+ if __name__ == '__main__':
116
+ model = SentenceModel("BAAI/bge-base-en-v1.5")
117
+ vector_size = 768 # 请根据你的模型确定嵌入向量维度
118
+
119
+ indexer = JSONLIndexer(vector_sz=vector_size, model=model)
120
+ jsonl_path = "tool-embedding.jsonl" # 替换为实际JSONL文件路径
121
+ indexer.load_jsonl(jsonl_path)
122
+
123
+ query = "your search query here"
124
+ ids, scores = indexer.search_return_id(query, top_docs=5)
125
+ print("检索结果:", list(zip(ids, scores)))
tool-embedding.jsonl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e172fd628eba5d4d59d1b467e58827a9874e981a4667e3e297ca4c01f1b275f5
3
+ size 674988141