Yyy0530 commited on
Commit
ec1c57d
·
verified ·
1 Parent(s): b97944f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -36
app.py CHANGED
@@ -30,52 +30,61 @@ config = DEFAULT_CONFIG.copy()
30
  config.update(cli_args)
31
  config['vector_size'] = int(config['vector_size'])
32
 
33
-
34
- #加载数据
35
- from datasets import load_dataset
36
- from datasets import concatenate_datasets
37
- ds1 = load_dataset("mangopy/ToolRet-Tools", "code")
38
- ds2 = load_dataset("mangopy/ToolRet-Tools", "customized")
39
- ds3 = load_dataset("mangopy/ToolRet-Tools", "web")
40
- ds = concatenate_datasets([ds1['tools'], ds2['tools'], ds3['tools']])
41
- ds = ds.rename_columns({'id':'tool'})
42
-
43
- #merge
44
-
45
- # 随便建立一个pd.DataFrame, 有两列,一列是id,一列是text
46
- import pandas as pd
 
47
  df2 = ds.to_pandas()
 
 
48
 
49
-
50
-
 
51
  @st.cache_resource
52
  def get_model(model_path: str = config['model_path']):
53
  return SentenceModel(model_path)
54
 
 
55
  @st.cache_resource
56
  def create_retriever(vector_sz: int, dataset_path: str, embedding_field: str, id_field: str, _model):
57
  retriever = JSONLIndexer(vector_sz=vector_sz, model=_model)
58
  retriever.load_jsonl(dataset_path, embedding_field=embedding_field, id_field=id_field)
59
  return retriever
60
 
61
- # 在侧边栏中添加模型配置标题
 
 
62
  st.sidebar.markdown("<div style='text-align: center;'><h3>📄 Model Configuration</h3></div>", unsafe_allow_html=True)
63
-
64
-
65
- # 添加模型选项下拉框,目前只有一个模型可选
66
  model_options = ["BAAI/bge-base-en-v1.5"]
67
  selected_model = st.sidebar.selectbox("Select Model", model_options)
68
  st.sidebar.write("Selected model:", selected_model)
69
  st.sidebar.write("Embedding length: 768")
70
 
71
- # 使用选中的模型加载
72
  model = get_model(selected_model)
73
-
74
-
75
- model = get_model(config['model_path'])
76
- retriever = create_retriever(config['vector_size'], config['dataset_path'], config['embedding_field'], config['id_field'], _model=model)
77
-
78
- # 美化界面
 
 
 
 
 
79
  st.markdown("""
80
  <style>
81
  .search-container {
@@ -102,7 +111,9 @@ st.markdown("""
102
 
103
  st.markdown("<h1 style='text-align: center;'>🔍 Tool Retrieval</h1>", unsafe_allow_html=True)
104
 
105
-
 
 
106
  col1, col2 = st.columns([4, 1])
107
  with col1:
108
  query = st.text_input("", placeholder="Enter your search query...", key="search_query", label_visibility="collapsed")
@@ -111,15 +122,13 @@ with col2:
111
 
112
  top_k = st.slider("Top-K tools", 1, 100, 50, help="Choose the number of results to display")
113
 
114
- styled_results = None
115
  if search_clicked and query:
116
  rec_ids, scores = retriever.search_return_id(query, top_k)
117
- df1 = pd.DataFrame({ "relevance": scores, "tool": rec_ids})
118
- # print(df1)
119
- # merge两个DataFrame
120
- results_df = pd.merge(df1, df2, on='tool', how = 'left')
121
-
122
- # results_df["interface"] = "asdasdadasdasdasdasdasdasdasasdasdasdasdasdasdasdasdasdasdasdasdasdasdasdasdasdassdasdasdasdasdasabababbabasdbabsdbasbdadabdbasdbasbdbasdbasdbasdb"
123
  st.subheader("🗂️ Retrieval results")
124
 
125
  styled_results = results_df.style.apply(
@@ -129,7 +138,7 @@ if search_clicked and query:
129
  ],
130
  axis=0,
131
  ).format({"relevance": "{:.4f}"})
