Viral-808 / app.py
Fred808's picture
Update app.py
9aaf00a verified
# api/main.py
from fastapi import FastAPI, HTTPException, Depends
from pydantic import BaseModel, constr
from typing import List, Dict
import logging
import requests
from io import BytesIO
from PIL import Image
import pytesseract
from textblob import TextBlob
import pandas as pd
import joblib
from sqlalchemy.orm import Session
from utils.database import init_db, save_to_db, fetch_posts_from_db, get_db
from utils.instaloader_utils import fetch_user_posts, fetch_competitors_posts
import torch
from torchvision import transforms
from transformers import ResNetForImageClassification
import re
import time
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Initialize FastAPI app
app = FastAPI()
# Initialize database
init_db()
# Load models
viral_model = joblib.load("models/viral_potential_model.pkl")
engagement_model = joblib.load("models/engagement_rate_model.pkl")
promotion_model = joblib.load("models/promotion_strategy_model.pkl")
class UserRequest(BaseModel):
username: str
class AnalyzePostRequest(BaseModel):
caption: str
hashtags: str
image_url: str
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Rate limiting variables
RATE_LIMIT_DELAY = 5 # Delay in seconds between API calls
LAST_REQUEST_TIME = 0
@app.post("/fetch-posts")
async def fetch_posts(user: UserRequest):
"""
Fetch posts from a given Instagram profile (public data only).
"""
global LAST_REQUEST_TIME
username = user.username
logger.info(f"Fetching posts for user: {username}")
# Rate limiting
current_time = time.time()
if current_time - LAST_REQUEST_TIME < RATE_LIMIT_DELAY:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Please wait a few seconds before making another request."
)
LAST_REQUEST_TIME = current_time
try:
# Fetch user's posts
user_posts = fetch_user_posts(username)
if not user_posts:
logger.warning(f"No posts found for user: {username}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No posts found for the user."
)
# Combine user and competitor data
all_posts = await user_posts
# Save data to the database
if not save_to_db(all_posts):
logger.error("Failed to save data to the database.")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to save data to the database."
)
# Return success response
return {
"status": "success",
"data": all_posts,
"message": f"Successfully fetched {len(all_posts)} posts."
}
except HTTPException as e:
# Re-raise HTTPException to return specific error messages
raise e
except Exception as e:
logger.error(f"Unexpected error fetching posts: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An unexpected error occurred. Please try again later."
)
@app.post("/analyze")
async def analyze(user: UserRequest, db: Session = Depends(get_db)):
"""
Analyze user and competitor data.
"""
username = user.username
logging.info(f"Analyzing data for user: {username}")
try:
# Fetch data from the database
user_posts = fetch_posts_from_db(username)
if not user_posts:
raise HTTPException(status_code=404, detail="No posts found for the user.")
# Perform analysis (e.g., viral potential, engagement rate, etc.)
analysis_results = {
"viral_potential": predict_viral_potential(user_posts),
"top_hashtags": recommend_hashtags(user_posts),
"engagement_stats": {
"mean_likes": sum(post['likes'] for post in user_posts) / len(user_posts),
"mean_comments": sum(post['comments'] for post in user_posts) / len(user_posts)
}
}
return {"status": "success", "results": analysis_results}
except Exception as e:
logging.error(f"Error analyzing data: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/analyze-post")
async def analyze_post(post: AnalyzePostRequest, db: Session = Depends(get_db)):
"""
Analyze a single post (caption, hashtags, and image).
"""
try:
# Download and analyze the image
response = requests.get(post.image_url)
response.raise_for_status()
image = Image.open(BytesIO(response.content))
# Extract text from the image
extracted_text = extract_text_from_image(image)
# Analyze the image content
image_analysis = analyze_image(image)
# Preprocess input for models
features = {
'caption_length': len(post.caption),
'hashtag_count': len(post.hashtags.split(",")),
'sentiment': TextBlob(post.caption).sentiment.polarity
}
features_df = pd.DataFrame([features])
# Make predictions
viral_score = viral_model.predict_proba(features_df)[0][1]
engagement_rate = engagement_model.predict(features_df)[0]
promote = promotion_model.predict(features_df)[0]
# Save post to database
post_data = {
"caption": post.caption,
"hashtags": post.hashtags,
"image_url": post.image_url,
"engagement_rate": engagement_rate,
"viral_score": viral_score,
"promote": bool(promote)
}
save_to_db([post_data])
return {
"extracted_text": extracted_text,
"image_analysis": image_analysis,
"viral_score": viral_score,
"engagement_rate": engagement_rate,
"promote": bool(promote)
}
except Exception as e:
logging.error(f"Error analyzing post: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Image processing functions
def resize_image(image, max_size=(800, 600)):
"""Resize an image to the specified maximum size."""
image.thumbnail(max_size)
return image
def extract_text_from_image(image):
"""Extract text from an image using OCR."""
try:
image = resize_image(image)
text = pytesseract.image_to_string(image)
return text
except Exception as e:
logging.error(f"Error extracting text from image: {e}")
return ""
def analyze_image(image):
"""Analyze image content using a pre-trained model."""
try:
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
image_tensor = preprocess(image).unsqueeze(0)
# Load ResNet model
model = ResNetForImageClassification.from_pretrained("microsoft/resnet-50")
model.eval()
with torch.no_grad():
output = model(image_tensor)
return output.logits.tolist() # Return the logits as a list
except Exception as e:
logging.error(f"Error analyzing image: {e}")
return None
# Helper functions
def predict_viral_potential(posts: List[Dict]) -> List[Dict]:
"""
Predict viral potential for posts.
"""
# Placeholder for viral potential prediction logic
return [{"caption": post["caption"], "viral_score": 0.8} for post in posts]
def recommend_hashtags(posts: List[Dict]) -> List[str]:
"""
Recommend trending hashtags.
"""
hashtags = [hashtag for post in posts for hashtag in post['hashtags']]
hashtag_counts = Counter(hashtags)
return [hashtag for hashtag, _ in hashtag_counts.most_common(10)]
# Run the API
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)