import streamlit as st import torch from transformers import AlbertTokenizer, AlbertForSequenceClassification, AlbertModel import numpy as np import pandas as pd import os from torch.nn.functional import softmax import torch.nn as nn # Paths LEVEL_DIRS = { 1: 'level1', 2: 'level2', 3: 'level3', 4: 'level4', 5: 'level5', 6: 'level6', 7: 'level7' } MAPPING_FILE = 'mapping.csv' MODEL_NAME = 'albert/albert-base-v2' # Define the base model name # Load mapping mapping_df = pd.read_csv(MAPPING_FILE) def get_label_text(level, predicted_id): level_map = {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6} level_num = level_map.get(level) if level_num is not None: row = mapping_df[(mapping_df['level'] == level_num) & (mapping_df['id'] == predicted_id)] return row['text'].iloc[0] if not row.empty else "Description not found" return "Invalid Level" def predict_level(level, text, parent_prediction_id=None, checkpoint_path=None): level_dir = LEVEL_DIRS[level] tokenizer = AlbertTokenizer.from_pretrained(checkpoint_path) label_map = np.load(os.path.join(level_dir, 'label_map.npy'), allow_pickle=True).item() num_labels = len(label_map) if level == 1: model = AlbertForSequenceClassification.from_pretrained(checkpoint_path) else: parent_level_dir = LEVEL_DIRS[level - 1] parent_label_map = np.load(os.path.join(parent_level_dir, 'label_map.npy'), allow_pickle=True).item() num_parent_labels = len(parent_label_map) class TaxonomyClassifier(nn.Module): def __init__(self, base_model_name, num_parent_labels, num_labels): super().__init__() self.albert = AlbertModel.from_pretrained(base_model_name) self.dropout = nn.Dropout(0.1) self.classifier = nn.Linear(self.albert.config.hidden_size + num_parent_labels, num_labels) def forward(self, input_ids, attention_mask, parent_ids): outputs = self.albert(input_ids, attention_mask=attention_mask) pooled_output = outputs.pooler_output pooled_output = self.dropout(pooled_output) combined_features = torch.cat((pooled_output, parent_ids), dim=1) logits = self.classifier(combined_features) return logits model = TaxonomyClassifier(MODEL_NAME, num_parent_labels, num_labels) model.load_state_dict(torch.load(os.path.join(checkpoint_path, 'model.safetensors'), map_location=torch.device('cpu'))) model.eval() inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) if level > 1: parent_label_map_current = np.load(os.path.join(LEVEL_DIRS[level - 1], 'label_map.npy'), allow_pickle=True).item() num_parent_labels_current = len(parent_label_map_current) parent_one_hot = torch.zeros(num_parent_labels_current) if parent_prediction_id != 0: parent_index = parent_label_map_current.get(parent_prediction_id) if parent_index is not None: parent_one_hot[parent_index] = 1.0 with torch.no_grad(): outputs = model(inputs.input_ids, attention_mask=inputs.attention_mask, parent_ids=parent_one_hot.unsqueeze(0)) else: with torch.no_grad(): outputs = model(**inputs) probabilities = softmax(outputs.logits if level == 1 else outputs, dim=-1)[0] top3_prob, top3_indices = torch.topk(probabilities, 3) index_to_label = {v: k for k, v in label_map.items()} results = [] for prob, index in zip(top3_prob, top3_indices): predicted_label_id = index_to_label[index.item()] results.append((predicted_label_id, prob.item())) return results st.title("Taxonomy Model Inference") input_text = st.text_area("Enter text to classify", "Experience the magic of music with the Clavinova CLP-800 series. This versatile range of digital pianos is designed to delight everyone, from budding musicians to seasoned pianists. Each model combines state-of-the-art technology with the realistic touch and tone of world-renowned grand pianos, enhanced by GrandTouch keyboard action and Virtual Resonance Modeling. With seamless Bluetooth® connectivity, built-in lessons, and elegant design, the CLP-800 series offers the perfect blend of tradition and innovation. Elevate your musical journey with the warmth and sophistication of the Yamaha Clavinova, our finest series of digital pianos.") softmax_threshold = st.slider("Softmax Threshold", min_value=0.0, max_value=1.0, value=0.5, step=0.05) # Checkpoint Selection available_levels = [] level_checkpoints = {} for level in LEVEL_DIRS: level_dir = LEVEL_DIRS[level] if os.path.exists(level_dir): options = [d for d in os.listdir(level_dir) if os.path.isdir(os.path.join(level_dir, d))] options = [d for d in options if 'step' in d or d == 'model'] options.sort(key=lambda x: (('step' not in x), int(x.split('step')[-1]) if 'step' in x else -1)) level_checkpoints[level] = [os.path.join(level_dir, opt) for opt in options] if level_checkpoints[level]: available_levels.append(level) else: level_checkpoints[level] = [] selected_checkpoints = {} for level in available_levels: selected_checkpoints[level] = st.selectbox(f"Select Level {level} Checkpoint", options=level_checkpoints[level]) if st.button("Run Inference"): if input_text: all_level_results = {} current_prediction_id = None last_level = 0 for level in sorted(available_levels): if selected_checkpoints[level]: checkpoint_path = selected_checkpoints[level] if level == 1: level_results = predict_level(level, input_text, checkpoint_path=checkpoint_path) else: if current_prediction_id == 0: st.info(f"Taxonomy terminated at Level {last_level} with ID 0.") break level_results = predict_level(level, input_text, parent_prediction_id=current_prediction_id, checkpoint_path=checkpoint_path) if level_results[0][1] < softmax_threshold: st.info(f"Inference stopped at Level {level} due to softmax probability ({level_results[0][1]:.3f}) being below the threshold.") break all_level_results[level] = level_results current_prediction_id = level_results[0][0] last_level = level else: st.warning(f"Skipping Level {level} as no checkpoint is selected.") break data = [] for level in sorted(all_level_results.keys()): results = all_level_results[level] data.append({ 'level': level, 'text': get_label_text(level - 1, results[0][0]), 'softmax': f"{results[0][1]:.3f}", 'runner_up_1_id': results[1][0], 'runner_up_1_text': get_label_text(level - 1, results[1][0]), 'runner_up_1_softmax': f"{results[1][1]:.3f}", 'runner_up_2_id': results[2][0], 'runner_up_2_text': get_label_text(level - 1, results[2][0]), 'runner_up_2_softmax': f"{results[2][1]:.3f}", }) if data: df = pd.DataFrame(data) st.dataframe(df) else: st.info("No predictions made or inference stopped.") else: st.warning("Please enter text for classification.")