|
|
|
from fastapi import FastAPI, HTTPException, Depends |
|
from pydantic import BaseModel, constr |
|
from typing import List, Dict |
|
import logging |
|
import requests |
|
from io import BytesIO |
|
from PIL import Image |
|
import pytesseract |
|
from textblob import TextBlob |
|
import pandas as pd |
|
import joblib |
|
from sqlalchemy.orm import Session |
|
from utils.database import init_db, save_to_db, fetch_posts_from_db, get_db |
|
from utils.instaloader_utils import fetch_user_posts, fetch_competitors_posts |
|
import torch |
|
from torchvision import transforms |
|
from transformers import ResNetForImageClassification |
|
import re |
|
import time |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
init_db() |
|
|
|
|
|
viral_model = joblib.load("models/viral_potential_model.pkl") |
|
engagement_model = joblib.load("models/engagement_rate_model.pkl") |
|
promotion_model = joblib.load("models/promotion_strategy_model.pkl") |
|
|
|
class UserRequest(BaseModel): |
|
username: str |
|
|
|
class AnalyzePostRequest(BaseModel): |
|
caption: str |
|
hashtags: str |
|
image_url: str |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
RATE_LIMIT_DELAY = 5 |
|
LAST_REQUEST_TIME = 0 |
|
|
|
|
|
@app.post("/fetch-posts") |
|
async def fetch_posts(user: UserRequest): |
|
""" |
|
Fetch posts from a given Instagram profile (public data only). |
|
""" |
|
global LAST_REQUEST_TIME |
|
|
|
username = user.username |
|
logger.info(f"Fetching posts for user: {username}") |
|
|
|
|
|
current_time = time.time() |
|
if current_time - LAST_REQUEST_TIME < RATE_LIMIT_DELAY: |
|
raise HTTPException( |
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS, |
|
detail="Please wait a few seconds before making another request." |
|
) |
|
LAST_REQUEST_TIME = current_time |
|
|
|
try: |
|
|
|
user_posts = fetch_user_posts(username) |
|
if not user_posts: |
|
logger.warning(f"No posts found for user: {username}") |
|
raise HTTPException( |
|
status_code=status.HTTP_404_NOT_FOUND, |
|
detail="No posts found for the user." |
|
) |
|
|
|
|
|
|
|
all_posts = await user_posts |
|
|
|
|
|
if not save_to_db(all_posts): |
|
logger.error("Failed to save data to the database.") |
|
raise HTTPException( |
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
|
detail="Failed to save data to the database." |
|
) |
|
|
|
|
|
return { |
|
"status": "success", |
|
"data": all_posts, |
|
"message": f"Successfully fetched {len(all_posts)} posts." |
|
} |
|
|
|
except HTTPException as e: |
|
|
|
raise e |
|
except Exception as e: |
|
logger.error(f"Unexpected error fetching posts: {e}") |
|
raise HTTPException( |
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
|
detail="An unexpected error occurred. Please try again later." |
|
) |
|
|
|
@app.post("/analyze") |
|
async def analyze(user: UserRequest, db: Session = Depends(get_db)): |
|
""" |
|
Analyze user and competitor data. |
|
""" |
|
username = user.username |
|
logging.info(f"Analyzing data for user: {username}") |
|
|
|
try: |
|
|
|
user_posts = fetch_posts_from_db(username) |
|
if not user_posts: |
|
raise HTTPException(status_code=404, detail="No posts found for the user.") |
|
|
|
|
|
analysis_results = { |
|
"viral_potential": predict_viral_potential(user_posts), |
|
"top_hashtags": recommend_hashtags(user_posts), |
|
"engagement_stats": { |
|
"mean_likes": sum(post['likes'] for post in user_posts) / len(user_posts), |
|
"mean_comments": sum(post['comments'] for post in user_posts) / len(user_posts) |
|
} |
|
} |
|
|
|
return {"status": "success", "results": analysis_results} |
|
except Exception as e: |
|
logging.error(f"Error analyzing data: {e}") |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@app.post("/analyze-post") |
|
async def analyze_post(post: AnalyzePostRequest, db: Session = Depends(get_db)): |
|
""" |
|
Analyze a single post (caption, hashtags, and image). |
|
""" |
|
try: |
|
|
|
response = requests.get(post.image_url) |
|
response.raise_for_status() |
|
image = Image.open(BytesIO(response.content)) |
|
|
|
|
|
extracted_text = extract_text_from_image(image) |
|
|
|
|
|
image_analysis = analyze_image(image) |
|
|
|
|
|
features = { |
|
'caption_length': len(post.caption), |
|
'hashtag_count': len(post.hashtags.split(",")), |
|
'sentiment': TextBlob(post.caption).sentiment.polarity |
|
} |
|
features_df = pd.DataFrame([features]) |
|
|
|
|
|
viral_score = viral_model.predict_proba(features_df)[0][1] |
|
engagement_rate = engagement_model.predict(features_df)[0] |
|
promote = promotion_model.predict(features_df)[0] |
|
|
|
|
|
post_data = { |
|
"caption": post.caption, |
|
"hashtags": post.hashtags, |
|
"image_url": post.image_url, |
|
"engagement_rate": engagement_rate, |
|
"viral_score": viral_score, |
|
"promote": bool(promote) |
|
} |
|
save_to_db([post_data]) |
|
|
|
return { |
|
"extracted_text": extracted_text, |
|
"image_analysis": image_analysis, |
|
"viral_score": viral_score, |
|
"engagement_rate": engagement_rate, |
|
"promote": bool(promote) |
|
} |
|
except Exception as e: |
|
logging.error(f"Error analyzing post: {e}") |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
def resize_image(image, max_size=(800, 600)): |
|
"""Resize an image to the specified maximum size.""" |
|
image.thumbnail(max_size) |
|
return image |
|
|
|
def extract_text_from_image(image): |
|
"""Extract text from an image using OCR.""" |
|
try: |
|
image = resize_image(image) |
|
text = pytesseract.image_to_string(image) |
|
return text |
|
except Exception as e: |
|
logging.error(f"Error extracting text from image: {e}") |
|
return "" |
|
|
|
def analyze_image(image): |
|
"""Analyze image content using a pre-trained model.""" |
|
try: |
|
preprocess = transforms.Compose([ |
|
transforms.Resize(256), |
|
transforms.CenterCrop(224), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
]) |
|
image_tensor = preprocess(image).unsqueeze(0) |
|
|
|
|
|
model = ResNetForImageClassification.from_pretrained("microsoft/resnet-50") |
|
model.eval() |
|
|
|
with torch.no_grad(): |
|
output = model(image_tensor) |
|
return output.logits.tolist() |
|
except Exception as e: |
|
logging.error(f"Error analyzing image: {e}") |
|
return None |
|
|
|
|
|
def predict_viral_potential(posts: List[Dict]) -> List[Dict]: |
|
""" |
|
Predict viral potential for posts. |
|
""" |
|
|
|
return [{"caption": post["caption"], "viral_score": 0.8} for post in posts] |
|
|
|
def recommend_hashtags(posts: List[Dict]) -> List[str]: |
|
""" |
|
Recommend trending hashtags. |
|
""" |
|
hashtags = [hashtag for post in posts for hashtag in post['hashtags']] |
|
hashtag_counts = Counter(hashtags) |
|
return [hashtag for hashtag, _ in hashtag_counts.most_common(10)] |
|
|
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=8000) |