Fred808 commited on
Commit
9ba9239
·
verified ·
1 Parent(s): 5039b15

Update utils/database.py

Browse files
Files changed (1) hide show
  1. utils/database.py +55 -19
utils/database.py CHANGED
@@ -1,10 +1,22 @@
1
  import json
2
  from typing import List, Dict
3
- from sqlalchemy import create_engine, Column, Integer, String, Float, Boolean, Text
4
- from sqlalchemy.ext.declarative import declarative_base
5
- from sqlalchemy.orm import sessionmaker, Session
 
6
 
7
- DATABASE_URL = "postgresql://postgres.lgbnxplydqdymepehirg:[email protected]:5432/postgres"
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  Base = declarative_base()
10
 
@@ -24,15 +36,13 @@ class Post(Base):
24
  viral_score = Column(Float, default=0.0)
25
  promote = Column(Boolean, default=False)
26
 
27
- # Initialize SQLAlchemy engine and session maker
28
- engine = create_engine(DATABASE_URL)
29
- SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
30
-
31
- def init_db():
32
  """
33
  Initialize the PostgreSQL database by creating tables.
34
  """
35
- Base.metadata.create_all(bind=engine)
 
36
  print("Database initialized.")
37
 
38
  async def post_exists(session: AsyncSession, image_url: str) -> bool:
@@ -65,12 +75,13 @@ async def save_to_db(data: List[Dict]):
65
  await session.commit()
66
  print("Data saved to database.")
67
 
68
- def fetch_posts_from_db(username: str) -> List[Dict]:
69
  """
70
  Fetch posts from the database for a given username.
71
  """
72
- with SessionLocal() as session:
73
- posts = session.query(Post).filter(Post.username == username).all()
 
74
  return [
75
  {
76
  "username": post.username,
@@ -87,12 +98,37 @@ def fetch_posts_from_db(username: str) -> List[Dict]:
87
  for post in posts
88
  ]
89
 
90
- def get_db():
91
  """
92
  Dependency to get a database session.
93
  """
94
- db = SessionLocal()
95
- try:
96
- yield db
97
- finally:
98
- db.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import json
2
  from typing import List, Dict
3
+ from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
4
+ from sqlalchemy.orm import sessionmaker, declarative_base
5
+ from sqlalchemy import Column, Integer, String, Float, Boolean, Text, select
6
+ import asyncio
7
 
8
+ # Use the async database URL (replace "postgresql" with "postgresql+asyncpg")
9
+ DATABASE_URL = "postgresql+asyncpg://postgres.lgbnxplydqdymepehirg:[email protected]:5432/postgres"
10
+
11
+ # Initialize the async engine
12
+ engine = create_async_engine(DATABASE_URL, echo=True)
13
+
14
+ # Create an async session maker
15
+ AsyncSessionLocal = sessionmaker(
16
+ bind=engine,
17
+ class_=AsyncSession,
18
+ expire_on_commit=False,
19
+ )
20
 
21
  Base = declarative_base()
22
 
 
36
  viral_score = Column(Float, default=0.0)
37
  promote = Column(Boolean, default=False)
38
 
39
+ # Initialize the database (create tables)
40
+ async def init_db():
 
 
 
41
  """
42
  Initialize the PostgreSQL database by creating tables.
43
  """
44
+ async with engine.begin() as conn:
45
+ await conn.run_sync(Base.metadata.create_all)
46
  print("Database initialized.")
47
 
48
  async def post_exists(session: AsyncSession, image_url: str) -> bool:
 
75
  await session.commit()
76
  print("Data saved to database.")
77
 
78
+ async def fetch_posts_from_db(username: str) -> List[Dict]:
79
  """
80
  Fetch posts from the database for a given username.
81
  """
82
+ async with AsyncSessionLocal() as session:
83
+ result = await session.execute(select(Post).filter(Post.username == username))
84
+ posts = result.scalars().all()
85
  return [
86
  {
87
  "username": post.username,
 
98
  for post in posts
99
  ]
100
 
101
+ async def get_db():
102
  """
103
  Dependency to get a database session.
104
  """
105
+ async with AsyncSessionLocal() as session:
106
+ yield session
107
+
108
+ # Example usage
109
+ async def main():
110
+ # Initialize the database
111
+ await init_db()
112
+
113
+ # Example data to save
114
+ example_data = [
115
+ {
116
+ "username": "test_user",
117
+ "caption": "This is a test post",
118
+ "hashtags": ["test", "example"],
119
+ "likes": 10,
120
+ "comments": 2,
121
+ "date": "2025-01-27",
122
+ "image_url": "https://example.com/image1.jpg",
123
+ "engagement_rate": 0.5,
124
+ "viral_score": 0.8,
125
+ "promote": False,
126
+ }
127
+ ]
128
+
129
+ # Save data to the database
130
+ await save_to_db(example_data)
131
+
132
+ # Fetch posts from the database
133
+ posts = await fetch_posts_from_db("test_user")
134
+ print("Fetched posts:", posts)