Spaces:
Running
Running
File size: 4,941 Bytes
9ab9c77 8911626 9ab9c77 b97944f 9ab9c77 7c04c37 9ab9c77 ec1c57d a302d1e ec1c57d a302d1e ec1c57d 9ab9c77 7c04c37 9ab9c77 ec1c57d 9ab9c77 ec1c57d a302d1e ec1c57d a302d1e ec1c57d 7c04c37 9ab9c77 a302d1e ec1c57d 7c04c37 8911626 7c04c37 8911626 7c04c37 ce422d8 a302d1e 7c04c37 9ab9c77 ec1c57d a302d1e 8911626 ec1c57d 8911626 a302d1e 8911626 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import os
import sys
import faiss
import numpy as np
import streamlit as st
import pandas as pd
from text2vec import SentenceModel
from src.jsonl_Indexer import JSONLIndexer
def get_cli_args():
args = {}
argv = sys.argv[2:] if len(sys.argv) > 2 else []
for arg in argv:
if '=' in arg:
key, value = arg.split('=', 1)
args[key.strip()] = value.strip()
return args
cli_args = get_cli_args()
DEFAULT_CONFIG = {
'model_path': 'BAAI/bge-base-en-v1.5',
'dataset_path': 'tool-embedding.jsonl',
'vector_size': 768,
'embedding_field': 'embedding',
'id_field': 'id'
}
config = DEFAULT_CONFIG.copy()
config.update(cli_args)
config['vector_size'] = int(config['vector_size'])
# ---------------------------
# 缓存数据集加载函数(避免每次运行时重复下载数据)
# ---------------------------
@st.cache_data
def load_tools_datasets():
from datasets import load_dataset, concatenate_datasets
ds1 = load_dataset("mangopy/ToolRet-Tools", "code")
ds2 = load_dataset("mangopy/ToolRet-Tools", "customized")
ds3 = load_dataset("mangopy/ToolRet-Tools", "web")
ds = concatenate_datasets([ds1['tools'], ds2['tools'], ds3['tools']])
# 重命名'id'字段为'tool'
ds = ds.rename_columns({'id': 'tool'})
return ds
ds = load_tools_datasets()
df2 = ds.to_pandas()
# 如果数据量较大,可以通过设置索引加速后续的合并操作
df2.set_index('tool', inplace=True)
# ---------------------------
# 缓存模型加载函数
# ---------------------------
@st.cache_resource
def get_model(model_path: str = config['model_path']):
return SentenceModel(model_path)
# 缓存检索器创建函数
@st.cache_resource
def create_retriever(vector_sz: int, dataset_path: str, embedding_field: str, id_field: str, _model):
retriever = JSONLIndexer(vector_sz=vector_sz, model=_model)
retriever.load_jsonl(dataset_path, embedding_field=embedding_field, id_field=id_field)
return retriever
# ---------------------------
# 侧边栏配置
# ---------------------------
st.sidebar.markdown("<div style='text-align: center;'><h3>📄 Model Configuration</h3></div>", unsafe_allow_html=True)
model_options = ["BAAI/bge-base-en-v1.5"]
selected_model = st.sidebar.selectbox("Select Model", model_options)
st.sidebar.write("Selected model:", selected_model)
st.sidebar.write("Embedding length: 768")
# 使用下拉框选中的模型(避免重复加载)
model = get_model(selected_model)
retriever = create_retriever(
config['vector_size'],
config['dataset_path'],
config['embedding_field'],
config['id_field'],
_model=model
)
# ---------------------------
# 界面样式设置
# ---------------------------
st.markdown("""
<style>
.search-container {
display: flex;
justify-content: center;
align-items: center;
gap: 10px;
margin-top: 20px;
}
.search-box input {
width: 500px !important;
height: 45px;
font-size: 16px;
border-radius: 25px;
padding-left: 15px;
}
.search-btn button {
height: 45px;
font-size: 16px;
border-radius: 25px;
}
</style>
""", unsafe_allow_html=True)
st.markdown("<h1 style='text-align: center;'>🔍 Tool Retrieval</h1>", unsafe_allow_html=True)
# ---------------------------
# 主体检索区域
# ---------------------------
col1, col2 = st.columns([4, 1])
with col1:
query = st.text_input("", placeholder="Enter your search query...", key="search_query", label_visibility="collapsed")
with col2:
search_clicked = st.button("🔎 Search", use_container_width=True)
top_k = st.slider("Top-K tools", 1, 100, 50, help="Choose the number of results to display")
if search_clicked and query:
rec_ids, scores = retriever.search_return_id(query, top_k)
# 构建检索结果 DataFrame
df1 = pd.DataFrame({"relevance": scores, "tool": rec_ids})
# 使用 join 加速合并(前提是 df2 已设置好索引)
results_df = df1.join(df2, on='tool', how='left').reset_index(drop=False)
st.subheader("🗂️ Retrieval results")
styled_results = results_df.style.apply(
lambda x: [
"background-color: #F7F7F7" if i % 2 == 0 else "background-color: #FFFFFF"
for i in range(len(x))
],
axis=0,
).format({"relevance": "{:.4f}"})
st.dataframe(
styled_results,
column_config={
"relevance": st.column_config.ProgressColumn(
"relevance",
help="记录与查询的匹配程度",
format="%.4f",
min_value=0,
max_value=float(max(scores)) if len(scores) > 0 else 1,
),
"tool": st.column_config.TextColumn("tool", help="tool help text", width="medium")
},
hide_index=True,
use_container_width=True,
)
|