File size: 4,683 Bytes
58e450d
 
9ba9239
 
 
 
58e450d
9ba9239
 
 
 
 
 
 
 
 
 
 
 
5319d78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ba9239
 
58e450d
5319d78
58e450d
9ba9239
 
5319d78
58e450d
d6a44dc
58e450d
 
 
d6a44dc
 
58e450d
d6a44dc
58e450d
5319d78
58e450d
d6a44dc
5319d78
d6a44dc
5319d78
 
 
 
 
 
 
 
 
 
 
 
 
d6a44dc
5319d78
58e450d
9ba9239
58e450d
 
 
9ba9239
 
 
5319d78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58e450d
9ba9239
58e450d
 
 
9ba9239
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import json
from typing import List, Dict
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker, declarative_base
from sqlalchemy import Column, Integer, String, Float, Boolean, Text, select
import asyncio

# Use the async database URL (replace "postgresql" with "postgresql+asyncpg")
DATABASE_URL = "postgresql+asyncpg://postgres.lgbnxplydqdymepehirg:[email protected]:5432/postgres"

# Initialize the async engine
engine = create_async_engine(DATABASE_URL, echo=True)

# Create an async session maker
AsyncSessionLocal = sessionmaker(
    bind=engine,
    class_=AsyncSession,
    expire_on_commit=False,
)

Base = declarative_base()

# Define the posts table using SQLAlchemy ORM
class Post(Base):
    __tablename__ = "posts"
    
    id = Column(Integer, primary_key=True, index=True)
    username = Column(String, nullable=False)
    caption = Column(Text, nullable=True)
    hashtags = Column(Text, nullable=True)  # Store as JSON string
    likes = Column(Integer, default=0)
    comments = Column(Integer, default=0)
    date = Column(String, nullable=True)
    image_url = Column(String, unique=True, nullable=False)
    engagement_rate = Column(Float, default=0.0)
    viral_score = Column(Float, default=0.0)
    promote = Column(Boolean, default=False)

# Initialize the database (create tables)
async def init_db():
    """
    Initialize the PostgreSQL database by creating tables.
    """
    async with engine.begin() as conn:
        await conn.run_sync(Base.metadata.create_all)
    print("Database initialized.")

async def post_exists(session: AsyncSession, image_url: str) -> bool:
    """
    Check if a post with the given image_url already exists in the database.
    """
    result = await session.execute(select(Post).filter(Post.image_url == image_url))
    return result.scalars().first() is not None

async def save_to_db(data: List[Dict]):
    """
    Save data to the PostgreSQL database, avoiding duplicates.
    """
    async with AsyncSessionLocal() as session:
        for post in data:
            if not await post_exists(session, post.get("image_url")):
                new_post = Post(
                    username=post.get("username", ""),
                    caption=post.get("caption", ""),
                    hashtags=json.dumps(post.get("hashtags", [])),  # Convert list to JSON string
                    likes=post.get("likes", 0),
                    comments=post.get("comments", 0),
                    date=post.get("date", ""),
                    image_url=post.get("image_url", ""),
                    engagement_rate=post.get("engagement_rate", 0.0),
                    viral_score=post.get("viral_score", 0.0),
                    promote=post.get("promote", False),
                )
                session.add(new_post)
        await session.commit()
    print("Data saved to database.")

async def fetch_posts_from_db(username: str) -> List[Dict]:
    """
    Fetch posts from the database for a given username.
    """
    async with AsyncSessionLocal() as session:
        result = await session.execute(select(Post).filter(Post.username == username))
        posts = result.scalars().all()
        return [
            {
                "username": post.username,
                "caption": post.caption,
                "hashtags": json.loads(post.hashtags),  # Convert JSON string back to list
                "likes": post.likes,
                "comments": post.comments,
                "date": post.date,
                "image_url": post.image_url,
                "engagement_rate": post.engagement_rate,
                "viral_score": post.viral_score,
                "promote": post.promote,
            }
            for post in posts
        ]

async def get_db():
    """
    Dependency to get a database session.
    """
    async with AsyncSessionLocal() as session:
        yield session

# Example usage
async def main():
    # Initialize the database
    await init_db()

    # Example data to save
    example_data = [
        {
            "username": "test_user",
            "caption": "This is a test post",
            "hashtags": ["test", "example"],
            "likes": 10,
            "comments": 2,
            "date": "2025-01-27",
            "image_url": "https://example.com/image1.jpg",
            "engagement_rate": 0.5,
            "viral_score": 0.8,
            "promote": False,
        }
    ]

    # Save data to the database
    await save_to_db(example_data)

    # Fetch posts from the database
    posts = await fetch_posts_from_db("test_user")
    print("Fetched posts:", posts)