import streamlit as st
import pandas as pd
from dotenv import load_dotenv
from datasets import load_dataset
import json
import re
from openai import OpenAI
import os
from config import DATASETS, MODELS
import matplotlib.pyplot as plt
import altair as alt
from concurrent.futures import ThreadPoolExecutor, as_completed
from tenacity import retry, wait_exponential, stop_after_attempt, retry_if_exception_type
import threading
from anthropic import Anthropic
import google.generativeai as genai
import hmac
import hashlib
from uuid import uuid4
from datetime import datetime
from huggingface_hub import CommitScheduler, Repository
from pathlib import Path

load_dotenv()

st.set_page_config(page_title="LLM Healthcare Benchmarking", layout="wide")

WRITE_LOCK = threading.Lock()
DATA_DIR = Path("data")
DATA_DIR.mkdir(exist_ok=True)
RESULTS_FILE = DATA_DIR / "results.csv"


scheduler = CommitScheduler(
    repo_id=os.getenv("HF_REPO_ID"),       
    repo_type="dataset",
    folder_path=DATA_DIR,
    path_in_repo="data",                
    every=10,                               
    token=os.getenv("HF_TOKEN")           
)

def initialize_session_state():
    if 'api_configured' not in st.session_state:
        st.session_state.api_configured = False
    if 'togetherai_client' not in st.session_state:
        st.session_state.togetherai_client = None
    if 'openai_client' not in st.session_state:
        st.session_state.openai_client = None
    if 'anthropic_client' not in st.session_state:
        st.session_state.anthropic_client = None
    if 'all_results' not in st.session_state:
        st.session_state.all_results = {}
    if 'detailed_model' not in st.session_state:
        st.session_state.detailed_model = None
    if 'detailed_dataset' not in st.session_state:
        st.session_state.detailed_dataset = None
    if 'last_evaluated_dataset' not in st.session_state:
        st.session_state.last_evaluated_dataset = None

initialize_session_state()

def setup_api_clients():
    with st.sidebar:
        st.title("API Configuration")
        
        use_stored = st.checkbox("Use the stored API keys")
        
        if use_stored:
            username = st.text_input("Username")
            password = st.text_input("Password", type="password")
            
            if st.button("Verify Credentials"):
                stored_username = os.getenv("STREAMLIT_USERNAME", "")
                stored_password = os.getenv("STREAMLIT_PASSWORD", "")
                
                if (hmac.compare_digest(username, stored_username) and 
                    hmac.compare_digest(password, stored_password)):
                    try:
                        st.session_state.togetherai_client = OpenAI(
                            api_key=os.getenv('TOGETHERAI_API_KEY'),
                            base_url="https://api.together.xyz/v1"
                        )
                        st.session_state.openai_client = OpenAI(
                            api_key=os.getenv('OPENAI_API_KEY')
                        )
                        st.session_state.anthropic_client = Anthropic(
                            api_key=os.getenv('ANTHROPIC_API_KEY')
                        )
                        genai.configure(api_key=os.getenv("GEMINI_API_KEY"))
                        
                        st.session_state.api_configured = True
                        st.success("Successfully configured the API clients with stored keys!")
                    except Exception as e:
                        st.error(f"Error initializing API clients: {str(e)}")
                        st.session_state.api_configured = False
                else:
                    st.error("Invalid credentials. Please try again or use your own API keys.")
                    st.session_state.api_configured = False
        else:
            st.subheader("Enter Your API Keys")
            togetherai_key = st.text_input("Together AI API Key", type="password", key="togetherai_key")
            openai_key = st.text_input("OpenAI API Key", type="password", key="openai_key")
            anthropic_key = st.text_input("Anthropic API Key", type="password", key="anthropic_key")
            gemini_key = st.text_input("Gemini API Key", type="password", key="gemini_key")
            
            if st.button("Initialize with the provided keys"):
                try:
                    st.session_state.togetherai_client = OpenAI(
                        api_key=togetherai_key,
                        base_url="https://api.together.xyz/v1"
                    )
                    st.session_state.openai_client = OpenAI(
                        api_key=openai_key
                    )
                    st.session_state.anthropic_client = Anthropic(
                        api_key=anthropic_key
                    )
                    genai.configure(api_key=gemini_key)
                    
                    st.session_state.api_configured = True
                    st.success("Successfully configured the API clients with provided keys!")
                except Exception as e:
                    st.error(f"Error initializing API clients: {str(e)}")
                    st.session_state.api_configured = False

setup_api_clients()


scheduler.start()

MAX_CONCURRENT_CALLS = 5
semaphore = threading.Semaphore(MAX_CONCURRENT_CALLS)

