|
import streamlit as st |
|
import pandas as pd |
|
import os |
|
import base64 |
|
from pathlib import Path |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
import numpy as np |
|
from datasets import load_dataset |
|
|
|
def load_css(): |
|
"""Load custom CSS""" |
|
with open('styles/custom.css') as f: |
|
st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True) |
|
|
|
def create_logo(): |
|
"""Create and display the logo""" |
|
from PIL import Image |
|
import os |
|
|
|
|
|
logo_path = "assets/python_huggingface_logo.png" |
|
|
|
|
|
if os.path.exists(logo_path): |
|
|
|
image = Image.open(logo_path) |
|
st.image(image, width=200) |
|
else: |
|
|
|
st.markdown( |
|
""" |
|
<div style="display: flex; justify-content: center; margin-bottom: 20px;"> |
|
<h2 style="color: #2196F3;">Python & HuggingFace Explorer</h2> |
|
</div> |
|
""", |
|
unsafe_allow_html=True |
|
) |
|
|
|
def get_dataset_info(dataset_name): |
|
"""Get basic information about a HuggingFace dataset""" |
|
if not dataset_name or not isinstance(dataset_name, str): |
|
st.error("Invalid dataset name") |
|
return None, None |
|
|
|
try: |
|
|
|
st.info(f"Loading dataset: {dataset_name}...") |
|
|
|
try: |
|
|
|
dataset = load_dataset(dataset_name, streaming=False) |
|
|
|
first_split = next(iter(dataset.keys())) |
|
data = dataset[first_split] |
|
except Exception as e: |
|
st.warning(f"Couldn't load dataset with default configuration: {str(e)}. Trying specific splits...") |
|
|
|
for split_name in ["train", "test", "validation"]: |
|
try: |
|
st.info(f"Trying to load '{split_name}' split...") |
|
data = load_dataset(dataset_name, split=split_name, streaming=False) |
|
break |
|
except Exception as split_error: |
|
if split_name == "validation": |
|
st.error(f"Failed to load dataset with any standard split: {str(split_error)}") |
|
return None, None |
|
continue |
|
|
|
|
|
info = { |
|
"Dataset": dataset_name, |
|
"Number of examples": len(data), |
|
"Features": list(data.features.keys()), |
|
"Sample": data[0] if len(data) > 0 else None |
|
} |
|
|
|
st.success(f"Successfully loaded dataset with {info['Number of examples']} examples") |
|
return info, data |
|
except Exception as e: |
|
st.error(f"Error loading dataset: {str(e)}") |
|
if "Connection error" in str(e) or "timeout" in str(e).lower(): |
|
st.warning("Network issue detected. Please check your internet connection and try again.") |
|
elif "not found" in str(e).lower(): |
|
st.warning(f"Dataset '{dataset_name}' not found. Please check the dataset name and try again.") |
|
return None, None |
|
|
|
def run_code(code): |
|
"""Run Python code and capture output""" |
|
import io |
|
import sys |
|
import time |
|
from contextlib import redirect_stdout, redirect_stderr |
|
|
|
|
|
stdout_capture = io.StringIO() |
|
stderr_capture = io.StringIO() |
|
|
|
|
|
results = { |
|
"output": "", |
|
"error": "", |
|
"figures": [] |
|
} |
|
|
|
|
|
if len(code) > 100000: |
|
results["error"] = "Code submission too large. Please reduce the size." |
|
return results |
|
|
|
|
|
dangerous_imports = ['os.system', 'subprocess', 'eval(', 'shutil.rmtree', 'open(', 'with open'] |
|
for dangerous_import in dangerous_imports: |
|
if dangerous_import in code: |
|
results["error"] = f"Potential security risk: {dangerous_import} is not allowed." |
|
return results |
|
|
|
|
|
initial_figs = plt.get_fignums() |
|
|
|
|
|
MAX_EXECUTION_TIME = 30 |
|
start_time = time.time() |
|
|
|
try: |
|
|
|
safe_globals = { |
|
'plt': plt, |
|
'pd': pd, |
|
'np': np, |
|
'sns': sns, |
|
'print': print, |
|
'__builtins__': __builtins__, |
|
} |
|
|
|
|
|
for module_name in ['datasets', 'transformers', 'sklearn', 'math']: |
|
try: |
|
module = __import__(module_name) |
|
safe_globals[module_name] = module |
|
except ImportError: |
|
pass |
|
|
|
|
|
with redirect_stdout(stdout_capture), redirect_stderr(stderr_capture): |
|
|
|
exec(code, safe_globals) |
|
|
|
if time.time() - start_time > MAX_EXECUTION_TIME: |
|
raise TimeoutError("Code execution exceeded maximum allowed time.") |
|
|
|
|
|
results["output"] = stdout_capture.getvalue() |
|
|
|
|
|
stderr_output = stderr_capture.getvalue() |
|
if stderr_output: |
|
if results["output"]: |
|
results["output"] += "\n\n--- Warnings/Errors ---\n" + stderr_output |
|
else: |
|
results["output"] = "--- Warnings/Errors ---\n" + stderr_output |
|
|
|
|
|
final_figs = plt.get_fignums() |
|
new_figs = set(final_figs) - set(initial_figs) |
|
|
|
for fig_num in new_figs: |
|
fig = plt.figure(fig_num) |
|
results["figures"].append(fig) |
|
|
|
except Exception as e: |
|
|
|
results["error"] = f"{type(e).__name__}: {str(e)}" |
|
|
|
return results |
|
|
|
def get_dataset_preview(data, max_rows=10): |
|
"""Convert a HuggingFace dataset to a pandas DataFrame for preview""" |
|
try: |
|
|
|
df = pd.DataFrame(data[:max_rows]) |
|
return df |
|
except Exception as e: |
|
st.error(f"Error converting dataset to DataFrame: {str(e)}") |
|
return None |
|
|
|
def generate_basic_stats(data): |
|
"""Generate basic statistics for a dataset""" |
|
try: |
|
|
|
df = pd.DataFrame(data) |
|
|
|
|
|
column_types = df.dtypes |
|
|
|
|
|
stats = {} |
|
|
|
for col in df.columns: |
|
col_stats = {} |
|
|
|
|
|
if pd.api.types.is_numeric_dtype(df[col]): |
|
col_stats["mean"] = df[col].mean() |
|
col_stats["median"] = df[col].median() |
|
col_stats["std"] = df[col].std() |
|
col_stats["min"] = df[col].min() |
|
col_stats["max"] = df[col].max() |
|
col_stats["missing"] = df[col].isna().sum() |
|
|
|
elif pd.api.types.is_string_dtype(df[col]) or pd.api.types.is_object_dtype(df[col]): |
|
col_stats["unique_values"] = df[col].nunique() |
|
col_stats["most_common"] = df[col].value_counts().head(5).to_dict() if df[col].nunique() < 100 else "Too many unique values" |
|
col_stats["missing"] = df[col].isna().sum() |
|
|
|
stats[col] = col_stats |
|
|
|
return stats |
|
except Exception as e: |
|
st.error(f"Error generating statistics: {str(e)}") |
|
return None |
|
|
|
def create_visualization(data, viz_type, x_col=None, y_col=None, hue_col=None): |
|
"""Create a visualization based on the selected type and columns""" |
|
try: |
|
df = pd.DataFrame(data) |
|
|
|
fig, ax = plt.subplots(figsize=(10, 6)) |
|
|
|
if viz_type == "Bar Chart": |
|
if x_col and y_col: |
|
sns.barplot(x=x_col, y=y_col, hue=hue_col, data=df, ax=ax) |
|
else: |
|
st.warning("Bar charts require both X and Y columns.") |
|
return None |
|
|
|
elif viz_type == "Line Chart": |
|
if x_col and y_col: |
|
sns.lineplot(x=x_col, y=y_col, hue=hue_col, data=df, ax=ax) |
|
else: |
|
st.warning("Line charts require both X and Y columns.") |
|
return None |
|
|
|
elif viz_type == "Scatter Plot": |
|
if x_col and y_col: |
|
sns.scatterplot(x=x_col, y=y_col, hue=hue_col, data=df, ax=ax) |
|
else: |
|
st.warning("Scatter plots require both X and Y columns.") |
|
return None |
|
|
|
elif viz_type == "Histogram": |
|
if x_col: |
|
sns.histplot(df[x_col], ax=ax) |
|
else: |
|
st.warning("Histograms require an X column.") |
|
return None |
|
|
|
elif viz_type == "Box Plot": |
|
if x_col and y_col: |
|
sns.boxplot(x=x_col, y=y_col, hue=hue_col, data=df, ax=ax) |
|
else: |
|
st.warning("Box plots require both X and Y columns.") |
|
return None |
|
|
|
elif viz_type == "Count Plot": |
|
if x_col: |
|
sns.countplot(x=x_col, hue=hue_col, data=df, ax=ax) |
|
else: |
|
st.warning("Count plots require an X column.") |
|
return None |
|
|
|
|
|
plt.title(f"{viz_type} of {y_col if y_col else ''} vs {x_col if x_col else ''}") |
|
plt.xlabel(x_col if x_col else "") |
|
plt.ylabel(y_col if y_col else "") |
|
plt.tight_layout() |
|
|
|
return fig |
|
|
|
except Exception as e: |
|
st.error(f"Error creating visualization: {str(e)}") |
|
return None |
|
|
|
def get_popular_datasets(category=None, limit=10): |
|
"""Get popular HuggingFace datasets, optionally filtered by category""" |
|
popular_datasets = { |
|
"Text": ["glue", "imdb", "squad", "wikitext", "ag_news"], |
|
"Image": ["cifar10", "cifar100", "mnist", "fashion_mnist", "coco"], |
|
"Audio": ["common_voice", "librispeech_asr", "voxpopuli", "voxceleb", "audiofolder"], |
|
"Multimodal": ["conceptual_captions", "flickr8k", "hateful_memes", "nlvr", "vqa"] |
|
} |
|
|
|
if category and category in popular_datasets: |
|
return popular_datasets[category][:limit] |
|
else: |
|
|
|
all_datasets = [] |
|
for cat_datasets in popular_datasets.values(): |
|
all_datasets.extend(cat_datasets) |
|
return all_datasets[:limit] |
|
|