YT-Trainer / app.py
Fred808's picture
Update app.py
5e6d68b verified
import re
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import (
pipeline,
AutoModelForSequenceClassification,
AutoTokenizer,
AutoModelForCausalLM,
T5Tokenizer,
T5ForConditionalGeneration,
)
from sentence_transformers import SentenceTransformer
from bertopic import BERTopic
import faiss
import numpy as np
from datasets import load_dataset, Features, Value
# Initialize FastAPI app
app = FastAPI()
# Preprocessing function
def preprocess_text(text):
"""
Cleans and tokenizes text.
"""
text = re.sub(r"http\S+|www\S+|https\S+", "", text, flags=re.MULTILINE) # Remove URLs
text = re.sub(r"\s+", " ", text).strip() # Remove extra spaces
text = re.sub(r"[^\w\s]", "", text) # Remove punctuation
return text.lower()
# Content Classification Model
class ContentClassifier:
def __init__(self, model_name="bert-base-uncased"):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
self.pipeline = pipeline("text-classification", model=self.model, tokenizer=self.tokenizer)
def classify(self, text):
"""
Classifies text into predefined categories.
"""
result = self.pipeline(text)
return result
# Relevance Detection Model
class RelevanceDetector:
def __init__(self, model_name="bert-base-uncased"):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
self.pipeline = pipeline("text-classification", model=self.model, tokenizer=self.tokenizer)
def detect_relevance(self, text, threshold=0.5):
"""
Detects whether a text is relevant to a specific domain.
"""
result = self.pipeline(text)
return result[0]["label"] == "RELEVANT" and result[0]["score"] > threshold
# Topic Extraction Model using BERTopic
class TopicExtractor:
def __init__(self):
self.model = BERTopic()
def extract_topics(self, documents):
"""
Extracts topics from a list of documents.
"""
topics, probs = self.model.fit_transform(documents)
return self.model.get_topic_info()
# Summarization Model
class Summarizer:
def __init__(self, model_name="t5-small"):
self.tokenizer = T5Tokenizer.from_pretrained(model_name)
self.model = T5ForConditionalGeneration.from_pretrained(model_name)
def summarize(self, text, max_length=100):
"""
Summarizes a given text.
"""
inputs = self.tokenizer.encode("summarize: " + text, return_tensors="pt", max_length=512, truncation=True)
summary_ids = self.model.generate(inputs, max_length=max_length, min_length=25, length_penalty=2.0, num_beams=4)
summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return summary
# Search and Recommendation Model using FAISS
class SearchEngine:
def __init__(self, embedding_model="sentence-transformers/all-MiniLM-L6-v2"):
self.model = SentenceTransformer(embedding_model)
self.index = None
self.documents = []
def build_index(self, docs):
"""
Builds a FAISS index for document retrieval.
"""
self.documents = docs
embeddings = self.model.encode(docs, convert_to_tensor=True, show_progress_bar=True)
self.index = faiss.IndexFlatL2(embeddings.shape[1])
self.index.add(embeddings.cpu().detach().numpy())
def search(self, query, top_k=5):
"""
Searches the index for the top_k most relevant documents.
"""
query_embedding = self.model.encode(query, convert_to_tensor=True)
distances, indices = self.index.search(query_embedding.cpu().detach().numpy().reshape(1, -1), top_k)
# Convert NumPy data types to native Python types
results = []
for i in indices[0]:
document = self.documents[i]
distance = float(distances[0][i]) # Convert numpy.float32 to float
results.append({"document": document, "distance": distance})
return results
class Chatbot:
def __init__(self, model_name="EleutherAI/gpt-neo-125M"):
"""
Initializes the chatbot with GPT-Neo.
"""
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name)
# Set pad_token to eos_token if not already defined
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
def generate_response(self, prompt, max_length=100):
"""
Generates a response to a user query using GPT-Neo.
"""
# Tokenize the input prompt
inputs = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
# Generate the response
outputs = self.model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask, # Pass the attention mask
max_length=max_length,
num_return_sequences=1,
pad_token_id=self.tokenizer.pad_token_id, # Use the defined pad_token_id
)
# Decode the generated response
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
def handle_request(self, prompt):
"""
Handles user requests by determining the intent and delegating to the appropriate function.
"""
# Check if the user wants to search for something
if "search" in prompt.lower():
query = prompt.lower().replace("search", "").strip()
results = search_engine.search(query)
return {"type": "search", "results": results}
# Check if the user wants a summary
elif "summarize" in prompt.lower() or "summary" in prompt.lower():
text = prompt.lower().replace("summarize", "").replace("summary", "").strip()
summary = summarizer.summarize(text)
return {"type": "summary", "summary": summary}
# Check if the user wants to extract topics
elif "topics" in prompt.lower() or "topic" in prompt.lower():
text = prompt.lower().replace("topics", "").replace("topic", "").strip()
topics = topic_extractor.extract_topics([text])
return {"type": "topics", "topics": topics.to_dict()}
# Default to generating a conversational response
else:
response = self.generate_response(prompt)
return {"type": "chat", "response": response}
# Initialize models
classifier = ContentClassifier()
relevance_detector = RelevanceDetector()
summarizer = Summarizer()
search_engine = SearchEngine()
topic_extractor = TopicExtractor()
chatbot = Chatbot()
# Initialize the search engine with a sample dataset
documents = [
"This video explains Instagram growth hacks.",
"Learn how to use hashtags effectively on Instagram.",
"Collaborations are key to growing your Instagram audience."
]
search_engine.build_index(documents)
# Define the schema
features = Features({
"video_id": Value("string"),
"video_link": Value("string"),
"title": Value("string"),
"text": Value("string"),
"channel": Value("string"),
"channel_id": Value("string"),
"date": Value("string"),
"license": Value("string"),
"original_language": Value("string"),
"source_language": Value("string"),
"transcription_language": Value("string"),
"word_count": Value("int64"),
"character_count": Value("int64"),
})
# Load the dataset from Hugging Face Hub
try:
dataset = load_dataset(
"PleIAs/YouTube-Commons",
features=features,
streaming=True,
)
# Process the dataset
for example in dataset["train"]:
print(example) # Process each example
break # Stop after the first example for demonstration
except Exception as e:
print(f"Error loading dataset: {e}")
# Pydantic models for request validation
class TextRequest(BaseModel):
text: str
class QueryRequest(BaseModel):
query: str
class PromptRequest(BaseModel):
prompt: str
# API Endpoints
@app.post("/classify")
async def classify(request: TextRequest):
text = request.text
if not text:
raise HTTPException(status_code=400, detail="No text provided")
result = classifier.classify(text)
return {"result": result}
@app.post("/relevance")
async def relevance(request: TextRequest):
text = request.text
if not text:
raise HTTPException(status_code=400, detail="No text provided")
relevant = relevance_detector.detect_relevance(text)
return {"relevant": relevant}
@app.post("/summarize")
async def summarize(request: TextRequest):
text = request.text
if not text:
raise HTTPException(status_code=400, detail="No text provided")
summary = summarizer.summarize(text)
return {"summary": summary}
@app.post("/search")
async def search(request: QueryRequest):
query = request.query
if not query:
raise HTTPException(status_code=400, detail="No query provided")
try:
results = search_engine.search(query)
return {"results": results}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/topics")
async def topics(request: TextRequest):
text = request.text
if not text:
raise HTTPException(status_code=400, detail="No text provided")
result = topic_extractor.extract_topics([text])
return {"topics": result.to_dict()}
@app.post("/chat")
async def chat(request: PromptRequest):
prompt = request.prompt
if not prompt:
raise HTTPException(status_code=400, detail="No prompt provided")
# Handle the request using the chatbot's handle_request method
result = chatbot.handle_request(prompt)
# Return the appropriate response based on the type of request
if result["type"] == "search":
return {"type": "search", "results": result["results"]}
elif result["type"] == "summary":
return {"type": "summary", "summary": result["summary"]}
elif result["type"] == "topics":
return {"type": "topics", "topics": result["topics"]}
else:
return {"type": "chat", "response": result["response"]}
# Start the FastAPI app
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)