File size: 8,125 Bytes
58e450d 68f1b21 58e450d aff69f7 58e450d 6c6ffe2 faaa5be 58e450d 75dc8d7 58e450d 4e1cf33 68f1b21 4e1cf33 58e450d 68f1b21 58e450d 68f1b21 58e450d 68f1b21 58e450d 9aaf00a 58e450d 68f1b21 58e450d 68f1b21 58e450d 68f1b21 58e450d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 |
# api/main.py
from fastapi import FastAPI, HTTPException, Depends
from pydantic import BaseModel, constr
from typing import List, Dict
import logging
import requests
from io import BytesIO
from PIL import Image
import pytesseract
from textblob import TextBlob
import pandas as pd
import joblib
from sqlalchemy.orm import Session
from utils.database import init_db, save_to_db, fetch_posts_from_db, get_db
from utils.instaloader_utils import fetch_user_posts, fetch_competitors_posts
import torch
from torchvision import transforms
from transformers import ResNetForImageClassification
import re
import time
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Initialize FastAPI app
app = FastAPI()
# Initialize database
init_db()
# Load models
viral_model = joblib.load("models/viral_potential_model.pkl")
engagement_model = joblib.load("models/engagement_rate_model.pkl")
promotion_model = joblib.load("models/promotion_strategy_model.pkl")
class UserRequest(BaseModel):
username: str
class AnalyzePostRequest(BaseModel):
caption: str
hashtags: str
image_url: str
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Rate limiting variables
RATE_LIMIT_DELAY = 5 # Delay in seconds between API calls
LAST_REQUEST_TIME = 0
@app.post("/fetch-posts")
async def fetch_posts(user: UserRequest):
"""
Fetch posts from a given Instagram profile (public data only).
"""
global LAST_REQUEST_TIME
username = user.username
logger.info(f"Fetching posts for user: {username}")
# Rate limiting
current_time = time.time()
if current_time - LAST_REQUEST_TIME < RATE_LIMIT_DELAY:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Please wait a few seconds before making another request."
)
LAST_REQUEST_TIME = current_time
try:
# Fetch user's posts
user_posts = fetch_user_posts(username)
if not user_posts:
logger.warning(f"No posts found for user: {username}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No posts found for the user."
)
# Combine user and competitor data
all_posts = await user_posts
# Save data to the database
if not save_to_db(all_posts):
logger.error("Failed to save data to the database.")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to save data to the database."
)
# Return success response
return {
"status": "success",
"data": all_posts,
"message": f"Successfully fetched {len(all_posts)} posts."
}
except HTTPException as e:
# Re-raise HTTPException to return specific error messages
raise e
except Exception as e:
logger.error(f"Unexpected error fetching posts: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An unexpected error occurred. Please try again later."
)
@app.post("/analyze")
async def analyze(user: UserRequest, db: Session = Depends(get_db)):
"""
Analyze user and competitor data.
"""
username = user.username
logging.info(f"Analyzing data for user: {username}")
try:
# Fetch data from the database
user_posts = fetch_posts_from_db(username)
if not user_posts:
raise HTTPException(status_code=404, detail="No posts found for the user.")
# Perform analysis (e.g., viral potential, engagement rate, etc.)
analysis_results = {
"viral_potential": predict_viral_potential(user_posts),
"top_hashtags": recommend_hashtags(user_posts),
"engagement_stats": {
"mean_likes": sum(post['likes'] for post in user_posts) / len(user_posts),
"mean_comments": sum(post['comments'] for post in user_posts) / len(user_posts)
}
}
return {"status": "success", "results": analysis_results}
except Exception as e:
logging.error(f"Error analyzing data: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/analyze-post")
async def analyze_post(post: AnalyzePostRequest, db: Session = Depends(get_db)):
"""
Analyze a single post (caption, hashtags, and image).
"""
try:
# Download and analyze the image
response = requests.get(post.image_url)
response.raise_for_status()
image = Image.open(BytesIO(response.content))
# Extract text from the image
extracted_text = extract_text_from_image(image)
# Analyze the image content
image_analysis = analyze_image(image)
# Preprocess input for models
features = {
'caption_length': len(post.caption),
'hashtag_count': len(post.hashtags.split(",")),
'sentiment': TextBlob(post.caption).sentiment.polarity
}
features_df = pd.DataFrame([features])
# Make predictions
viral_score = viral_model.predict_proba(features_df)[0][1]
engagement_rate = engagement_model.predict(features_df)[0]
promote = promotion_model.predict(features_df)[0]
# Save post to database
post_data = {
"caption": post.caption,
"hashtags": post.hashtags,
"image_url": post.image_url,
"engagement_rate": engagement_rate,
"viral_score": viral_score,
"promote": bool(promote)
}
save_to_db([post_data])
return {
"extracted_text": extracted_text,
"image_analysis": image_analysis,
"viral_score": viral_score,
"engagement_rate": engagement_rate,
"promote": bool(promote)
}
except Exception as e:
logging.error(f"Error analyzing post: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Image processing functions
def resize_image(image, max_size=(800, 600)):
"""Resize an image to the specified maximum size."""
image.thumbnail(max_size)
return image
def extract_text_from_image(image):
"""Extract text from an image using OCR."""
try:
image = resize_image(image)
text = pytesseract.image_to_string(image)
return text
except Exception as e:
logging.error(f"Error extracting text from image: {e}")
return ""
def analyze_image(image):
"""Analyze image content using a pre-trained model."""
try:
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
image_tensor = preprocess(image).unsqueeze(0)
# Load ResNet model
model = ResNetForImageClassification.from_pretrained("microsoft/resnet-50")
model.eval()
with torch.no_grad():
output = model(image_tensor)
return output.logits.tolist() # Return the logits as a list
except Exception as e:
logging.error(f"Error analyzing image: {e}")
return None
# Helper functions
def predict_viral_potential(posts: List[Dict]) -> List[Dict]:
"""
Predict viral potential for posts.
"""
# Placeholder for viral potential prediction logic
return [{"caption": post["caption"], "viral_score": 0.8} for post in posts]
def recommend_hashtags(posts: List[Dict]) -> List[str]:
"""
Recommend trending hashtags.
"""
hashtags = [hashtag for post in posts for hashtag in post['hashtags']]
hashtag_counts = Counter(hashtags)
return [hashtag for hashtag, _ in hashtag_counts.most_common(10)]
# Run the API
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000) |