|
from collections import Counter |
|
from itertools import count, groupby, islice |
|
from operator import itemgetter |
|
from typing import Any, Iterable, TypeVar |
|
|
|
import gradio as gr |
|
import requests |
|
import pandas as pd |
|
from datasets import Features |
|
from gradio_huggingfacehub_search import HuggingfaceHubSearch |
|
from requests.adapters import HTTPAdapter, Retry |
|
|
|
from analyze import PresidioEntity, analyzer, get_column_description, get_columns_with_strings, mask, presidio_scan_entities |
|
|
|
MAX_ROWS = 100 |
|
T = TypeVar("T") |
|
session = requests.Session() |
|
retries = Retry(total=5, backoff_factor=1, status_forcelist=[502, 503, 504]) |
|
session.mount('http://', HTTPAdapter(max_retries=retries)) |
|
DEFAULT_PRESIDIO_ENTITIES = sorted([ |
|
'PERSON', |
|
'CREDIT_CARD', |
|
'US_SSN', |
|
'US_DRIVER_LICENSE', |
|
'PHONE_NUMBER', |
|
'US_PASSPORT', |
|
'EMAIL_ADDRESS', |
|
'IP_ADDRESS', |
|
'US_BANK_NUMBER', |
|
'IBAN_CODE', |
|
'EMAIL', |
|
]) |
|
|
|
def stream_rows(dataset: str, config: str, split: str) -> Iterable[dict[str, Any]]: |
|
batch_size = 100 |
|
for i in count(): |
|
rows_resp = session.get(f"https://datasets-server.huggingface.co/rows?dataset={dataset}&config={config}&split={split}&offset={i * batch_size}&length={batch_size}", timeout=10).json() |
|
if "error" in rows_resp: |
|
raise RuntimeError(rows_resp["error"]) |
|
if not rows_resp["rows"]: |
|
break |
|
for row_item in rows_resp["rows"]: |
|
yield row_item["row"] |
|
|
|
class track_iter: |
|
|
|
def __init__(self, it: Iterable[T]): |
|
self.it = it |
|
self.next_idx = 0 |
|
|
|
def __iter__(self) -> T: |
|
for item in self.it: |
|
self.next_idx += 1 |
|
yield item |
|
|
|
|
|
def presidio_report(presidio_entities: list[PresidioEntity], next_row_idx: int, num_rows: int) -> dict[str, float]: |
|
title = f"Scan finished: {len(presidio_entities)} entities found" if num_rows == next_row_idx else "Scan in progress..." |
|
counter = Counter([title] * next_row_idx) |
|
for row_idx, presidio_entities_per_row in groupby(presidio_entities, itemgetter("row_idx")): |
|
counter.update(set("% of rows with " + presidio_entity["type"] for presidio_entity in presidio_entities_per_row)) |
|
return dict((presidio_entity_type, presidio_entity_type_row_count / num_rows) for presidio_entity_type, presidio_entity_type_row_count in counter.most_common()) |
|
|
|
|
|
def analyze_dataset(dataset: str, enabled_presidio_entities: list[str] = DEFAULT_PRESIDIO_ENTITIES, show_texts_without_masks: bool = False) -> pd.DataFrame: |
|
info_resp = session.get(f"https://datasets-server.huggingface.co/info?dataset={dataset}", timeout=3).json() |
|
if "error" in info_resp: |
|
yield "β " + info_resp["error"], pd.DataFrame() |
|
return |
|
config = "default" if "default" in info_resp["dataset_info"] else next(iter(info_resp["dataset_info"])) |
|
features = Features.from_dict(info_resp["dataset_info"][config]["features"]) |
|
split = "train" if "train" in info_resp["dataset_info"][config]["splits"] else next(iter(info_resp["dataset_info"][config]["splits"])) |
|
num_rows = min(info_resp["dataset_info"][config]["splits"][split]["num_examples"], MAX_ROWS) |
|
scanned_columns = get_columns_with_strings(features) |
|
columns_descriptions = [ |
|
get_column_description(column_name, features[column_name]) for column_name in scanned_columns |
|
] |
|
rows = track_iter(islice(stream_rows(dataset, config, split), MAX_ROWS)) |
|
presidio_entities = [] |
|
for presidio_entity in presidio_scan_entities( |
|
rows, scanned_columns=scanned_columns, columns_descriptions=columns_descriptions |
|
): |
|
if not show_texts_without_masks: |
|
presidio_entity["text"] = mask(presidio_entity["text"]) |
|
if presidio_entity["type"] in enabled_presidio_entities: |
|
presidio_entities.append(presidio_entity) |
|
yield presidio_report(presidio_entities, next_row_idx=rows.next_idx, num_rows=num_rows), pd.DataFrame(presidio_entities) |
|
yield presidio_report(presidio_entities, next_row_idx=rows.next_idx, num_rows=num_rows), pd.DataFrame(presidio_entities) |
|
|
|
with gr.Blocks(css=".table {border-collapse: separate}") as demo: |
|
gr.Markdown("# Scan datasets using Presidio") |
|
gr.Markdown("The space takes an HF dataset name as an input, and returns the list of entities detected by Presidio in the first samples.") |
|
inputs = [ |
|
HuggingfaceHubSearch( |
|
label="Hub Dataset ID", |
|
placeholder="Search for dataset id on Huggingface", |
|
search_type="dataset", |
|
), |
|
gr.CheckboxGroup( |
|
label="Presidio entities", |
|
choices=sorted(analyzer.get_supported_entities()), |
|
value=DEFAULT_PRESIDIO_ENTITIES, |
|
interactive=True, |
|
), |
|
gr.Checkbox(label="Show texts without masks", value=False), |
|
] |
|
button = gr.Button("Run Presidio Scan") |
|
outputs = [ |
|
gr.Label(show_label=False), |
|
gr.DataFrame(), |
|
] |
|
button.click(analyze_dataset, inputs, outputs) |
|
gr.Examples( |
|
[ |
|
["microsoft/orca-math-word-problems-200k"], |
|
["tatsu-lab/alpaca"], |
|
["Anthropic/hh-rlhf"], |
|
["OpenAssistant/oasst1"], |
|
["sidhq/email-thread-summary"], |
|
["lhoestq/fake_name_and_ssn"] |
|
], |
|
inputs, |
|
outputs, |
|
fn=analyze_dataset, |
|
run_on_click=True, |
|
cache_examples=False, |
|
) |
|
|
|
demo.launch() |
|
|