Spaces:
Sleeping
Sleeping
import streamlit as st | |
from langchain_core.messages import HumanMessage, AIMessage | |
from src.RAG_pipeline import RAGConfig, RAGPipeline | |
from src.language_detector import LanguageDetector | |
### | |
import torch | |
import pandas as pd | |
from transformers import BertForSequenceClassification, BertTokenizer | |
# ======================== | |
# Streamlit App Definition | |
# ======================== | |
def initialize_session_state(): | |
"""Initialize session state variables if not already set.""" | |
if "chat_history" not in st.session_state: | |
st.session_state.chat_history = [] | |
if "pipeline" not in st.session_state: | |
st.session_state.pipeline = None | |
if "pipeline_language" not in st.session_state: | |
st.session_state.pipeline_language = None | |
if "language_detector" not in st.session_state: | |
try: | |
st.session_state.language_detector = LanguageDetector(config_path="./Configs/config.yaml") | |
except Exception as e: | |
st.error(f"Error initializing language detector: {e}") | |
def get_pipeline_for_language(language: str) -> RAGPipeline: | |
""" | |
Load the RAG configuration from YAML and override the persist_directory | |
based on the detected language. | |
""" | |
# Choose the persist directory based on language. | |
desired_persist_dir = "./data/chroma_db_ar" if language == "darija" else "./data/chroma_db_fr" | |
config = RAGConfig.load_from_yaml("./Configs/config.yaml") | |
config.persist_dir = desired_persist_dir # override according to detected language | |
pipeline = RAGPipeline(config) | |
return pipeline | |
########### | |
# Load BERT model and tokenizer | |
MODEL_PATH = "models/saved_bert_model_v2" | |
tokenizer = BertTokenizer.from_pretrained(MODEL_PATH) | |
model = BertForSequenceClassification.from_pretrained(MODEL_PATH) | |
model.eval() # set model to eval mode | |
data_cls_fr = pd.read_excel("./data/Classification dataset - Q&A.xlsx", sheet_name = "Fr") | |
data_cls_ar = pd.read_excel("./data/Classification dataset - Q&A.xlsx", sheet_name = "Ar") | |
# Use GPU if available | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
def predict_class(answer, max_length=500): | |
""" | |
Predicts the class (as string) for a given answer text. | |
Args: | |
answer (str): The answer text to classify. | |
max_length (int): Maximum token length for the tokenizer. | |
Returns: | |
str: The predicted class label as text. | |
""" | |
# Tokenize | |
inputs = tokenizer( | |
answer, | |
add_special_tokens=True, | |
max_length=max_length, | |
padding='max_length', | |
truncation=True, | |
return_tensors="pt" | |
) | |
# Move inputs to the device | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
# Inference | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
# Extract logits and find predicted class | |
logits = outputs.logits | |
predicted_class_id = torch.argmax(logits, dim=1).item() | |
predicted_label = model.config.id2label[predicted_class_id] | |
return predicted_label | |
def main(): | |
st.title("Assistant de chat INWI") | |
st.markdown( | |
""" | |
Bienvenue sur votre assistant IA INWI. | |
Posez votre question ici et obtenez une réponse instantanée ! | |
""" | |
) | |
# Initialize session state variables. | |
initialize_session_state() | |
# Sidebar controls: allow resetting the conversation. | |
st.sidebar.header("Conversation Controls") | |
if st.sidebar.button("Reset Conversation"): | |
st.session_state.chat_history = [] | |
st.session_state.pipeline = None | |
st.session_state.pipeline_language = None | |
st.rerun() | |
# Text input for the user message. | |
user_input = st.text_input("Your message:", key="input_text") | |
if st.button("Send") and user_input: | |
# Detect language of the user input. | |
try: | |
lang, __ = st.session_state.language_detector.detect_language(user_input) | |
st.write(f"**Detected language:** {lang}") | |
except Exception as e: | |
st.error(f"Language detection error: {e}") | |
lang = "fr" # Fallback if detection fails | |
# (Re)initialize the pipeline if it has not been created yet or if the language changed. | |
if st.session_state.pipeline is None or st.session_state.pipeline_language != lang: | |
st.info(f"Initializing pipeline for language: {lang}") | |
try: | |
st.session_state.pipeline = get_pipeline_for_language(lang) | |
st.session_state.pipeline_language = lang | |
except Exception as e: | |
st.error(f"Error initializing pipeline: {e}") | |
return | |
# Generate the response using the RAG pipeline. | |
try: | |
response = st.session_state.pipeline.generate_response( | |
query=user_input, | |
chat_history=st.session_state.chat_history | |
) | |
############ | |
# Classify the answer (RAG pipeline's response) | |
predicted_label = predict_class(user_input) #response | |
st.write(f"**Predicted class for the answer**: {predicted_label}") | |
id_ref = list(data_cls_fr.loc[data_cls_fr["Réponse"]==predicted_label]["Référence"])[0] | |
st.write(f"**Predicted refrence for the answer**: {id_ref}") | |
except Exception as e: | |
response = f"Error generating response: {e}" | |
# Update the conversation history. | |
st.session_state.chat_history.append(HumanMessage(content=user_input)) | |
st.session_state.chat_history.append(AIMessage(content=response)) | |
# Display the conversation history. | |
st.markdown("## Conversation") | |
for msg in st.session_state.chat_history: | |
if isinstance(msg, HumanMessage): | |
st.markdown(f"**User:** {msg.content}") | |
elif isinstance(msg, AIMessage): | |
st.markdown(f"**Assistant:** {msg.content}") | |
if __name__ == '__main__': | |
main() | |