import os import sys import faiss import numpy as np import streamlit as st import pandas as pd from text2vec import SentenceModel from src.jsonl_Indexer import JSONLIndexer def get_cli_args(): args = {} argv = sys.argv[2:] if len(sys.argv) > 2 else [] for arg in argv: if '=' in arg: key, value = arg.split('=', 1) args[key.strip()] = value.strip() return args cli_args = get_cli_args() DEFAULT_CONFIG = { 'model_path': 'BAAI/bge-base-en-v1.5', 'dataset_path': 'tool-embedding.jsonl', 'vector_size': 768, 'embedding_field': 'embedding', 'id_field': 'id' } config = DEFAULT_CONFIG.copy() config.update(cli_args) config['vector_size'] = int(config['vector_size']) # --------------------------- # 缓存数据集加载函数(避免每次运行时重复下载数据) # --------------------------- @st.cache_data def load_tools_datasets(): from datasets import load_dataset, concatenate_datasets ds1 = load_dataset("mangopy/ToolRet-Tools", "code") ds2 = load_dataset("mangopy/ToolRet-Tools", "customized") ds3 = load_dataset("mangopy/ToolRet-Tools", "web") ds = concatenate_datasets([ds1['tools'], ds2['tools'], ds3['tools']]) # 重命名'id'字段为'tool' ds = ds.rename_columns({'id': 'tool'}) return ds ds = load_tools_datasets() df2 = ds.to_pandas() # 如果数据量较大,可以通过设置索引加速后续的合并操作 df2.set_index('tool', inplace=True) # --------------------------- # 缓存模型加载函数 # --------------------------- @st.cache_resource def get_model(model_path: str = config['model_path']): return SentenceModel(model_path) # 缓存检索器创建函数 @st.cache_resource def create_retriever(vector_sz: int, dataset_path: str, embedding_field: str, id_field: str, _model): retriever = JSONLIndexer(vector_sz=vector_sz, model=_model) retriever.load_jsonl(dataset_path, embedding_field=embedding_field, id_field=id_field) return retriever # --------------------------- # 侧边栏配置 # --------------------------- st.sidebar.markdown("

📄 Model Configuration

", unsafe_allow_html=True) model_options = ["BAAI/bge-base-en-v1.5"] selected_model = st.sidebar.selectbox("Select Model", model_options) st.sidebar.write("Selected model:", selected_model) st.sidebar.write("Embedding length: 768") # 使用下拉框选中的模型(避免重复加载) model = get_model(selected_model) retriever = create_retriever( config['vector_size'], config['dataset_path'], config['embedding_field'], config['id_field'], _model=model ) # --------------------------- # 界面样式设置 # --------------------------- st.markdown(""" """, unsafe_allow_html=True) st.markdown("

🔍 Tool Retrieval

", unsafe_allow_html=True) # --------------------------- # 主体检索区域 # --------------------------- col1, col2 = st.columns([4, 1]) with col1: query = st.text_input("", placeholder="Enter your search query...", key="search_query", label_visibility="collapsed") with col2: search_clicked = st.button("🔎 Search", use_container_width=True) top_k = st.slider("Top-K tools", 1, 100, 50, help="Choose the number of results to display") if search_clicked and query: rec_ids, scores = retriever.search_return_id(query, top_k) # 构建检索结果 DataFrame df1 = pd.DataFrame({"relevance": scores, "tool": rec_ids}) # 使用 join 加速合并(前提是 df2 已设置好索引) results_df = df1.join(df2, on='tool', how='left').reset_index(drop=False) st.subheader("🗂️ Retrieval results") styled_results = results_df.style.apply( lambda x: [ "background-color: #F7F7F7" if i % 2 == 0 else "background-color: #FFFFFF" for i in range(len(x)) ], axis=0, ).format({"relevance": "{:.4f}"}) st.dataframe( styled_results, column_config={ "relevance": st.column_config.ProgressColumn( "relevance", help="记录与查询的匹配程度", format="%.4f", min_value=0, max_value=float(max(scores)) if len(scores) > 0 else 1, ), "tool": st.column_config.TextColumn("tool", help="tool help text", width="medium") }, hide_index=True, use_container_width=True, )