Spaces:
Running
Running
import os | |
import sys | |
import faiss | |
import numpy as np | |
import streamlit as st | |
import pandas as pd | |
from text2vec import SentenceModel | |
from src.jsonl_Indexer import JSONLIndexer | |
def get_cli_args(): | |
args = {} | |
argv = sys.argv[2:] if len(sys.argv) > 2 else [] | |
for arg in argv: | |
if '=' in arg: | |
key, value = arg.split('=', 1) | |
args[key.strip()] = value.strip() | |
return args | |
cli_args = get_cli_args() | |
DEFAULT_CONFIG = { | |
'model_path': 'BAAI/bge-base-en-v1.5', | |
'dataset_path': 'tool-embedding.jsonl', | |
'vector_size': 768, | |
'embedding_field': 'embedding', | |
'id_field': 'id' | |
} | |
config = DEFAULT_CONFIG.copy() | |
config.update(cli_args) | |
config['vector_size'] = int(config['vector_size']) | |
# --------------------------- | |
# 缓存数据集加载函数(避免每次运行时重复下载数据) | |
# --------------------------- | |
def load_tools_datasets(): | |
from datasets import load_dataset, concatenate_datasets | |
ds1 = load_dataset("mangopy/ToolRet-Tools", "code") | |
ds2 = load_dataset("mangopy/ToolRet-Tools", "customized") | |
ds3 = load_dataset("mangopy/ToolRet-Tools", "web") | |
ds = concatenate_datasets([ds1['tools'], ds2['tools'], ds3['tools']]) | |
# 重命名'id'字段为'tool' | |
ds = ds.rename_columns({'id': 'tool'}) | |
return ds | |
ds = load_tools_datasets() | |
df2 = ds.to_pandas() | |
# 如果数据量较大,可以通过设置索引加速后续的合并操作 | |
df2.set_index('tool', inplace=True) | |
# --------------------------- | |
# 缓存模型加载函数 | |
# --------------------------- | |
def get_model(model_path: str = config['model_path']): | |
return SentenceModel(model_path) | |
# 缓存检索器创建函数 | |
def create_retriever(vector_sz: int, dataset_path: str, embedding_field: str, id_field: str, _model): | |
retriever = JSONLIndexer(vector_sz=vector_sz, model=_model) | |
retriever.load_jsonl(dataset_path, embedding_field=embedding_field, id_field=id_field) | |
return retriever | |
# --------------------------- | |
# 侧边栏配置 | |
# --------------------------- | |
st.sidebar.markdown("<div style='text-align: center;'><h3>📄 Model Configuration</h3></div>", unsafe_allow_html=True) | |
model_options = ["BAAI/bge-base-en-v1.5"] | |
selected_model = st.sidebar.selectbox("Select Model", model_options) | |
st.sidebar.write("Selected model:", selected_model) | |
st.sidebar.write("Embedding length: 768") | |
# 使用下拉框选中的模型(避免重复加载) | |
model = get_model(selected_model) | |
retriever = create_retriever( | |
config['vector_size'], | |
config['dataset_path'], | |
config['embedding_field'], | |
config['id_field'], | |
_model=model | |
) | |
# --------------------------- | |
# 界面样式设置 | |
# --------------------------- | |
st.markdown(""" | |
<style> | |
.search-container { | |
display: flex; | |
justify-content: center; | |
align-items: center; | |
gap: 10px; | |
margin-top: 20px; | |
} | |
.search-box input { | |
width: 500px !important; | |
height: 45px; | |
font-size: 16px; | |
border-radius: 25px; | |
padding-left: 15px; | |
} | |
.search-btn button { | |
height: 45px; | |
font-size: 16px; | |
border-radius: 25px; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
st.markdown("<h1 style='text-align: center;'>🔍 Tool Retrieval</h1>", unsafe_allow_html=True) | |
# --------------------------- | |
# 主体检索区域 | |
# --------------------------- | |
col1, col2 = st.columns([4, 1]) | |
with col1: | |
query = st.text_input("", placeholder="Enter your search query...", key="search_query", label_visibility="collapsed") | |
with col2: | |
search_clicked = st.button("🔎 Search", use_container_width=True) | |
top_k = st.slider("Top-K tools", 1, 100, 50, help="Choose the number of results to display") | |
if search_clicked and query: | |
rec_ids, scores = retriever.search_return_id(query, top_k) | |
# 构建检索结果 DataFrame | |
df1 = pd.DataFrame({"relevance": scores, "tool": rec_ids}) | |
# 使用 join 加速合并(前提是 df2 已设置好索引) | |
results_df = df1.join(df2, on='tool', how='left').reset_index(drop=False) | |
st.subheader("🗂️ Retrieval results") | |
styled_results = results_df.style.apply( | |
lambda x: [ | |
"background-color: #F7F7F7" if i % 2 == 0 else "background-color: #FFFFFF" | |
for i in range(len(x)) | |
], | |
axis=0, | |
).format({"relevance": "{:.4f}"}) | |
st.dataframe( | |
styled_results, | |
column_config={ | |
"relevance": st.column_config.ProgressColumn( | |
"relevance", | |
help="记录与查询的匹配程度", | |
format="%.4f", | |
min_value=0, | |
max_value=float(max(scores)) if len(scores) > 0 else 1, | |
), | |
"tool": st.column_config.TextColumn("tool", help="tool help text", width="medium") | |
}, | |
hide_index=True, | |
use_container_width=True, | |
) | |