@st.cache_data
def load_dataset_by_name(dataset_name, split="train"):
    dataset_config = DATASETS[dataset_name]
    dataset = load_dataset(dataset_config["loader"])
    df = pd.DataFrame(dataset[split])
    df = df[df['choice_type'] == 'single']
    
    questions = []
    for _, row in df.iterrows():
        options = [row['opa'], row['opb'], row['opc'], row['opd']]
        correct_answer = options[row['cop']]
        
        question_dict = {
            'question': row['question'],
            'options': options,
            'correct_answer': correct_answer,
            'subject_name': row['subject_name'],
            'topic_name': row['topic_name'],
            'explanation': row['exp']
        }
        questions.append(question_dict)
    
    st.write(f"Loaded {len(questions)} single-select questions from `{dataset_name}`")
    return questions

@retry(
    wait=wait_exponential(multiplier=1, min=4, max=10),
    stop=stop_after_attempt(5),
    retry=retry_if_exception_type(Exception)
)
def get_model_response(question, options, prompt_template, model_name, clients):
    with semaphore:
        try:
            model_config = MODELS[model_name]
            options_text = "\n".join([f"{chr(65+i)}. {opt}" for i, opt in enumerate(options)])
            prompt = prompt_template.replace("{question}", question).replace("{options}", options_text)

            provider = model_config["provider"]

            if provider == "togetherai":
                response = clients["togetherai"].chat.completions.create(
                            model=model_config["model_id"],
                            messages=[{"role": "user", "content": prompt}]
                            )
                response_text = response.choices[0].message.content.strip()

            elif provider == "openai":
                response = clients["openai"].chat.completions.create(
                        model=model_config["model_id"],
                        messages=[{"role": "user", "content": prompt}]      
                    )
                response_text = response.choices[0].message.content.strip()

            elif provider == "anthropic":
                response = clients["anthropic"].messages.create(
                model=model_config["model_id"],
                messages=[{"role": "user", "content": prompt}],
                max_tokens=4096 
                )
                response_text = response.content[0].text

            elif provider == "google":
                model = genai.GenerativeModel(
                model_name=model_config["model_id"]
                )

                chat_session = model.start_chat(
                history=[]
                )
                response_text = chat_session.send_message(prompt).text

            # Extract JSON from response
            json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
            if not json_match:
                return f"Error: Invalid response format", response_text
            
            json_response = json.loads(json_match.group(0))
            answer = json_response.get('answer', '').strip()
            answer = re.sub(r'^[A-D]\.\s*', '', answer)
            
            if not any(answer.lower() == opt.lower() for opt in options):
                return f"Error: Answer '{answer}' does not match any options", response_text
            
            return answer, response_text
        except Exception as e:
            return f"Error: {str(e)}", str(e)


def evaluate_response(model_response, correct_answer):
    if model_response.startswith("Error:"):
        return False
    is_correct = model_response.lower().strip() == correct_answer.lower().strip()
    return is_correct

def process_single_evaluation(question, prompt_template, model_name, clients, last_evaluated_dataset):
    answer, response_text = get_model_response(
        question['question'],
        question['options'],
        prompt_template,
        model_name,
        clients 
    )
    is_correct = evaluate_response(answer, question['correct_answer'])
    result = {
        'dataset': last_evaluated_dataset, 
        'model': model_name,
        'question': question['question'],
        'correct_answer': question['correct_answer'],
        'subject': question['subject_name'],
        'options': ' | '.join(question['options']),
        'model_response': answer,
        'is_correct': is_correct,
        'explanation': question['explanation'],
        'timestamp': datetime.utcnow().isoformat()
    }
    
    with WRITE_LOCK:
        if RESULTS_FILE.exists():
            existing_df = pd.read_csv(RESULTS_FILE)
            updated_df = existing_df.append(result, ignore_index=True)
        else:
            updated_df = pd.DataFrame([result])
        
        updated_df.to_csv(RESULTS_FILE, index=False)
    
    return result

def process_evaluations_concurrently(questions, prompt_template, models_to_evaluate, progress_callback, clients, last_evaluated_dataset):
    results = []
    total_iterations = len(models_to_evaluate) * len(questions)
    current_iteration = 0

    if RESULTS_FILE.exists():
        existing_df = pd.read_csv(RESULTS_FILE)
        completed = set(zip(existing_df['model'], existing_df['question']))
    else:
        completed = set()

    with ThreadPoolExecutor(max_workers=MAX_CONCURRENT_CALLS) as executor:
        future_to_params = {}
        for model_name in models_to_evaluate:
            for question in questions:
                if (model_name, question['question']) in completed:
                    current_iteration += 1
                    progress_callback(current_iteration, total_iterations)
                    continue  # Skip already completed evaluations
                future = executor.submit(
                    process_single_evaluation, 
                    question, 
                    prompt_template, 
                    model_name, 
                    clients, 
                    last_evaluated_dataset
                )
                future_to_params[future] = (model_name, question)
        
        for future in as_completed(future_to_params):
            result = future.result()
            results.append(result)
            current_iteration += 1
            progress_callback(current_iteration, total_iterations)
    
    return results


