import streamlit as st from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer from huggingface_hub import login from threading import Thread import PyPDF2 import pandas as pd import torch import time # Check if 'peft' is installed try: from peft import PeftModel, PeftConfig except ImportError: raise ImportError( "The 'peft' library is required but not installed. " "Please install it using: `pip install peft`" ) # Set page configuration st.set_page_config( page_title="WizNerd Insp", page_icon="🚀", layout="centered" ) # Model names BASE_MODEL_NAME = "HuggingFaceTB/SmolLM2-360M" MODEL_OPTIONS = { "Full Fine-Tuned": "amiguel/SmolLM2-360M-concise-reasoning", "LoRA Adapter": "amiguel/SmolLM2-360M-concise-reasoning-lora", "QLoRA Adapter": "amiguel/SmolLM2-360M-concise-reasoning-qlora" # Hypothetical, adjust if needed } # Title with rocket emojis st.title("🚀 WizNerd Insp 🚀") # Configure Avatars USER_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/9904d9a0d445ab0488cf7395cb863cce7621d897/USER_AVATAR.png" BOT_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/991f4c6e4e1dc7a8e24876ca5aae5228bcdb4dba/Ataliba_Avatar.jpg" # Sidebar configuration with st.sidebar: st.header("Authentication 🔒") hf_token = st.text_input("Hugging Face Token", type="password", help="Get your token from https://huggingface.co/settings/tokens") st.header("Model Selection 🤖") model_type = st.selectbox("Choose Model Type", list(MODEL_OPTIONS.keys()), index=0) selected_model = MODEL_OPTIONS[model_type] st.header("Upload Documents 📂") uploaded_file = st.file_uploader( "Choose a PDF or XLSX file", type=["pdf", "xlsx"], label_visibility="collapsed" ) # Initialize chat history if "messages" not in st.session_state: st.session_state.messages = [] # File processing function @st.cache_data def process_file(uploaded_file): if uploaded_file is None: return "" try: if uploaded_file.type == "application/pdf": pdf_reader = PyPDF2.PdfReader(uploaded_file) return "\n".join([page.extract_text() for page in pdf_reader.pages]) elif uploaded_file.type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": df = pd.read_excel(uploaded_file) return df.to_markdown() except Exception as e: st.error(f"📄 Error processing file: {str(e)}") return "" # Model loading function @st.cache_resource def load_model(hf_token, model_type, selected_model): try: if not hf_token: st.error("🔐 Authentication required! Please provide a Hugging Face token.") return None login(token=hf_token) # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME, token=hf_token) # Load model based on type if model_type == "Full Fine-Tuned": # Load full fine-tuned model directly model = AutoModelForCausalLM.from_pretrained( selected_model, torch_dtype=torch.bfloat16, device_map="auto", token=hf_token ) else: # Load base model and apply PEFT adapter base_model = AutoModelForCausalLM.from_pretrained( BASE_MODEL_NAME, torch_dtype=torch.bfloat16, device_map="auto", token=hf_token ) model = PeftModel.from_pretrained( base_model, selected_model, torch_dtype=torch.bfloat16, is_trainable=False, # Inference mode token=hf_token ) return model, tokenizer except Exception as e: st.error(f"🤖 Model loading failed: {str(e)}") return None # Generation function with KV caching def generate_with_kv_cache(prompt, file_context, model, tokenizer, use_cache=True): full_prompt = f"Analyze this context:\n{file_context}\n\nQuestion: {prompt}\nAnswer:" streamer = TextIteratorStreamer( tokenizer, skip_prompt=True, skip_special_tokens=True ) inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device) generation_kwargs = { "input_ids": inputs["input_ids"], "attention_mask": inputs["attention_mask"], "max_new_tokens": 1024, "temperature": 0.7, "top_p": 0.9, "repetition_penalty": 1.1, "do_sample": True, "use_cache": use_cache, "streamer": streamer } Thread(target=model.generate, kwargs=generation_kwargs).start() return streamer # Display chat messages for message in st.session_state.messages: try: avatar = USER_AVATAR if message["role"] == "user" else BOT_AVATAR with st.chat_message(message["role"], avatar=avatar): st.markdown(message["content"]) except: with st.chat_message(message["role"]): st.markdown(message["content"]) # Chat input handling if prompt := st.chat_input("Ask your inspection question..."): if not hf_token: st.error("🔑 Authentication required!") st.stop() # Load model if not already loaded or if model type changed if "model" not in st.session_state or st.session_state.get("model_type") != model_type: model_data = load_model(hf_token, model_type, selected_model) if model_data is None: st.error("Failed to load model. Please check your token and try again.") st.stop() st.session_state.model, st.session_state.tokenizer = model_data st.session_state.model_type = model_type model = st.session_state.model tokenizer = st.session_state.tokenizer # Add user message with st.chat_message("user", avatar=USER_AVATAR): st.markdown(prompt) st.session_state.messages.append({"role": "user", "content": prompt}) # Process file file_context = process_file(uploaded_file) # Generate response with KV caching if model and tokenizer: try: with st.chat_message("assistant", avatar=BOT_AVATAR): start_time = time.time() streamer = generate_with_kv_cache(prompt, file_context, model, tokenizer, use_cache=True) response_container = st.empty() full_response = "" for chunk in streamer: cleaned_chunk = chunk.replace("", "").replace("", "").strip() full_response += cleaned_chunk + " " response_container.markdown(full_response + "▌", unsafe_allow_html=True) # Calculate performance metrics end_time = time.time() input_tokens = len(tokenizer(prompt)["input_ids"]) output_tokens = len(tokenizer(full_response)["input_ids"]) speed = output_tokens / (end_time - start_time) # Calculate costs (hypothetical pricing model) input_cost = (input_tokens / 1000000) * 5 # $5 per million input tokens output_cost = (output_tokens / 1000000) * 15 # $15 per million output tokens total_cost_usd = input_cost + output_cost total_cost_aoa = total_cost_usd * 1160 # Convert to AOA (Angolan Kwanza) # Display metrics st.caption( f"🔑 Input Tokens: {input_tokens} | Output Tokens: {output_tokens} | " f"🕒 Speed: {speed:.1f}t/s | 💰 Cost (USD): ${total_cost_usd:.4f} | " f"💵 Cost (AOA): {total_cost_aoa:.4f}" ) response_container.markdown(full_response) st.session_state.messages.append({"role": "assistant", "content": full_response}) except Exception as e: st.error(f"⚡ Generation error: {str(e)}") else: st.error("🤖 Model not loaded!")