asoria HF staff commited on
Commit
e6bb5bf
·
1 Parent(s): a52826f

Add application file

Browse files
Files changed (3) hide show
  1. README.md +4 -3
  2. app.py +166 -0
  3. requirements.txt +9 -0
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
  title: Datasets Similarity Tool
3
- emoji: 📉
4
- colorFrom: pink
5
- colorTo: pink
6
  sdk: gradio
7
  sdk_version: 4.19.2
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: Datasets Similarity Tool
3
+ emoji: 🐨
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: gradio
7
  sdk_version: 4.19.2
8
  app_file: app.py
9
  pinned: false
10
+ startup_duration_timeout: 1h
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+ import os
3
+
4
+ import pandas as pd
5
+ from httpx import Client
6
+ from huggingface_hub.utils import logging
7
+ from functools import lru_cache
8
+ from tqdm.contrib.concurrent import thread_map
9
+ from huggingface_hub import HfApi
10
+ import gradio as gr
11
+ from sentence_transformers import SentenceTransformer
12
+ import faiss
13
+ import numpy as np
14
+ from urllib.parse import quote
15
+
16
+ load_dotenv()
17
+
18
+ HF_TOKEN = os.getenv("HF_TOKEN")
19
+ assert HF_TOKEN is not None, "You need to set HF_TOKEN in your environment variables"
20
+
21
+ BASE_DATASETS_SERVER_URL = "https://datasets-server.huggingface.co"
22
+
23
+ logger = logging.get_logger(__name__)
24
+ headers = {
25
+ "authorization": f"Bearer ${HF_TOKEN}",
26
+ }
27
+ client = Client(headers=headers)
28
+ api = HfApi(token=HF_TOKEN)
29
+
30
+
31
+ def get_first_config_name(dataset: str):
32
+ try:
33
+ resp = client.get(f"{BASE_DATASETS_SERVER_URL}/splits?dataset={dataset}")
34
+ data = resp.json()
35
+ return data["splits"][0]["config"][0]
36
+ except Exception as e:
37
+ logger.error(f"Failed to get splits for {dataset}: {e}")
38
+ return None
39
+
40
+
41
+ def datasets_server_valid_rows(dataset: str):
42
+ try:
43
+ resp = client.get(f"{BASE_DATASETS_SERVER_URL}/is-valid?dataset={dataset}")
44
+ return resp.json()["viewer"]
45
+ except Exception as e:
46
+ logger.error(f"Failed to get is-valid for {dataset}: {e}")
47
+ return None
48
+
49
+
50
+ def dataset_is_valid(dataset):
51
+ return dataset if datasets_server_valid_rows(dataset.id) else None
52
+
53
+
54
+ def get_first_config_and_split_name(hub_id: str):
55
+ try:
56
+ resp = client.get(
57
+ f"https://datasets-server.huggingface.co/splits?dataset={hub_id}"
58
+ )
59
+
60
+ data = resp.json()
61
+ return data["splits"][0]["config"], data["splits"][0]["split"]
62
+ except Exception as e:
63
+ logger.error(f"Failed to get splits for {hub_id}: {e}")
64
+ return None
65
+
66
+
67
+ def get_dataset_info(hub_id: str, config: str | None = None):
68
+ if config is None:
69
+ config = get_first_config_and_split_name(hub_id)
70
+ if config is None:
71
+ return None
72
+ else:
73
+ config = config[0]
74
+ resp = client.get(
75
+ f"{BASE_DATASETS_SERVER_URL}/info?dataset={hub_id}&config={config}"
76
+ )
77
+ resp.raise_for_status()
78
+ return resp.json()
79
+
80
+
81
+ def dataset_with_info(dataset):
82
+ try:
83
+ if info := get_dataset_info(dataset.id):
84
+ columns = info.get("dataset_info", {}).get("features", {})
85
+ if columns is not None:
86
+ return {
87
+ "dataset": dataset.id,
88
+ "column_names": ','.join(list(columns.keys())),
89
+ "text": f"{dataset.id}-{','.join(list(columns.keys()))}",
90
+ "likes": dataset.likes,
91
+ "downloads": dataset.downloads,
92
+ "created_at": dataset.created_at,
93
+ "tags": dataset.tags,
94
+ }
95
+ except Exception as e:
96
+ logger.error(f"Failed to get info for {dataset.id}: {e}")
97
+ return None
98
+
99
+
100
+
101
+ @lru_cache(maxsize=100)
102
+ def prep_data():
103
+ datasets = list(api.list_datasets(limit=None, sort="createdAt", direction=-1))
104
+ print(f"Found {len(datasets)} datasets in the hub.")
105
+ logger.info(f"Found {len(datasets)} datasets.")
106
+ has_server = thread_map(
107
+ dataset_is_valid,
108
+ datasets,
109
+ )
110
+ datasets_with_server = [x for x in has_server if x is not None]
111
+ print(f"Found {len(datasets_with_server)} datasets with server.")
112
+ dataset_infos = thread_map(dataset_with_info, datasets_with_server)
113
+ dataset_infos = [x for x in dataset_infos if x is not None]
114
+ print(f"Found {len(dataset_infos)} datasets with server data.")
115
+ print(dataset_infos[0])
116
+ return dataset_infos
117
+
118
+ all_datasets = prep_data()
119
+ all_datasets_df = pd.DataFrame.from_dict(all_datasets)
120
+ print(all_datasets_df.head())
121
+ text = all_datasets_df['text']
122
+ encoder = SentenceTransformer("paraphrase-mpnet-base-v2")
123
+ vectors = encoder.encode(text)
124
+ vector_dimension = vectors.shape[1]
125
+ print("Start indexing")
126
+ index = faiss.IndexFlatL2(vector_dimension)
127
+ faiss.normalize_L2(vectors)
128
+ index.add(vectors)
129
+ print("Indexing done")
130
+
131
+ def render_model_hub_link(hub_id):
132
+ link = f"https://huggingface.co/datasets/{quote(hub_id)}"
133
+ return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{hub_id}</a>'
134
+
135
+
136
+ def search(dataset_name):
137
+ print(f"start search for {dataset_name}")
138
+ try:
139
+ dataset_row = all_datasets_df[all_datasets_df.dataset == dataset_name].iloc[0]
140
+ print(dataset_row)
141
+ except IndexError:
142
+ return pd.DataFrame([{"error": f"❌ Dataset does not exist or is not supported"}])
143
+ text = dataset_row["text"]
144
+ search_vector = encoder.encode(text)
145
+ _vector = np.array([search_vector])
146
+ faiss.normalize_L2(_vector)
147
+ distances, ann = index.search(_vector, k=20)
148
+ results = pd.DataFrame({'distances': distances[0], 'ann': ann[0]})
149
+ print("results for distances and ann")
150
+ print(results)
151
+ merge = pd.merge(results, all_datasets_df, left_on="ann", right_index=True)
152
+ print("resultst for merged df (distances,ann, dataset info)")
153
+ merge["dataset"] = merge["dataset"].apply(render_model_hub_link)
154
+ return merge
155
+
156
+ with gr.Blocks() as demo:
157
+ gr.Markdown("# Search similar Datasets on Hugging Face")
158
+ gr.Markdown("This space shows similar dataset based on column name and types")
159
+ dataset_name = gr.Textbox(
160
+ "asoria/bolivian-population", label="Dataset Name"
161
+ )
162
+ btn = gr.Button("Show similar datasets")
163
+ df = gr.DataFrame(datatype="markdown")
164
+ btn.click(search, dataset_name, df)
165
+
166
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ Pyarrow
2
+ gradio==4.18.0
3
+ httpx
4
+ huggingface_hub
5
+ pandas
6
+ python-dotenv
7
+ datasets
8
+ sentence-transformers
9
+ faiss-cpu