tool_retriever / app.py
Yyy0530's picture
Update app.py
ec1c57d verified
import os
import sys
import faiss
import numpy as np
import streamlit as st
import pandas as pd
from text2vec import SentenceModel
from src.jsonl_Indexer import JSONLIndexer
def get_cli_args():
args = {}
argv = sys.argv[2:] if len(sys.argv) > 2 else []
for arg in argv:
if '=' in arg:
key, value = arg.split('=', 1)
args[key.strip()] = value.strip()
return args
cli_args = get_cli_args()
DEFAULT_CONFIG = {
'model_path': 'BAAI/bge-base-en-v1.5',
'dataset_path': 'tool-embedding.jsonl',
'vector_size': 768,
'embedding_field': 'embedding',
'id_field': 'id'
}
config = DEFAULT_CONFIG.copy()
config.update(cli_args)
config['vector_size'] = int(config['vector_size'])
# ---------------------------
# 缓存数据集加载函数(避免每次运行时重复下载数据)
# ---------------------------
@st.cache_data
def load_tools_datasets():
from datasets import load_dataset, concatenate_datasets
ds1 = load_dataset("mangopy/ToolRet-Tools", "code")
ds2 = load_dataset("mangopy/ToolRet-Tools", "customized")
ds3 = load_dataset("mangopy/ToolRet-Tools", "web")
ds = concatenate_datasets([ds1['tools'], ds2['tools'], ds3['tools']])
# 重命名'id'字段为'tool'
ds = ds.rename_columns({'id': 'tool'})
return ds
ds = load_tools_datasets()
df2 = ds.to_pandas()
# 如果数据量较大,可以通过设置索引加速后续的合并操作
df2.set_index('tool', inplace=True)
# ---------------------------
# 缓存模型加载函数
# ---------------------------
@st.cache_resource
def get_model(model_path: str = config['model_path']):
return SentenceModel(model_path)
# 缓存检索器创建函数
@st.cache_resource
def create_retriever(vector_sz: int, dataset_path: str, embedding_field: str, id_field: str, _model):
retriever = JSONLIndexer(vector_sz=vector_sz, model=_model)
retriever.load_jsonl(dataset_path, embedding_field=embedding_field, id_field=id_field)
return retriever
# ---------------------------
# 侧边栏配置
# ---------------------------
st.sidebar.markdown("<div style='text-align: center;'><h3>📄 Model Configuration</h3></div>", unsafe_allow_html=True)
model_options = ["BAAI/bge-base-en-v1.5"]
selected_model = st.sidebar.selectbox("Select Model", model_options)
st.sidebar.write("Selected model:", selected_model)
st.sidebar.write("Embedding length: 768")
# 使用下拉框选中的模型(避免重复加载)
model = get_model(selected_model)
retriever = create_retriever(
config['vector_size'],
config['dataset_path'],
config['embedding_field'],
config['id_field'],
_model=model
)
# ---------------------------
# 界面样式设置
# ---------------------------
st.markdown("""
<style>
.search-container {
display: flex;
justify-content: center;
align-items: center;
gap: 10px;
margin-top: 20px;
}
.search-box input {
width: 500px !important;
height: 45px;
font-size: 16px;
border-radius: 25px;
padding-left: 15px;
}
.search-btn button {
height: 45px;
font-size: 16px;
border-radius: 25px;
}
</style>
""", unsafe_allow_html=True)
st.markdown("<h1 style='text-align: center;'>🔍 Tool Retrieval</h1>", unsafe_allow_html=True)
# ---------------------------
# 主体检索区域
# ---------------------------
col1, col2 = st.columns([4, 1])
with col1:
query = st.text_input("", placeholder="Enter your search query...", key="search_query", label_visibility="collapsed")
with col2:
search_clicked = st.button("🔎 Search", use_container_width=True)
top_k = st.slider("Top-K tools", 1, 100, 50, help="Choose the number of results to display")
if search_clicked and query:
rec_ids, scores = retriever.search_return_id(query, top_k)
# 构建检索结果 DataFrame
df1 = pd.DataFrame({"relevance": scores, "tool": rec_ids})
# 使用 join 加速合并(前提是 df2 已设置好索引)
results_df = df1.join(df2, on='tool', how='left').reset_index(drop=False)
st.subheader("🗂️ Retrieval results")
styled_results = results_df.style.apply(
lambda x: [
"background-color: #F7F7F7" if i % 2 == 0 else "background-color: #FFFFFF"
for i in range(len(x))
],
axis=0,
).format({"relevance": "{:.4f}"})
st.dataframe(
styled_results,
column_config={
"relevance": st.column_config.ProgressColumn(
"relevance",
help="记录与查询的匹配程度",
format="%.4f",
min_value=0,
max_value=float(max(scores)) if len(scores) > 0 else 1,
),
"tool": st.column_config.TextColumn("tool", help="tool help text", width="medium")
},
hide_index=True,
use_container_width=True,
)