# api/main.py from fastapi import FastAPI, HTTPException, Depends from pydantic import BaseModel, constr from typing import List, Dict import logging import requests from io import BytesIO from PIL import Image import pytesseract from textblob import TextBlob import pandas as pd import joblib from sqlalchemy.orm import Session from utils.database import init_db, save_to_db, fetch_posts_from_db, get_db from utils.instaloader_utils import fetch_user_posts, fetch_competitors_posts import torch from torchvision import transforms from transformers import ResNetForImageClassification import re import time # Set up logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # Initialize FastAPI app app = FastAPI() # Initialize database init_db() # Load models viral_model = joblib.load("models/viral_potential_model.pkl") engagement_model = joblib.load("models/engagement_rate_model.pkl") promotion_model = joblib.load("models/promotion_strategy_model.pkl") class UserRequest(BaseModel): username: str class AnalyzePostRequest(BaseModel): caption: str hashtags: str image_url: str # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Rate limiting variables RATE_LIMIT_DELAY = 5 # Delay in seconds between API calls LAST_REQUEST_TIME = 0 @app.post("/fetch-posts") async def fetch_posts(user: UserRequest): """ Fetch posts from a given Instagram profile (public data only). """ global LAST_REQUEST_TIME username = user.username logger.info(f"Fetching posts for user: {username}") # Rate limiting current_time = time.time() if current_time - LAST_REQUEST_TIME < RATE_LIMIT_DELAY: raise HTTPException( status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="Please wait a few seconds before making another request." ) LAST_REQUEST_TIME = current_time try: # Fetch user's posts user_posts = fetch_user_posts(username) if not user_posts: logger.warning(f"No posts found for user: {username}") raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="No posts found for the user." ) # Combine user and competitor data all_posts = await user_posts # Save data to the database if not save_to_db(all_posts): logger.error("Failed to save data to the database.") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to save data to the database." ) # Return success response return { "status": "success", "data": all_posts, "message": f"Successfully fetched {len(all_posts)} posts." } except HTTPException as e: # Re-raise HTTPException to return specific error messages raise e except Exception as e: logger.error(f"Unexpected error fetching posts: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred. Please try again later." ) @app.post("/analyze") async def analyze(user: UserRequest, db: Session = Depends(get_db)): """ Analyze user and competitor data. """ username = user.username logging.info(f"Analyzing data for user: {username}") try: # Fetch data from the database user_posts = fetch_posts_from_db(username) if not user_posts: raise HTTPException(status_code=404, detail="No posts found for the user.") # Perform analysis (e.g., viral potential, engagement rate, etc.) analysis_results = { "viral_potential": predict_viral_potential(user_posts), "top_hashtags": recommend_hashtags(user_posts), "engagement_stats": { "mean_likes": sum(post['likes'] for post in user_posts) / len(user_posts), "mean_comments": sum(post['comments'] for post in user_posts) / len(user_posts) } } return {"status": "success", "results": analysis_results} except Exception as e: logging.error(f"Error analyzing data: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/analyze-post") async def analyze_post(post: AnalyzePostRequest, db: Session = Depends(get_db)): """ Analyze a single post (caption, hashtags, and image). """ try: # Download and analyze the image response = requests.get(post.image_url) response.raise_for_status() image = Image.open(BytesIO(response.content)) # Extract text from the image extracted_text = extract_text_from_image(image) # Analyze the image content image_analysis = analyze_image(image) # Preprocess input for models features = { 'caption_length': len(post.caption), 'hashtag_count': len(post.hashtags.split(",")), 'sentiment': TextBlob(post.caption).sentiment.polarity } features_df = pd.DataFrame([features]) # Make predictions viral_score = viral_model.predict_proba(features_df)[0][1] engagement_rate = engagement_model.predict(features_df)[0] promote = promotion_model.predict(features_df)[0] # Save post to database post_data = { "caption": post.caption, "hashtags": post.hashtags, "image_url": post.image_url, "engagement_rate": engagement_rate, "viral_score": viral_score, "promote": bool(promote) } save_to_db([post_data]) return { "extracted_text": extracted_text, "image_analysis": image_analysis, "viral_score": viral_score, "engagement_rate": engagement_rate, "promote": bool(promote) } except Exception as e: logging.error(f"Error analyzing post: {e}") raise HTTPException(status_code=500, detail=str(e)) # Image processing functions def resize_image(image, max_size=(800, 600)): """Resize an image to the specified maximum size.""" image.thumbnail(max_size) return image def extract_text_from_image(image): """Extract text from an image using OCR.""" try: image = resize_image(image) text = pytesseract.image_to_string(image) return text except Exception as e: logging.error(f"Error extracting text from image: {e}") return "" def analyze_image(image): """Analyze image content using a pre-trained model.""" try: preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) image_tensor = preprocess(image).unsqueeze(0) # Load ResNet model model = ResNetForImageClassification.from_pretrained("microsoft/resnet-50") model.eval() with torch.no_grad(): output = model(image_tensor) return output.logits.tolist() # Return the logits as a list except Exception as e: logging.error(f"Error analyzing image: {e}") return None # Helper functions def predict_viral_potential(posts: List[Dict]) -> List[Dict]: """ Predict viral potential for posts. """ # Placeholder for viral potential prediction logic return [{"caption": post["caption"], "viral_score": 0.8} for post in posts] def recommend_hashtags(posts: List[Dict]) -> List[str]: """ Recommend trending hashtags. """ hashtags = [hashtag for post in posts for hashtag in post['hashtags']] hashtag_counts = Counter(hashtags) return [hashtag for hashtag, _ in hashtag_counts.most_common(10)] # Run the API if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)