132
-
133
  st.dataframe(
134
  styled_results,
135
  column_config={
 
30
  config.update(cli_args)
31
  config['vector_size'] = int(config['vector_size'])
32
 
33
+ # ---------------------------
34
+ # 缓存数据集加载函数(避免每次运行时重复下载数据)
35
+ # ---------------------------
36
+ @st.cache_data
37
+ def load_tools_datasets():
38
+ from datasets import load_dataset, concatenate_datasets
39
+ ds1 = load_dataset("mangopy/ToolRet-Tools", "code")
40
+ ds2 = load_dataset("mangopy/ToolRet-Tools", "customized")
41
+ ds3 = load_dataset("mangopy/ToolRet-Tools", "web")
42
+ ds = concatenate_datasets([ds1['tools'], ds2['tools'], ds3['tools']])
43
+ # 重命名'id'字段为'tool'
44
+ ds = ds.rename_columns({'id': 'tool'})
45
+ return ds
46
+
47
+ ds = load_tools_datasets()
48
  df2 = ds.to_pandas()
49
+ # 如果数据量较大,可以通过设置索引加速后续的合并操作
50
+ df2.set_index('tool', inplace=True)
51
 
52
+ # ---------------------------
53
+ # 缓存模型加载函数
54
+ # ---------------------------
55
  @st.cache_resource
56
  def get_model(model_path: str = config['model_path']):
57
  return SentenceModel(model_path)
58
 
59
+ # 缓存检索器创建函数
60
  @st.cache_resource
61
  def create_retriever(vector_sz: int, dataset_path: str, embedding_field: str, id_field: str, _model):
62
  retriever = JSONLIndexer(vector_sz=vector_sz, model=_model)
63
  retriever.load_jsonl(dataset_path, embedding_field=embedding_field, id_field=id_field)
64
  return retriever
65
 
66
+ # ---------------------------
67
+ # 侧边栏配置
68
+ # ---------------------------
69
  st.sidebar.markdown("<div style='text-align: center;'><h3>📄 Model Configuration</h3></div>", unsafe_allow_html=True)
 
 
 
70
  model_options = ["BAAI/bge-base-en-v1.5"]
71
  selected_model = st.sidebar.selectbox("Select Model", model_options)
72
  st.sidebar.write("Selected model:", selected_model)
73
  st.sidebar.write("Embedding length: 768")
74
 
75
+ # 使用下拉框选中的模型(避免重复加载)
76
  model = get_model(selected_model)
77
+ retriever = create_retriever(
78
+ config['vector_size'],
79
+ config['dataset_path'],
80
+ config['embedding_field'],
81
+ config['id_field'],
82
+ _model=model
83
+ )
84
+
85
+ # ---------------------------
86
+ # 界面样式设置
87
+ # ---------------------------
88
  st.markdown("""
89
  <style>
90
  .search-container {
 
111
 
112
  st.markdown("<h1 style='text-align: center;'>🔍 Tool Retrieval</h1>", unsafe_allow_html=True)
113
 
114
+ # ---------------------------
115
+ # 主体检索区域
116
+ # ---------------------------
117
  col1, col2 = st.columns([4, 1])
118
  with col1:
119
  query = st.text_input("", placeholder="Enter your search query...", key="search_query", label_visibility="collapsed")
 
122
 
123
  top_k = st.slider("Top-K tools", 1, 100, 50, help="Choose the number of results to display")
124
 
 
125
  if search_clicked and query:
126
  rec_ids, scores = retriever.search_return_id(query, top_k)
127
+ # 构建检索结果 DataFrame
128
+ df1 = pd.DataFrame({"relevance": scores, "tool": rec_ids})
129
+ # 使用 join 加速合并(前提是 df2 已设置好索引)
130
+ results_df = df1.join(df2, on='tool', how='left').reset_index(drop=False)
131
+
 
132
  st.subheader("🗂️ Retrieval results")
133
 
134
  styled_results = results_df.style.apply(
 
138
  ],
139
  axis=0,
140
  ).format({"relevance": "{:.4f}"})
141
+
142
  st.dataframe(
143
  styled_results,
144
  column_config={