|
import streamlit as st |
|
import streamlit.components.v1 as components |
|
from pathlib import Path |
|
import tempfile |
|
import shutil |
|
import os |
|
import json |
|
from omegaconf import OmegaConf |
|
from rich.console import Console |
|
import sys |
|
from dotenv import load_dotenv |
|
import logging |
|
from aide import Experiment |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
|
handlers=[logging.StreamHandler(sys.stderr)], |
|
) |
|
|
|
logger = logging.getLogger("aide") |
|
logger.setLevel(logging.INFO) |
|
|
|
console = Console(file=sys.stderr) |
|
|
|
|
|
class WebUI: |
|
""" |
|
WebUI encapsulates the Streamlit application logic for the AIDE Machine Learning Engineer Agent. |
|
""" |
|
|
|
def __init__(self): |
|
""" |
|
Initialize the WebUI with environment variables and session state. |
|
""" |
|
self.env_vars = self.load_env_variables() |
|
self.project_root = Path(__file__).parent |
|
self.config_session_state() |
|
self.setup_page() |
|
|
|
@staticmethod |
|
def load_env_variables(): |
|
""" |
|
Load API keys and environment variables from .env file. |
|
|
|
Returns: |
|
dict: Dictionary containing API keys. |
|
""" |
|
load_dotenv() |
|
return { |
|
"openai_key": os.getenv("OPENAI_API_KEY", ""), |
|
"anthropic_key": os.getenv("ANTHROPIC_API_KEY", ""), |
|
} |
|
|
|
@staticmethod |
|
def config_session_state(): |
|
""" |
|
Configure default values for Streamlit session state. |
|
""" |
|
if "is_running" not in st.session_state: |
|
st.session_state.is_running = False |
|
if "current_step" not in st.session_state: |
|
st.session_state.current_step = 0 |
|
if "total_steps" not in st.session_state: |
|
st.session_state.total_steps = 0 |
|
if "progress" not in st.session_state: |
|
st.session_state.progress = 0 |
|
if "results" not in st.session_state: |
|
st.session_state.results = None |
|
|
|
@staticmethod |
|
def setup_page(): |
|
""" |
|
Set up the Streamlit page configuration and load custom CSS. |
|
""" |
|
st.set_page_config( |
|
page_title="AIDE: Machine Learning Engineer Agent", |
|
layout="wide", |
|
) |
|
WebUI.load_css() |
|
|
|
@staticmethod |
|
def load_css(): |
|
""" |
|
Load custom CSS styles from 'style.css' file. |
|
""" |
|
css_file = Path(__file__).parent / "style.css" |
|
if css_file.exists(): |
|
with open(css_file) as f: |
|
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True) |
|
else: |
|
st.warning(f"CSS file not found at: {css_file}") |
|
|
|
def run(self): |
|
""" |
|
Run the main logic of the Streamlit application. |
|
""" |
|
self.render_sidebar() |
|
input_col, results_col = st.columns([1, 3]) |
|
with input_col: |
|
self.render_input_section(results_col) |
|
|
|
def render_sidebar(self): |
|
""" |
|
Render the sidebar with API key settings. |
|
""" |
|
with st.sidebar: |
|
st.header("⚙️ Settings") |
|
st.markdown( |
|
"<p style='text-align: center;'>OpenAI API Key</p>", |
|
unsafe_allow_html=True, |
|
) |
|
openai_key = st.text_input( |
|
"OpenAI API Key", |
|
value=self.env_vars["openai_key"], |
|
type="password", |
|
label_visibility="collapsed", |
|
) |
|
st.markdown( |
|
"<p style='text-align: center;'>Anthropic API Key</p>", |
|
unsafe_allow_html=True, |
|
) |
|
anthropic_key = st.text_input( |
|
"Anthropic API Key", |
|
value=self.env_vars["anthropic_key"], |
|
type="password", |
|
label_visibility="collapsed", |
|
) |
|
if st.button("Save API Keys", use_container_width=True): |
|
st.session_state.openai_key = openai_key |
|
st.session_state.anthropic_key = anthropic_key |
|
st.success("API keys saved!") |
|
|
|
def render_input_section(self, results_col): |
|
""" |
|
Render the input section of the application. |
|
|
|
Args: |
|
results_col (st.delta_generator.DeltaGenerator): The results column to pass to methods. |
|
""" |
|
st.header("Input") |
|
uploaded_files = self.handle_file_upload() |
|
goal_text, eval_text, num_steps = self.handle_user_inputs() |
|
if st.button("Run AIDE", type="primary", use_container_width=True): |
|
with st.spinner("AIDE is running..."): |
|
results = self.run_aide( |
|
uploaded_files, goal_text, eval_text, num_steps, results_col |
|
) |
|
st.session_state.results = results |
|
|
|
def handle_file_upload(self): |
|
""" |
|
Handle file uploads and example file loading. |
|
|
|
Returns: |
|
list: List of uploaded or example files. |
|
""" |
|
|
|
if not st.session_state.get("example_files"): |
|
uploaded_files = st.file_uploader( |
|
"Upload Data Files", |
|
accept_multiple_files=True, |
|
type=["csv", "txt", "json", "md"], |
|
label_visibility="collapsed", |
|
) |
|
|
|
if uploaded_files: |
|
st.session_state.pop( |
|
"example_files", None |
|
) |
|
return uploaded_files |
|
|
|
|
|
if st.button( |
|
"Load Example Experiment", type="primary", use_container_width=True |
|
): |
|
st.session_state.example_files = self.load_example_files() |
|
|
|
if st.session_state.get("example_files"): |
|
st.info("Example files loaded! Click 'Run AIDE' to proceed.") |
|
with st.expander("View Loaded Files", expanded=False): |
|
for file in st.session_state.example_files: |
|
st.text(f"📄 {file['name']}") |
|
return st.session_state.example_files |
|
|
|
return [] |
|
|
|
def handle_user_inputs(self): |
|
""" |
|
Handle goal, evaluation criteria, and number of steps inputs. |
|
|
|
Returns: |
|
tuple: Goal text, evaluation criteria text, and number of steps. |
|
""" |
|
goal_text = st.text_area( |
|
"Goal", |
|
value=st.session_state.get("goal", ""), |
|
placeholder="Example: Predict the sales price for each house", |
|
) |
|
eval_text = st.text_area( |
|
"Evaluation Criteria", |
|
value=st.session_state.get("eval", ""), |
|
placeholder="Example: Use the RMSE metric between the logarithm of the predicted and observed values.", |
|
) |
|
num_steps = st.slider( |
|
"Number of Steps", |
|
min_value=1, |
|
max_value=20, |
|
value=st.session_state.get("steps", 10), |
|
) |
|
return goal_text, eval_text, num_steps |
|
|
|
@staticmethod |
|
def load_example_files(): |
|
""" |
|
Load example files from the 'example_tasks/house_prices' directory. |
|
|
|
Returns: |
|
list: List of example files with their paths. |
|
""" |
|
package_root = Path(__file__).parent / "aide" |
|
example_dir = package_root / "example_tasks" / "house_prices" |
|
|
|
if not example_dir.exists(): |
|
st.error(f"Example directory not found at: {example_dir}") |
|
return [] |
|
|
|
example_files = [] |
|
|
|
for file_path in example_dir.glob("*"): |
|
if file_path.suffix.lower() in [".csv", ".txt", ".json", ".md"]: |
|
with tempfile.NamedTemporaryFile( |
|
delete=False, suffix=file_path.suffix |
|
) as tmp_file: |
|
tmp_file.write(file_path.read_bytes()) |
|
example_files.append( |
|
{"name": file_path.name, "path": tmp_file.name} |
|
) |
|
|
|
if not example_files: |
|
st.warning("No example files found in the example directory") |
|
|
|
st.session_state["goal"] = "Predict the sales price for each house" |
|
st.session_state["eval"] = ( |
|
"Use the RMSE metric between the logarithm of the predicted and observed values." |
|
) |
|
|
|
return example_files |
|
|
|
def run_aide(self, files, goal_text, eval_text, num_steps, results_col): |
|
""" |
|
Run the AIDE experiment with the provided inputs. |
|
|
|
Args: |
|
files (list): List of uploaded or example files. |
|
goal_text (str): The goal of the experiment. |
|
eval_text (str): The evaluation criteria. |
|
num_steps (int): Number of steps to run. |
|
results_col (st.delta_generator.DeltaGenerator): Results column for displaying progress. |
|
|
|
Returns: |
|
dict: Dictionary containing the results of the experiment. |
|
""" |
|
try: |
|
self.initialize_run_state(num_steps) |
|
self.set_api_keys() |
|
|
|
input_dir = self.prepare_input_directory(files) |
|
if not input_dir: |
|
return None |
|
|
|
experiment = self.initialize_experiment(input_dir, goal_text, eval_text) |
|
|
|
|
|
progress_placeholder = results_col.empty() |
|
config_placeholder = results_col.empty() |
|
results_placeholder = results_col.empty() |
|
|
|
for step in range(num_steps): |
|
st.session_state.current_step = step + 1 |
|
progress = (step + 1) / num_steps |
|
|
|
|
|
with progress_placeholder.container(): |
|
st.markdown( |
|
f"### 🔥 Running Step {st.session_state.current_step}/{st.session_state.total_steps}" |
|
) |
|
st.progress(progress) |
|
|
|
|
|
if step == 0: |
|
with config_placeholder.container(): |
|
st.markdown("### 📋 Configuration") |
|
st.code(OmegaConf.to_yaml(experiment.cfg), language="yaml") |
|
|
|
experiment.run(steps=1) |
|
|
|
|
|
with results_placeholder.container(): |
|
self.render_live_results(experiment) |
|
|
|
|
|
if step == 0: |
|
config_placeholder.empty() |
|
|
|
|
|
progress_placeholder.empty() |
|
|
|
|
|
st.session_state.is_running = False |
|
st.session_state.results = self.collect_results(experiment) |
|
return st.session_state.results |
|
|
|
except Exception as e: |
|
st.session_state.is_running = False |
|
console.print_exception() |
|
st.error(f"Error occurred: {str(e)}") |
|
return None |
|
|
|
@staticmethod |
|
def initialize_run_state(num_steps): |
|
""" |
|
Initialize the running state for the experiment. |
|
|
|
Args: |
|
num_steps (int): Total number of steps in the experiment. |
|
""" |
|
st.session_state.is_running = True |
|
st.session_state.current_step = 0 |
|
st.session_state.total_steps = num_steps |
|
st.session_state.progress = 0 |
|
|
|
@staticmethod |
|
def set_api_keys(): |
|
""" |
|
Set the API keys in the environment variables from the session state. |
|
""" |
|
if st.session_state.get("openai_key"): |
|
os.environ["OPENAI_API_KEY"] = st.session_state.openai_key |
|
if st.session_state.get("anthropic_key"): |
|
os.environ["ANTHROPIC_API_KEY"] = st.session_state.anthropic_key |
|
|
|
def prepare_input_directory(self, files): |
|
""" |
|
Prepare the input directory and handle uploaded files. |
|
|
|
Args: |
|
files (list): List of uploaded or example files. |
|
|
|
Returns: |
|
Path: The input directory path, or None if files are missing. |
|
""" |
|
input_dir = self.project_root / "input" |
|
input_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
if files: |
|
for file in files: |
|
if isinstance(file, dict): |
|
shutil.copy2(file["path"], input_dir / file["name"]) |
|
else: |
|
with open(input_dir / file.name, "wb") as f: |
|
f.write(file.getbuffer()) |
|
else: |
|
st.error("Please upload data files") |
|
return None |
|
return input_dir |
|
|
|
@staticmethod |
|
def initialize_experiment(input_dir, goal_text, eval_text): |
|
""" |
|
Initialize the AIDE Experiment. |
|
|
|
Args: |
|
input_dir (Path): Path to the input directory. |
|
goal_text (str): The goal of the experiment. |
|
eval_text (str): The evaluation criteria. |
|
|
|
Returns: |
|
Experiment: The initialized Experiment object. |
|
""" |
|
experiment = Experiment(data_dir=str(input_dir), goal=goal_text, eval=eval_text) |
|
return experiment |
|
|
|
@staticmethod |
|
def collect_results(experiment): |
|
""" |
|
Collect the results from the experiment. |
|
|
|
Args: |
|
experiment (Experiment): The Experiment object. |
|
|
|
Returns: |
|
dict: Dictionary containing the collected results. |
|
""" |
|
solution_path = experiment.cfg.log_dir / "best_solution.py" |
|
if solution_path.exists(): |
|
solution = solution_path.read_text() |
|
else: |
|
solution = "No solution found" |
|
|
|
journal_data = [ |
|
{ |
|
"step": node.step, |
|
"code": str(node.code), |
|
"metric": str(node.metric.value) if node.metric else None, |
|
"is_buggy": node.is_buggy, |
|
} |
|
for node in experiment.journal.nodes |
|
] |
|
|
|
results = { |
|
"solution": solution, |
|
"config": OmegaConf.to_yaml(experiment.cfg), |
|
"journal": json.dumps(journal_data, indent=2, default=str), |
|
"tree_path": str(experiment.cfg.log_dir / "tree_plot.html"), |
|
} |
|
return results |
|
|
|
@staticmethod |
|
def render_tree_visualization(results): |
|
""" |
|
Render the tree visualization from the experiment results. |
|
|
|
Args: |
|
results (dict): The results dictionary containing paths and data. |
|
""" |
|
if "tree_path" in results: |
|
tree_path = Path(results["tree_path"]) |
|
logger.info(f"Loading tree visualization from: {tree_path}") |
|
if tree_path.exists(): |
|
with open(tree_path, "r", encoding="utf-8") as f: |
|
html_content = f.read() |
|
components.html(html_content, height=600, scrolling=True) |
|
else: |
|
st.error(f"Tree visualization file not found at: {tree_path}") |
|
logger.error(f"Tree file not found at: {tree_path}") |
|
else: |
|
st.info("No tree visualization available for this run.") |
|
|
|
@staticmethod |
|
def render_best_solution(results): |
|
""" |
|
Display the best solution code. |
|
|
|
Args: |
|
results (dict): The results dictionary containing the solution. |
|
""" |
|
if "solution" in results: |
|
solution_code = results["solution"] |
|
st.code(solution_code, language="python") |
|
else: |
|
st.info("No solution available.") |
|
|
|
@staticmethod |
|
def render_config(results): |
|
""" |
|
Display the configuration used in the experiment. |
|
|
|
Args: |
|
results (dict): The results dictionary containing the config. |
|
""" |
|
if "config" in results: |
|
st.code(results["config"], language="yaml") |
|
else: |
|
st.info("No configuration available.") |
|
|
|
@staticmethod |
|
def render_journal(results): |
|
""" |
|
Display the experiment journal as JSON. |
|
|
|
Args: |
|
results (dict): The results dictionary containing the journal. |
|
""" |
|
if "journal" in results: |
|
try: |
|
journal_data = json.loads(results["journal"]) |
|
formatted_journal = json.dumps(journal_data, indent=2) |
|
st.code(formatted_journal, language="json") |
|
except json.JSONDecodeError: |
|
st.code(results["journal"], language="json") |
|
else: |
|
st.info("No journal available.") |
|
|
|
@staticmethod |
|
def get_best_metric(results): |
|
""" |
|
Extract the best validation metric from results. |
|
""" |
|
try: |
|
journal_data = json.loads(results["journal"]) |
|
metrics = [] |
|
for node in journal_data: |
|
if node["metric"] is not None: |
|
try: |
|
|
|
metric_value = float(node["metric"]) |
|
metrics.append(metric_value) |
|
except (ValueError, TypeError): |
|
continue |
|
return max(metrics) if metrics else None |
|
except (json.JSONDecodeError, KeyError): |
|
return None |
|
|
|
@staticmethod |
|
def render_validation_plot(results, step): |
|
""" |
|
Render the validation score plot. |
|
|
|
Args: |
|
results (dict): The results dictionary |
|
step (int): Current step number for unique key generation |
|
""" |
|
try: |
|
journal_data = json.loads(results["journal"]) |
|
steps = [] |
|
metrics = [] |
|
|
|
for node in journal_data: |
|
if node["metric"] is not None and node["metric"].lower() != "none": |
|
try: |
|
metric_value = float(node["metric"]) |
|
steps.append(node["step"]) |
|
metrics.append(metric_value) |
|
except (ValueError, TypeError): |
|
continue |
|
|
|
if metrics: |
|
import plotly.graph_objects as go |
|
|
|
fig = go.Figure() |
|
fig.add_trace( |
|
go.Scatter( |
|
x=steps, |
|
y=metrics, |
|
mode="lines+markers", |
|
name="Validation Score", |
|
line=dict(color="#F04370"), |
|
marker=dict(color="#F04370"), |
|
) |
|
) |
|
|
|
fig.update_layout( |
|
title="Validation Score Progress", |
|
xaxis_title="Step", |
|
yaxis_title="Validation Score", |
|
template="plotly_white", |
|
hovermode="x unified", |
|
plot_bgcolor="rgba(0,0,0,0)", |
|
paper_bgcolor="rgba(0,0,0,0)", |
|
) |
|
|
|
|
|
st.plotly_chart(fig, use_container_width=True, key=f"plot_{step}") |
|
else: |
|
st.info("No validation metrics available to plot") |
|
|
|
except (json.JSONDecodeError, KeyError): |
|
st.error("Could not parse validation metrics data") |
|
|
|
def render_live_results(self, experiment): |
|
""" |
|
Render live results. |
|
|
|
Args: |
|
experiment (Experiment): The Experiment object |
|
""" |
|
results = self.collect_results(experiment) |
|
|
|
|
|
tabs = st.tabs( |
|
[ |
|
"Tree Visualization", |
|
"Best Solution", |
|
"Config", |
|
"Journal", |
|
"Validation Plot", |
|
] |
|
) |
|
|
|
with tabs[0]: |
|
self.render_tree_visualization(results) |
|
with tabs[1]: |
|
self.render_best_solution(results) |
|
with tabs[2]: |
|
self.render_config(results) |
|
with tabs[3]: |
|
self.render_journal(results) |
|
with tabs[4]: |
|
best_metric = self.get_best_metric(results) |
|
if best_metric is not None: |
|
st.metric("Best Validation Score", f"{best_metric:.4f}") |
|
self.render_validation_plot(results, step=st.session_state.current_step) |
|
|
|
|
|
if __name__ == "__main__": |
|
app = WebUI() |
|
app.run() |
|
|