def main():
    if 'all_results' not in st.session_state:
        st.session_state.all_results = {}
        st.session_state.last_evaluated_dataset = None
    if RESULTS_FILE.exists():
        existing_df = pd.read_csv(RESULTS_FILE)
        all_results = {}
        for _, row in existing_df.iterrows():
            model = row['model']
            result = row.to_dict()
            if model not in all_results:
                all_results[model] = []
            all_results[model].append(result)
        st.session_state.all_results = all_results
        st.session_state.last_evaluated_dataset = existing_df['dataset'].iloc[-1]
        st.info(f"Loaded existing results from `{RESULTS_FILE}`.")
    else:
        st.session_state.all_results = {}
        st.session_state.last_evaluated_dataset = None
        st.info(f"No existing results found. Ready to start fresh.")

    with st.sidebar:
        if st.button("Reset Results"):
            if RESULTS_FILE.exists():
                try:
                    RESULTS_FILE.unlink()
                    st.session_state.all_results = {}
                    st.session_state.last_evaluated_dataset = None
                    st.success("Results have been reset.")
                except Exception as e:
                    st.error(f"Error deleting file: {str(e)}")
            else:
                st.info("No results to reset.")

    col1, col2 = st.columns(2)
    with col1:
        selected_dataset = st.selectbox(
            "Select Dataset",
            options=list(DATASETS.keys()),
            help="Choose the dataset to evaluate on"
        )
    with col2:
        selected_models = st.multiselect(
            "Select Model(s)",
            options=list(MODELS.keys()),
            default=[list(MODELS.keys())[0]],
            help="Choose one or more models to evaluate."
        )

    models_to_evaluate = selected_models

    default_prompt = '''You are a medical AI assistant. Please answer the following multiple choice question.
Question: {question}

Options:
{options}

## Output Format:
Please provide your answer in JSON format that contains an "answer" field.
You may include any additional fields in your JSON response that you find relevant, such as:
- "choice reasoning": your detailed reasoning
- "elimination reasoning": why you ruled out other options

Example response format:
{
    "answer": "exact option text here(e.g., A. xxx, B. xxx, C. xxx)",
    "choice reasoning": "your detailed reasoning here",
    "elimination reasoning": "why you ruled out other options"
}

Important:
- Only the "answer" field will be used for evaluation
- Ensure your response is in valid JSON format'''

    # Customize Prompt Template
    col1, col2 = st.columns([2, 1])
    with col1:
        prompt_template = st.text_area(
            "Customize Prompt Template", 
            default_prompt, 
            height=400,
            help="Edit the prompt template before starting the evaluation."
        )
    
    with col2:
        st.markdown("""
        ### Prompt Variables
        - `{question}`: The medical question
        - `{options}`: The multiple choice options
        """)

    # Load Dataset
    if st.session_state.api_configured:
        with st.spinner("Loading dataset..."):
            questions = load_dataset_by_name(selected_dataset)
    else:
        st.warning("Please configure the API keys in the sidebar to load datasets and proceed.")
        questions = []

    # Filter by Subject
    if questions:
        subjects = sorted(list(set(q['subject_name'] for q in questions)))
        selected_subject = st.selectbox("Filter by subject", ["All"] + subjects)
        
        if selected_subject != "All":
            questions = [q for q in questions if q['subject_name'] == selected_subject]

        # Number of Questions to Evaluate
        num_questions = st.number_input(
            "Number of questions to evaluate", 
            min_value=1, 
            max_value=len(questions), 
            value=min(10, len(questions)), 
            step=1
        )

        # Start Evaluation Button
        if st.button("Start Evaluation"):
            if not models_to_evaluate:
                st.error("Please select at least one model to evaluate.")
            else:
                with st.spinner("Starting evaluation..."):
                    selected_questions = questions[:num_questions]
                    
                    clients = {
                        "togetherai": st.session_state["togetherai_client"],
                        "openai": st.session_state["openai_client"],
                        "anthropic": st.session_state["anthropic_client"]
                    }
                    
                    last_evaluated_dataset = st.session_state.last_evaluated_dataset if st.session_state.last_evaluated_dataset else selected_dataset

                    progress_container = st.container()
                    progress_bar = progress_container.progress(0)
                    status_text = progress_container.empty()
                    
                    def update_progress(current, total):
                        progress = current / total
                        progress_bar.progress(progress)
                        status_text.text(f"Progress: {current}/{total} evaluations completed")

                    results = process_evaluations_concurrently(
                        selected_questions,
                        prompt_template,
                        models_to_evaluate,
                        update_progress,
                        clients,
                        last_evaluated_dataset
                    )
                
                # Update Session State with New Results
                all_results = st.session_state.all_results.copy()
                for result in results:
                    model = result.pop('model')
                    if model not in all_results:
                        all_results[model] = []
                    all_results[model].append(result)
                
                st.session_state.all_results = all_results
                st.session_state.last_evaluated_dataset = selected_dataset

                # Set Default Detailed Model and Dataset if Not Set
                if st.session_state.detailed_model is None and all_results:
                    st.session_state.detailed_model = list(all_results.keys())[0]
                if st.session_state.detailed_dataset is None:
                    st.session_state.detailed_dataset = selected_dataset

                st.success("Evaluation completed!")
                st.experimental_rerun()

    # Display Evaluation Results
    if st.session_state.all_results:
        st.subheader("Evaluation Results")
        model_metrics = {}

        for model_name, results in st.session_state.all_results.items():
            df = pd.DataFrame(results)
            metrics = {
                'Accuracy': df['is_correct'].mean(),
            }
            model_metrics[model_name] = metrics

        metrics_df = pd.DataFrame(model_metrics).T.reset_index().rename(columns={'index': 'Model'})

        st.subheader("Model Performance Comparison")
        accuracy_chart = alt.Chart(
            metrics_df
        ).mark_bar().encode(
            x=alt.X('Model:N', title=None),
            y=alt.Y('Accuracy:Q', title='Accuracy', scale=alt.Scale(domain=[0, 1])),
            color=alt.Color('Model:N', scale=alt.Scale(scheme='blues')),
            tooltip=['Model:N', 'Accuracy:Q']
        ).properties(
            height=300,
            title={
                "text": "Model Accuracy",
                "anchor": "middle",
                "fontSize": 20
            }
        ).interactive()

        st.altair_chart(accuracy_chart, use_container_width=True)

    # Display Detailed Results
    if st.session_state.all_results:
        st.subheader("Detailed Results")
        
        def update_model():
            st.session_state.detailed_model = st.session_state.model_select
            
        def update_dataset():
            st.session_state.detailed_dataset = st.session_state.dataset_select

        col1, col2 = st.columns(2)
        with col1:
            selected_model_details = st.selectbox(
                "Select model",
                options=list(st.session_state.all_results.keys()),
                key="model_select",
                on_change=update_model,
                index=list(st.session_state.all_results.keys()).index(st.session_state.detailed_model) 
                    if st.session_state.detailed_model in st.session_state.all_results else 0
            )
        
        with col2:
            selected_dataset_details = st.selectbox(
                "Select dataset",
                options=[st.session_state.last_evaluated_dataset] if st.session_state.last_evaluated_dataset else [],
                key="dataset_select",
                on_change=update_dataset
            )

        if selected_model_details and selected_model_details in st.session_state.all_results:
            results = st.session_state.all_results[selected_model_details]
            df = pd.DataFrame(results)
            accuracy = df['is_correct'].mean()
            
            st.metric("Accuracy", f"{accuracy:.2%}")
            
            for idx, result in enumerate(results):
                with st.expander(f"Question {idx + 1} - {result['subject']}"):
                    st.write("**Question:**", result['question'])
                    st.write("**Options:**")
                    for i, opt in enumerate(result['options'].split(' | ')):
                        st.write(f"{chr(65+i)}. {opt}")
                    
                    col1, col2 = st.columns(2)
                    with col1:
                        st.write("**Model Response:**")
                        st.code(result.get('model_response', "N/A"))
                    with col2:
                        st.write("**Explanation:**")
                        st.code(result.get('explanation', "N/A"))
                    
                    col1, col2 = st.columns(2)
                    with col1:
                        st.write("**Correct Answer:**", result['correct_answer'])
                        st.write("**Model Answer:**", result['model_response'])
                    with col2:
                        if result['is_correct']:
                            st.success("Correct!")
                        else:
                            st.error("Incorrect")
                    
                    st.write("**Timestamp:**", result['timestamp'])
        else:
            st.info(f"No results available for {selected_model_details} on {selected_dataset_details}. Please run the evaluation first.")

        st.markdown("---")
        st.subheader("Download Results")
        if RESULTS_FILE.exists():
            csv_data = RESULTS_FILE.read_text(encoding='utf-8')
            st.download_button(
                label="Download All Results as CSV",
                data=csv_data,
                file_name=f"all_models_{st.session_state.last_evaluated_dataset}_results.csv",
                mime="text/csv", 
                key="download_all_results"
            )
        else:
            st.info("No data available to download.")

if __name__ == "__main__":
    main()