|
from functools import lru_cache |
|
|
|
import duckdb |
|
import gradio as gr |
|
import polars as pl |
|
from datasets import load_dataset |
|
from gradio_huggingfacehub_search import HuggingfaceHubSearch |
|
from model2vec import StaticModel |
|
|
|
global df |
|
|
|
|
|
model_name = "minishlab/potion-base-8M" |
|
model = StaticModel.from_pretrained(model_name) |
|
|
|
|
|
def get_iframe(hub_repo_id): |
|
if not hub_repo_id: |
|
raise ValueError("Hub repo id is required") |
|
url = f"https://huggingface.co/datasets/{hub_repo_id}/embed/viewer" |
|
iframe = f""" |
|
<iframe |
|
src="{url}" |
|
frameborder="0" |
|
width="100%" |
|
height="600px" |
|
></iframe> |
|
""" |
|
return iframe |
|
|
|
|
|
def load_dataset_from_hub(hub_repo_id: str): |
|
gr.Info(message="Loading dataset...") |
|
ds = load_dataset(hub_repo_id) |
|
|
|
|
|
def get_columns(hub_repo_id: str, split: str): |
|
ds = load_dataset(hub_repo_id) |
|
ds_split = ds[split] |
|
return gr.Dropdown( |
|
choices=ds_split.column_names, |
|
value=ds_split.column_names[0], |
|
label="Select a column", |
|
visible=True, |
|
) |
|
|
|
|
|
def get_splits(hub_repo_id: str): |
|
ds = load_dataset(hub_repo_id) |
|
splits = list(ds.keys()) |
|
return gr.Dropdown( |
|
choices=splits, value=splits[0], label="Select a split", visible=True |
|
) |
|
|
|
|
|
@lru_cache |
|
def vectorize_dataset(hub_repo_id: str, split: str, column: str): |
|
gr.Info("Vectorizing dataset...") |
|
ds = load_dataset(hub_repo_id) |
|
df = ds[split].to_polars() |
|
embeddings = model.encode(df[column].cast(str), max_length=512) |
|
return embeddings |
|
|
|
|
|
def run_query(hub_repo_id: str, query: str, split: str, column: str): |
|
embeddings = vectorize_dataset(hub_repo_id, split, column) |
|
ds = load_dataset(hub_repo_id) |
|
df = ds[split].to_polars() |
|
df = df.with_columns(pl.Series(embeddings).alias("embeddings")) |
|
try: |
|
vector = model.encode(query) |
|
df_results = duckdb.sql( |
|
query=f""" |
|
SELECT * |
|
FROM df |
|
ORDER BY array_cosine_distance(embeddings, {vector.tolist()}::FLOAT[256]) |
|
LIMIT 5 |
|
""" |
|
).to_df() |
|
return gr.Dataframe(df_results, visible=True) |
|
except Exception as e: |
|
raise gr.Error(f"Error running query: {e}") |
|
|
|
|
|
def hide_components(): |
|
return [ |
|
gr.Dropdown(visible=False), |
|
gr.Dropdown(visible=False), |
|
gr.Textbox(visible=False), |
|
gr.Button(visible=False), |
|
gr.Dataframe(visible=False), |
|
] |
|
|
|
|
|
def partial_hide_components(): |
|
return [ |
|
gr.Textbox(visible=False), |
|
gr.Button(visible=False), |
|
gr.Dataframe(visible=False), |
|
] |
|
|
|
|
|
def show_components(): |
|
return [ |
|
gr.Textbox(visible=True, label="Query"), |
|
gr.Button(visible=True, value="Search"), |
|
] |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.HTML( |
|
""" |
|
<h1>Vector Search any Hugging Face Dataset</h1> |
|
<p> |
|
This app allows you to vector search any Hugging Face dataset. |
|
You can search for the nearest neighbors of a query vector, or |
|
perform a similarity search on a dataframe. |
|
</p> |
|
""" |
|
) |
|
with gr.Row(): |
|
with gr.Column(): |
|
search_in = HuggingfaceHubSearch( |
|
label="Search Huggingface Hub", |
|
placeholder="Search for models on Huggingface", |
|
search_type="dataset", |
|
sumbit_on_select=True, |
|
) |
|
with gr.Row(): |
|
search_out = gr.HTML(label="Search Results") |
|
|
|
with gr.Row(): |
|
split_dropdown = gr.Dropdown(label="Select a split", visible=False) |
|
column_dropdown = gr.Dropdown(label="Select a column", visible=False) |
|
with gr.Row(): |
|
query_input = gr.Textbox(label="Query", visible=False) |
|
|
|
btn_run = gr.Button("Search", visible=False) |
|
|
|
results_output = gr.Dataframe(label="Results", visible=False) |
|
|
|
search_in.submit(get_iframe, inputs=search_in, outputs=search_out).then( |
|
fn=load_dataset_from_hub, |
|
inputs=search_in, |
|
show_progress=True, |
|
).then( |
|
fn=hide_components, |
|
outputs=[split_dropdown, column_dropdown, query_input, btn_run, results_output], |
|
).then(fn=get_splits, inputs=search_in, outputs=split_dropdown).then( |
|
fn=get_columns, inputs=[search_in, split_dropdown], outputs=column_dropdown |
|
) |
|
|
|
split_dropdown.change( |
|
fn=get_columns, inputs=[search_in, split_dropdown], outputs=column_dropdown |
|
) |
|
|
|
column_dropdown.change( |
|
fn=partial_hide_components, |
|
outputs=[query_input, btn_run, results_output], |
|
).then(fn=show_components, outputs=[query_input, btn_run]) |
|
|
|
btn_run.click( |
|
fn=run_query, |
|
inputs=[search_in, query_input, split_dropdown, column_dropdown], |
|
outputs=results_output, |
|
) |
|
|
|
demo.launch() |
|
|