Fred808 commited on
Commit
d6a44dc
·
verified ·
1 Parent(s): 7994db7

Update utils/database.py

Browse files
Files changed (1) hide show
  1. utils/database.py +7 -6
utils/database.py CHANGED
@@ -35,19 +35,20 @@ def init_db():
35
  Base.metadata.create_all(bind=engine)
36
  print("Database initialized.")
37
 
38
- def post_exists(session: Session, image_url: str) -> bool:
39
  """
40
  Check if a post with the given image_url already exists in the database.
41
  """
42
- return session.query(Post.id).filter(Post.image_url == image_url).first() is not None
 
43
 
44
- def save_to_db(data: List[Dict]):
45
  """
46
  Save data to the PostgreSQL database, avoiding duplicates.
47
  """
48
- with SessionLocal() as session:
49
  for post in data:
50
- if not post_exists(session, post.get("image_url")):
51
  new_post = Post(
52
  username=post.get("username", ""),
53
  caption=post.get("caption", ""),
@@ -61,7 +62,7 @@ def save_to_db(data: List[Dict]):
61
  promote=post.get("promote", False),
62
  )
63
  session.add(new_post)
64
- session.commit()
65
  print("Data saved to database.")
66
 
67
  def fetch_posts_from_db(username: str) -> List[Dict]:
 
35
  Base.metadata.create_all(bind=engine)
36
  print("Database initialized.")
37
 
38
+ async def post_exists(session: AsyncSession, image_url: str) -> bool:
39
  """
40
  Check if a post with the given image_url already exists in the database.
41
  """
42
+ result = await session.execute(select(Post).filter(Post.image_url == image_url))
43
+ return result.scalars().first() is not None
44
 
45
+ async def save_to_db(data: List[Dict]):
46
  """
47
  Save data to the PostgreSQL database, avoiding duplicates.
48
  """
49
+ async with AsyncSessionLocal() as session:
50
  for post in data:
51
+ if not await post_exists(session, post.get("image_url")):
52
  new_post = Post(
53
  username=post.get("username", ""),
54
  caption=post.get("caption", ""),
 
62
  promote=post.get("promote", False),
63
  )
64
  session.add(new_post)
65
+ await session.commit()
66
  print("Data saved to database.")
67
 
68
  def fetch_posts_from_db(username: str) -> List[Dict]: