File size: 11,013 Bytes
2a64443
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
import streamlit as st
import pandas as pd
import os
import base64
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from datasets import load_dataset

def load_css():
    """Load custom CSS"""
    with open('styles/custom.css') as f:
        st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)

def create_logo():
    """Create and display the logo"""
    from PIL import Image
    import os
    
    # Path to the logo image
    logo_path = "assets/python_huggingface_logo.png"
    
    # Check if the logo exists
    if os.path.exists(logo_path):
        # Display the logo image
        image = Image.open(logo_path)
        st.image(image, width=200)
    else:
        # Fallback to text if image is not found
        st.markdown(
            """
            <div style="display: flex; justify-content: center; margin-bottom: 20px;">
                <h2 style="color: #2196F3;">Python & HuggingFace Explorer</h2>
            </div>
            """,
            unsafe_allow_html=True
        )

def get_dataset_info(dataset_name):
    """Get basic information about a HuggingFace dataset"""
    if not dataset_name or not isinstance(dataset_name, str):
        st.error("Invalid dataset name")
        return None, None
        
    try:
        # Attempt to load the dataset with default configuration
        st.info(f"Loading dataset: {dataset_name}...")
        
        try:
            # First try to load the dataset with streaming=False for better compatibility
            dataset = load_dataset(dataset_name, streaming=False)
            # Get the first split
            first_split = next(iter(dataset.keys()))
            data = dataset[first_split]
        except Exception as e:
            st.warning(f"Couldn't load dataset with default configuration: {str(e)}. Trying specific splits...")
            # If that fails, try loading with specific splits
            for split_name in ["train", "test", "validation"]:
                try:
                    st.info(f"Trying to load '{split_name}' split...")
                    data = load_dataset(dataset_name, split=split_name, streaming=False)
                    break
                except Exception as split_error:
                    if split_name == "validation":  # Last attempt
                        st.error(f"Failed to load dataset with any standard split: {str(split_error)}")
                        return None, None
                    continue
        
        # Get basic info
        info = {
            "Dataset": dataset_name,
            "Number of examples": len(data),
            "Features": list(data.features.keys()),
            "Sample": data[0] if len(data) > 0 else None
        }
        
        st.success(f"Successfully loaded dataset with {info['Number of examples']} examples")
        return info, data
    except Exception as e:
        st.error(f"Error loading dataset: {str(e)}")
        if "Connection error" in str(e) or "timeout" in str(e).lower():
            st.warning("Network issue detected. Please check your internet connection and try again.")
        elif "not found" in str(e).lower():
            st.warning(f"Dataset '{dataset_name}' not found. Please check the dataset name and try again.")
        return None, None

def run_code(code):
    """Run Python code and capture output"""
    import io
    import sys
    import time
    from contextlib import redirect_stdout, redirect_stderr
    
    # Create StringIO objects to capture stdout and stderr
    stdout_capture = io.StringIO()
    stderr_capture = io.StringIO()
    
    # Dictionary for storing results
    results = {
        "output": "",
        "error": "",
        "figures": []
    }
    
    # Safety check - limit code size
    if len(code) > 100000:
        results["error"] = "Code submission too large. Please reduce the size."
        return results
        
    # Basic security check - this is not comprehensive
    dangerous_imports = ['os.system', 'subprocess', 'eval(', 'shutil.rmtree', 'open(', 'with open']
    for dangerous_import in dangerous_imports:
        if dangerous_import in code:
            results["error"] = f"Potential security risk: {dangerous_import} is not allowed."
            return results
    
    # Capture current figures to avoid including existing ones
    initial_figs = plt.get_fignums()
    
    # Set execution timeout
    MAX_EXECUTION_TIME = 30  # seconds
    start_time = time.time()
    
    try:
        # Create a restricted globals dictionary
        safe_globals = {
            'plt': plt,
            'pd': pd,
            'np': np,
            'sns': sns,
            'print': print,
            '__builtins__': __builtins__,
        }
        
        # Add common data science libraries
        for module_name in ['datasets', 'transformers', 'sklearn', 'math']:
            try:
                module = __import__(module_name)
                safe_globals[module_name] = module
            except ImportError:
                pass  # Module not available
        
        # Redirect stdout and stderr
        with redirect_stdout(stdout_capture), redirect_stderr(stderr_capture):
            # Execute the code with timeout check
            exec(code, safe_globals)
            
            if time.time() - start_time > MAX_EXECUTION_TIME:
                raise TimeoutError("Code execution exceeded maximum allowed time.")
        
        # Get the captured output
        results["output"] = stdout_capture.getvalue()
        
        # Also capture stderr
        stderr_output = stderr_capture.getvalue()
        if stderr_output:
            if results["output"]:
                results["output"] += "\n\n--- Warnings/Errors ---\n" + stderr_output
            else:
                results["output"] = "--- Warnings/Errors ---\n" + stderr_output
        
        # Capture any figures that were created
        final_figs = plt.get_fignums()
        new_figs = set(final_figs) - set(initial_figs)
        
        for fig_num in new_figs:
            fig = plt.figure(fig_num)
            results["figures"].append(fig)
    
    except Exception as e:
        # Capture the error
        results["error"] = f"{type(e).__name__}: {str(e)}"
    
    return results

def get_dataset_preview(data, max_rows=10):
    """Convert a HuggingFace dataset to a pandas DataFrame for preview"""
    try:
        # Convert to pandas DataFrame
        df = pd.DataFrame(data[:max_rows])
        return df
    except Exception as e:
        st.error(f"Error converting dataset to DataFrame: {str(e)}")
        return None

def generate_basic_stats(data):
    """Generate basic statistics for a dataset"""
    try:
        # Convert to pandas DataFrame
        df = pd.DataFrame(data)
        
        # Get column types
        column_types = df.dtypes
        
        # Initialize stats dictionary
        stats = {}
        
        for col in df.columns:
            col_stats = {}
            
            # Check if column is numeric
            if pd.api.types.is_numeric_dtype(df[col]):
                col_stats["mean"] = df[col].mean()
                col_stats["median"] = df[col].median()
                col_stats["std"] = df[col].std()
                col_stats["min"] = df[col].min()
                col_stats["max"] = df[col].max()
                col_stats["missing"] = df[col].isna().sum()
            # Check if column is string/object
            elif pd.api.types.is_string_dtype(df[col]) or pd.api.types.is_object_dtype(df[col]):
                col_stats["unique_values"] = df[col].nunique()
                col_stats["most_common"] = df[col].value_counts().head(5).to_dict() if df[col].nunique() < 100 else "Too many unique values"
                col_stats["missing"] = df[col].isna().sum()
            
            stats[col] = col_stats
        
        return stats
    except Exception as e:
        st.error(f"Error generating statistics: {str(e)}")
        return None

def create_visualization(data, viz_type, x_col=None, y_col=None, hue_col=None):
    """Create a visualization based on the selected type and columns"""
    try:
        df = pd.DataFrame(data)
        
        fig, ax = plt.subplots(figsize=(10, 6))
        
        if viz_type == "Bar Chart":
            if x_col and y_col:
                sns.barplot(x=x_col, y=y_col, hue=hue_col, data=df, ax=ax)
            else:
                st.warning("Bar charts require both X and Y columns.")
                return None
        
        elif viz_type == "Line Chart":
            if x_col and y_col:
                sns.lineplot(x=x_col, y=y_col, hue=hue_col, data=df, ax=ax)
            else:
                st.warning("Line charts require both X and Y columns.")
                return None
        
        elif viz_type == "Scatter Plot":
            if x_col and y_col:
                sns.scatterplot(x=x_col, y=y_col, hue=hue_col, data=df, ax=ax)
            else:
                st.warning("Scatter plots require both X and Y columns.")
                return None
        
        elif viz_type == "Histogram":
            if x_col:
                sns.histplot(df[x_col], ax=ax)
            else:
                st.warning("Histograms require an X column.")
                return None
        
        elif viz_type == "Box Plot":
            if x_col and y_col:
                sns.boxplot(x=x_col, y=y_col, hue=hue_col, data=df, ax=ax)
            else:
                st.warning("Box plots require both X and Y columns.")
                return None
        
        elif viz_type == "Count Plot":
            if x_col:
                sns.countplot(x=x_col, hue=hue_col, data=df, ax=ax)
            else:
                st.warning("Count plots require an X column.")
                return None
        
        # Set title and labels
        plt.title(f"{viz_type} of {y_col if y_col else ''} vs {x_col if x_col else ''}")
        plt.xlabel(x_col if x_col else "")
        plt.ylabel(y_col if y_col else "")
        plt.tight_layout()
        
        return fig
    
    except Exception as e:
        st.error(f"Error creating visualization: {str(e)}")
        return None

def get_popular_datasets(category=None, limit=10):
    """Get popular HuggingFace datasets, optionally filtered by category"""
    popular_datasets = {
        "Text": ["glue", "imdb", "squad", "wikitext", "ag_news"],
        "Image": ["cifar10", "cifar100", "mnist", "fashion_mnist", "coco"],
        "Audio": ["common_voice", "librispeech_asr", "voxpopuli", "voxceleb", "audiofolder"],
        "Multimodal": ["conceptual_captions", "flickr8k", "hateful_memes", "nlvr", "vqa"]
    }
    
    if category and category in popular_datasets:
        return popular_datasets[category][:limit]
    else:
        # Return all datasets flattened
        all_datasets = []
        for cat_datasets in popular_datasets.values():
            all_datasets.extend(cat_datasets)
        return all_datasets[:limit]