File size: 4,941 Bytes
9ab9c77
 
 
 
 
8911626
9ab9c77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b97944f
9ab9c77
7c04c37
 
9ab9c77
 
 
 
 
 
ec1c57d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a302d1e
ec1c57d
 
a302d1e
ec1c57d
 
 
9ab9c77
 
7c04c37
9ab9c77
ec1c57d
9ab9c77
 
 
 
 
 
ec1c57d
 
 
a302d1e
 
 
 
 
 
ec1c57d
a302d1e
ec1c57d
 
 
 
 
 
 
 
 
 
 
7c04c37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ab9c77
a302d1e
 
ec1c57d
 
 
7c04c37
8911626
7c04c37
8911626
7c04c37
ce422d8
a302d1e
7c04c37
 
9ab9c77
ec1c57d
 
 
 
 
a302d1e
8911626
 
 
 
 
 
 
 
ec1c57d
8911626
 
 
 
 
 
 
 
 
a302d1e
 
8911626
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import os
import sys
import faiss
import numpy as np
import streamlit as st
import pandas as pd
from text2vec import SentenceModel
from src.jsonl_Indexer import JSONLIndexer

def get_cli_args():
    args = {}
    argv = sys.argv[2:] if len(sys.argv) > 2 else []
    for arg in argv:
        if '=' in arg:
            key, value = arg.split('=', 1)
            args[key.strip()] = value.strip()
    return args

cli_args = get_cli_args()

DEFAULT_CONFIG = {
    'model_path': 'BAAI/bge-base-en-v1.5',
    'dataset_path': 'tool-embedding.jsonl',
    'vector_size': 768,
    'embedding_field': 'embedding',
    'id_field': 'id'
}

config = DEFAULT_CONFIG.copy()
config.update(cli_args)
config['vector_size'] = int(config['vector_size'])

# ---------------------------
# 缓存数据集加载函数(避免每次运行时重复下载数据)
# ---------------------------
@st.cache_data
def load_tools_datasets():
    from datasets import load_dataset, concatenate_datasets
    ds1 = load_dataset("mangopy/ToolRet-Tools", "code")
    ds2 = load_dataset("mangopy/ToolRet-Tools", "customized")
    ds3 = load_dataset("mangopy/ToolRet-Tools", "web")
    ds = concatenate_datasets([ds1['tools'], ds2['tools'], ds3['tools']])
    # 重命名'id'字段为'tool'
    ds = ds.rename_columns({'id': 'tool'})
    return ds

ds = load_tools_datasets()
df2 = ds.to_pandas()
# 如果数据量较大,可以通过设置索引加速后续的合并操作
df2.set_index('tool', inplace=True)

# ---------------------------
# 缓存模型加载函数
# ---------------------------
@st.cache_resource
def get_model(model_path: str = config['model_path']):
    return SentenceModel(model_path)

# 缓存检索器创建函数
@st.cache_resource
def create_retriever(vector_sz: int, dataset_path: str, embedding_field: str, id_field: str, _model):
    retriever = JSONLIndexer(vector_sz=vector_sz, model=_model)
    retriever.load_jsonl(dataset_path, embedding_field=embedding_field, id_field=id_field)
    return retriever

# ---------------------------
# 侧边栏配置
# ---------------------------
st.sidebar.markdown("<div style='text-align: center;'><h3>📄 Model Configuration</h3></div>", unsafe_allow_html=True)
model_options = ["BAAI/bge-base-en-v1.5"]
selected_model = st.sidebar.selectbox("Select Model", model_options)
st.sidebar.write("Selected model:", selected_model)
st.sidebar.write("Embedding length: 768")

# 使用下拉框选中的模型(避免重复加载)
model = get_model(selected_model)
retriever = create_retriever(
    config['vector_size'],
    config['dataset_path'],
    config['embedding_field'],
    config['id_field'],
    _model=model
)

# ---------------------------
# 界面样式设置
# ---------------------------
st.markdown("""
    <style>
    .search-container {
        display: flex;
        justify-content: center;
        align-items: center;
        gap: 10px;
        margin-top: 20px;
    }
    .search-box input {
        width: 500px !important;
        height: 45px;
        font-size: 16px;
        border-radius: 25px;
        padding-left: 15px;
    }
    .search-btn button {
        height: 45px;
        font-size: 16px;
        border-radius: 25px;
    }
    </style>
""", unsafe_allow_html=True)

st.markdown("<h1 style='text-align: center;'>🔍 Tool Retrieval</h1>", unsafe_allow_html=True)

# ---------------------------
# 主体检索区域
# ---------------------------
col1, col2 = st.columns([4, 1])
with col1:
    query = st.text_input("", placeholder="Enter your search query...", key="search_query", label_visibility="collapsed")
with col2:
    search_clicked = st.button("🔎 Search", use_container_width=True)

top_k = st.slider("Top-K tools", 1, 100, 50, help="Choose the number of results to display")

if search_clicked and query:
    rec_ids, scores = retriever.search_return_id(query, top_k)
    # 构建检索结果 DataFrame
    df1 = pd.DataFrame({"relevance": scores, "tool": rec_ids})
    # 使用 join 加速合并(前提是 df2 已设置好索引)
    results_df = df1.join(df2, on='tool', how='left').reset_index(drop=False)
    
    st.subheader("🗂️ Retrieval results")
    
    styled_results = results_df.style.apply(
        lambda x: [
            "background-color: #F7F7F7" if i % 2 == 0 else "background-color: #FFFFFF"
            for i in range(len(x))
        ],
        axis=0,
    ).format({"relevance": "{:.4f}"})
    
    st.dataframe(
        styled_results,
        column_config={
            "relevance": st.column_config.ProgressColumn(
                "relevance",
                help="记录与查询的匹配程度",
                format="%.4f",
                min_value=0,
                max_value=float(max(scores)) if len(scores) > 0 else 1,
            ),
            "tool": st.column_config.TextColumn("tool", help="tool help text", width="medium")
        },
        hide_index=True,
        use_container_width=True,
    )