Abdelmoula01's picture
Update app.py
5135ae3 verified
import streamlit as st
from langchain_core.messages import HumanMessage, AIMessage
from src.RAG_pipeline import RAGConfig, RAGPipeline
from src.language_detector import LanguageDetector
###
import torch
import pandas as pd
from transformers import BertForSequenceClassification, BertTokenizer
# ========================
# Streamlit App Definition
# ========================
def initialize_session_state():
"""Initialize session state variables if not already set."""
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
if "pipeline" not in st.session_state:
st.session_state.pipeline = None
if "pipeline_language" not in st.session_state:
st.session_state.pipeline_language = None
if "language_detector" not in st.session_state:
try:
st.session_state.language_detector = LanguageDetector(config_path="./Configs/config.yaml")
except Exception as e:
st.error(f"Error initializing language detector: {e}")
def get_pipeline_for_language(language: str) -> RAGPipeline:
"""
Load the RAG configuration from YAML and override the persist_directory
based on the detected language.
"""
# Choose the persist directory based on language.
desired_persist_dir = "./data/chroma_db_ar" if language == "darija" else "./data/chroma_db_fr"
config = RAGConfig.load_from_yaml("./Configs/config.yaml")
config.persist_dir = desired_persist_dir # override according to detected language
pipeline = RAGPipeline(config)
return pipeline
###########
# Load BERT model and tokenizer
MODEL_PATH = "models/saved_bert_model_v2"
tokenizer = BertTokenizer.from_pretrained(MODEL_PATH)
model = BertForSequenceClassification.from_pretrained(MODEL_PATH)
model.eval() # set model to eval mode
data_cls_fr = pd.read_excel("./data/Classification dataset - Q&A.xlsx", sheet_name = "Fr")
data_cls_ar = pd.read_excel("./data/Classification dataset - Q&A.xlsx", sheet_name = "Ar")
# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
def predict_class(answer, max_length=500):
"""
Predicts the class (as string) for a given answer text.
Args:
answer (str): The answer text to classify.
max_length (int): Maximum token length for the tokenizer.
Returns:
str: The predicted class label as text.
"""
# Tokenize
inputs = tokenizer(
answer,
add_special_tokens=True,
max_length=max_length,
padding='max_length',
truncation=True,
return_tensors="pt"
)
# Move inputs to the device
inputs = {k: v.to(device) for k, v in inputs.items()}
# Inference
with torch.no_grad():
outputs = model(**inputs)
# Extract logits and find predicted class
logits = outputs.logits
predicted_class_id = torch.argmax(logits, dim=1).item()
predicted_label = model.config.id2label[predicted_class_id]
return predicted_label
def main():
st.title("Assistant de chat INWI")
st.markdown(
"""
Bienvenue sur votre assistant IA INWI.
Posez votre question ici et obtenez une réponse instantanée !
"""
)
# Initialize session state variables.
initialize_session_state()
# Sidebar controls: allow resetting the conversation.
st.sidebar.header("Conversation Controls")
if st.sidebar.button("Reset Conversation"):
st.session_state.chat_history = []
st.session_state.pipeline = None
st.session_state.pipeline_language = None
st.rerun()
# Text input for the user message.
user_input = st.text_input("Your message:", key="input_text")
if st.button("Send") and user_input:
# Detect language of the user input.
try:
lang, __ = st.session_state.language_detector.detect_language(user_input)
st.write(f"**Detected language:** {lang}")
except Exception as e:
st.error(f"Language detection error: {e}")
lang = "fr" # Fallback if detection fails
# (Re)initialize the pipeline if it has not been created yet or if the language changed.
if st.session_state.pipeline is None or st.session_state.pipeline_language != lang:
st.info(f"Initializing pipeline for language: {lang}")
try:
st.session_state.pipeline = get_pipeline_for_language(lang)
st.session_state.pipeline_language = lang
except Exception as e:
st.error(f"Error initializing pipeline: {e}")
return
# Generate the response using the RAG pipeline.
try:
response = st.session_state.pipeline.generate_response(
query=user_input,
chat_history=st.session_state.chat_history
)
############
# Classify the answer (RAG pipeline's response)
predicted_label = predict_class(user_input) #response
st.write(f"**Predicted class for the answer**: {predicted_label}")
id_ref = list(data_cls_fr.loc[data_cls_fr["Réponse"]==predicted_label]["Référence"])[0]
st.write(f"**Predicted refrence for the answer**: {id_ref}")
except Exception as e:
response = f"Error generating response: {e}"
# Update the conversation history.
st.session_state.chat_history.append(HumanMessage(content=user_input))
st.session_state.chat_history.append(AIMessage(content=response))
# Display the conversation history.
st.markdown("## Conversation")
for msg in st.session_state.chat_history:
if isinstance(msg, HumanMessage):
st.markdown(f"**User:** {msg.content}")
elif isinstance(msg, AIMessage):
st.markdown(f"**Assistant:** {msg.content}")
if __name__ == '__main__':
main()