Viral-808 / utils /database.py
Fred808's picture
Update utils/database.py
9ba9239 verified
import json
from typing import List, Dict
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker, declarative_base
from sqlalchemy import Column, Integer, String, Float, Boolean, Text, select
import asyncio
# Use the async database URL (replace "postgresql" with "postgresql+asyncpg")
DATABASE_URL = "postgresql+asyncpg://postgres.lgbnxplydqdymepehirg:[email protected]:5432/postgres"
# Initialize the async engine
engine = create_async_engine(DATABASE_URL, echo=True)
# Create an async session maker
AsyncSessionLocal = sessionmaker(
bind=engine,
class_=AsyncSession,
expire_on_commit=False,
)
Base = declarative_base()
# Define the posts table using SQLAlchemy ORM
class Post(Base):
__tablename__ = "posts"
id = Column(Integer, primary_key=True, index=True)
username = Column(String, nullable=False)
caption = Column(Text, nullable=True)
hashtags = Column(Text, nullable=True) # Store as JSON string
likes = Column(Integer, default=0)
comments = Column(Integer, default=0)
date = Column(String, nullable=True)
image_url = Column(String, unique=True, nullable=False)
engagement_rate = Column(Float, default=0.0)
viral_score = Column(Float, default=0.0)
promote = Column(Boolean, default=False)
# Initialize the database (create tables)
async def init_db():
"""
Initialize the PostgreSQL database by creating tables.
"""
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
print("Database initialized.")
async def post_exists(session: AsyncSession, image_url: str) -> bool:
"""
Check if a post with the given image_url already exists in the database.
"""
result = await session.execute(select(Post).filter(Post.image_url == image_url))
return result.scalars().first() is not None
async def save_to_db(data: List[Dict]):
"""
Save data to the PostgreSQL database, avoiding duplicates.
"""
async with AsyncSessionLocal() as session:
for post in data:
if not await post_exists(session, post.get("image_url")):
new_post = Post(
username=post.get("username", ""),
caption=post.get("caption", ""),
hashtags=json.dumps(post.get("hashtags", [])), # Convert list to JSON string
likes=post.get("likes", 0),
comments=post.get("comments", 0),
date=post.get("date", ""),
image_url=post.get("image_url", ""),
engagement_rate=post.get("engagement_rate", 0.0),
viral_score=post.get("viral_score", 0.0),
promote=post.get("promote", False),
)
session.add(new_post)
await session.commit()
print("Data saved to database.")
async def fetch_posts_from_db(username: str) -> List[Dict]:
"""
Fetch posts from the database for a given username.
"""
async with AsyncSessionLocal() as session:
result = await session.execute(select(Post).filter(Post.username == username))
posts = result.scalars().all()
return [
{
"username": post.username,
"caption": post.caption,
"hashtags": json.loads(post.hashtags), # Convert JSON string back to list
"likes": post.likes,
"comments": post.comments,
"date": post.date,
"image_url": post.image_url,
"engagement_rate": post.engagement_rate,
"viral_score": post.viral_score,
"promote": post.promote,
}
for post in posts
]
async def get_db():
"""
Dependency to get a database session.
"""
async with AsyncSessionLocal() as session:
yield session
# Example usage
async def main():
# Initialize the database
await init_db()
# Example data to save
example_data = [
{
"username": "test_user",
"caption": "This is a test post",
"hashtags": ["test", "example"],
"likes": 10,
"comments": 2,
"date": "2025-01-27",
"image_url": "https://example.com/image1.jpg",
"engagement_rate": 0.5,
"viral_score": 0.8,
"promote": False,
}
]
# Save data to the database
await save_to_db(example_data)
# Fetch posts from the database
posts = await fetch_posts_from_db("test_user")
print("Fetched posts:", posts)