import streamlit as st |
import pandas as pd |
import os |
import base64 |
from pathlib import Path |
import matplotlib.pyplot as plt |
import seaborn as sns |
import numpy as np |
from datasets import load_dataset |
def load_css(): |
"""Load custom CSS""" |
with open('styles/custom.css') as f: |
st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True) |
def create_logo(): |
"""Create and display the logo""" |
from PIL import Image |
import os |
logo_path = "assets/python_huggingface_logo.png" |
if os.path.exists(logo_path): |
image = Image.open(logo_path) |
st.image(image, width=200) |
else: |
st.markdown( |
""" |
<div style="display: flex; justify-content: center; margin-bottom: 20px;"> |
<h2 style="color: #2196F3;">Python & HuggingFace Explorer</h2> |
</div> |
""", |
unsafe_allow_html=True |
) |
def get_dataset_info(dataset_name): |
"""Get basic information about a HuggingFace dataset""" |
if not dataset_name or not isinstance(dataset_name, str): |
st.error("Invalid dataset name") |
return None, None |
try: |
st.info(f"Loading dataset: {dataset_name}...") |
try: |
dataset = load_dataset(dataset_name, streaming=False) |
first_split = next(iter(dataset.keys())) |
data = dataset[first_split] |
except Exception as e: |
st.warning(f"Couldn't load dataset with default configuration: {str(e)}. Trying specific splits...") |
for split_name in ["train", "test", "validation"]: |
try: |
st.info(f"Trying to load '{split_name}' split...") |
data = load_dataset(dataset_name, split=split_name, streaming=False) |
break |
except Exception as split_error: |
if split_name == "validation": |
st.error(f"Failed to load dataset with any standard split: {str(split_error)}") |
return None, None |
continue |
info = { |
"Dataset": dataset_name, |
"Number of examples": len(data), |
"Features": list(data.features.keys()), |
"Sample": data[0] if len(data) > 0 else None |
} |
st.success(f"Successfully loaded dataset with {info['Number of examples']} examples") |
return info, data |
except Exception as e: |
st.error(f"Error loading dataset: {str(e)}") |
if "Connection error" in str(e) or "timeout" in str(e).lower(): |
st.warning("Network issue detected. Please check your internet connection and try again.") |
elif "not found" in str(e).lower(): |
st.warning(f"Dataset '{dataset_name}' not found. Please check the dataset name and try again.") |
return None, None |
def run_code(code): |
"""Run Python code and capture output""" |
import io |
import sys |
import time |
from contextlib import redirect_stdout, redirect_stderr |
stdout_capture = io.StringIO() |
stderr_capture = io.StringIO() |
results = { |
"output": "", |
"error": "", |
"figures": [] |
} |
if len(code) > 100000: |
results["error"] = "Code submission too large. Please reduce the size." |
return results |
dangerous_imports = ['os.system', 'subprocess', 'eval(', 'shutil.rmtree', 'open(', 'with open'] |
for dangerous_import in dangerous_imports: |
if dangerous_import in code: |
results["error"] = f"Potential security risk: {dangerous_import} is not allowed." |
return results |
initial_figs = plt.get_fignums() |
start_time = time.time() |
try: |
safe_globals = { |
'plt': plt, |
'pd': pd, |
'np': np, |
'sns': sns, |
'print': print, |
'__builtins__': __builtins__, |
} |
for module_name in ['datasets', 'transformers', 'sklearn', 'math']: |
try: |
module = __import__(module_name) |
safe_globals[module_name] = module |
except ImportError: |
pass |
with redirect_stdout(stdout_capture), redirect_stderr(stderr_capture): |
exec(code, safe_globals) |
if time.time() - start_time > MAX_EXECUTION_TIME: |
raise TimeoutError("Code execution exceeded maximum allowed time.") |
results["output"] = stdout_capture.getvalue() |
stderr_output = stderr_capture.getvalue() |
if stderr_output: |
if results["output"]: |
results["output"] += "\n\n--- Warnings/Errors ---\n" + stderr_output |
else: |
results["output"] = "--- Warnings/Errors ---\n" + stderr_output |
final_figs = plt.get_fignums() |
new_figs = set(final_figs) - set(initial_figs) |
for fig_num in new_figs: |
fig = plt.figure(fig_num) |
results["figures"].append(fig) |
except Exception as e: |
results["error"] = f"{type(e).__name__}: {str(e)}" |
return results |
def get_dataset_preview(data, max_rows=10): |
"""Convert a HuggingFace dataset to a pandas DataFrame for preview""" |
try: |
df = pd.DataFrame(data[:max_rows]) |
return df |
except Exception as e: |
st.error(f"Error converting dataset to DataFrame: {str(e)}") |
return None |
def generate_basic_stats(data): |
"""Generate basic statistics for a dataset""" |
try: |
df = pd.DataFrame(data) |
column_types = df.dtypes |
stats = {} |
for col in df.columns: |
col_stats = {} |
if pd.api.types.is_numeric_dtype(df[col]): |
col_stats["mean"] = df[col].mean() |
col_stats["median"] = df[col].median() |
col_stats["std"] = df[col].std() |
col_stats["min"] = df[col].min() |
col_stats["max"] = df[col].max() |
col_stats["missing"] = df[col].isna().sum() |
elif pd.api.types.is_string_dtype(df[col]) or pd.api.types.is_object_dtype(df[col]): |
col_stats["unique_values"] = df[col].nunique() |
col_stats["most_common"] = df[col].value_counts().head(5).to_dict() if df[col].nunique() < 100 else "Too many unique values" |
col_stats["missing"] = df[col].isna().sum() |
stats[col] = col_stats |
return stats |
except Exception as e: |
st.error(f"Error generating statistics: {str(e)}") |
return None |
def create_visualization(data, viz_type, x_col=None, y_col=None, hue_col=None): |
"""Create a visualization based on the selected type and columns""" |
try: |
df = pd.DataFrame(data) |
fig, ax = plt.subplots(figsize=(10, 6)) |
if viz_type == "Bar Chart": |
if x_col and y_col: |
sns.barplot(x=x_col, y=y_col, hue=hue_col, data=df, ax=ax) |
else: |
st.warning("Bar charts require both X and Y columns.") |
return None |
elif viz_type == "Line Chart": |
if x_col and y_col: |
sns.lineplot(x=x_col, y=y_col, hue=hue_col, data=df, ax=ax) |
else: |
st.warning("Line charts require both X and Y columns.") |
return None |
elif viz_type == "Scatter Plot": |
if x_col and y_col: |
sns.scatterplot(x=x_col, y=y_col, hue=hue_col, data=df, ax=ax) |
else: |
st.warning("Scatter plots require both X and Y columns.") |
return None |
elif viz_type == "Histogram": |
if x_col: |
sns.histplot(df[x_col], ax=ax) |
else: |
st.warning("Histograms require an X column.") |
return None |
elif viz_type == "Box Plot": |
if x_col and y_col: |
sns.boxplot(x=x_col, y=y_col, hue=hue_col, data=df, ax=ax) |
else: |
st.warning("Box plots require both X and Y columns.") |
return None |
elif viz_type == "Count Plot": |
if x_col: |
sns.countplot(x=x_col, hue=hue_col, data=df, ax=ax) |
else: |
st.warning("Count plots require an X column.") |
return None |
plt.title(f"{viz_type} of {y_col if y_col else ''} vs {x_col if x_col else ''}") |
plt.xlabel(x_col if x_col else "") |
plt.ylabel(y_col if y_col else "") |
plt.tight_layout() |
return fig |
except Exception as e: |
st.error(f"Error creating visualization: {str(e)}") |
return None |
def get_popular_datasets(category=None, limit=10): |
"""Get popular HuggingFace datasets, optionally filtered by category""" |
popular_datasets = { |
"Text": ["glue", "imdb", "squad", "wikitext", "ag_news"], |
"Image": ["cifar10", "cifar100", "mnist", "fashion_mnist", "coco"], |
"Audio": ["common_voice", "librispeech_asr", "voxpopuli", "voxceleb", "audiofolder"], |
"Multimodal": ["conceptual_captions", "flickr8k", "hateful_memes", "nlvr", "vqa"] |
} |
if category and category in popular_datasets: |
return popular_datasets[category][:limit] |
else: |
all_datasets = [] |
for cat_datasets in popular_datasets.values(): |
all_datasets.extend(cat_datasets) |
return all_datasets[:limit] |