Sam Fred commited on
Commit
5319d78
·
1 Parent(s): 75dc8d7

Update database.py

Browse files
Files changed (1) hide show
  1. utils/database.py +67 -81
utils/database.py CHANGED
@@ -1,104 +1,90 @@
1
- # api/utils/database.py
2
- import sqlite3
3
  import json
4
  from typing import List, Dict
5
- from sqlalchemy import create_engine
 
6
  from sqlalchemy.orm import sessionmaker, Session
7
 
8
- DATABASE = "instagram_ai.db"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  def init_db():
11
  """
12
- Initialize the SQLite database.
13
  """
14
- conn = sqlite3.connect(DATABASE)
15
- cursor = conn.cursor()
16
- cursor.execute('''
17
- CREATE TABLE IF NOT EXISTS posts (
18
- id INTEGER PRIMARY KEY AUTOINCREMENT,
19
- username TEXT NOT NULL,
20
- caption TEXT,
21
- hashtags TEXT,
22
- likes INTEGER,
23
- comments INTEGER,
24
- date TEXT,
25
- image_url TEXT UNIQUE,
26
- engagement_rate REAL,
27
- viral_score REAL,
28
- promote BOOLEAN
29
- )
30
- ''')
31
- conn.commit()
32
- conn.close()
33
 
34
- def post_exists(image_url: str) -> bool:
35
  """
36
  Check if a post with the given image_url already exists in the database.
37
  """
38
- conn = sqlite3.connect(DATABASE)
39
- cursor = conn.cursor()
40
- cursor.execute('SELECT id FROM posts WHERE image_url = ?', (image_url,))
41
- result = cursor.fetchone()
42
- conn.close()
43
- return result is not None
44
 
45
  def save_to_db(data: List[Dict]):
46
  """
47
- Save data to the SQLite database, avoiding duplicates.
48
  """
49
- conn = sqlite3.connect(DATABASE)
50
- cursor = conn.cursor()
51
- for post in data:
52
- if not post_exists(post.get('image_url')):
53
- cursor.execute('''
54
- INSERT INTO posts (username, caption, hashtags, likes, comments, date, image_url, engagement_rate, viral_score, promote)
55
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
56
- ''', (
57
- post.get('username', ''),
58
- post.get('caption', ''),
59
- json.dumps(post.get('hashtags', [])),
60
- post.get('likes', 0),
61
- post.get('comments', 0),
62
- post.get('date', ''),
63
- post.get('image_url', ''),
64
- post.get('engagement_rate', 0.0),
65
- post.get('viral_score', 0.0),
66
- post.get('promote', False)
67
- ))
68
- conn.commit()
69
- conn.close()
70
- print(f"Data saved to database: {DATABASE}")
71
 
72
  def fetch_posts_from_db(username: str) -> List[Dict]:
73
  """
74
  Fetch posts from the database for a given username.
75
  """
76
- conn = sqlite3.connect(DATABASE)
77
- cursor = conn.cursor()
78
- cursor.execute('SELECT * FROM posts WHERE username = ?', (username,))
79
- rows = cursor.fetchall()
80
- conn.close()
81
-
82
-
83
- posts = []
84
- for row in rows:
85
- posts.append({
86
- "username": row[1],
87
- "caption": row[2],
88
- "hashtags": json.loads(row[3]),
89
- "likes": row[4],
90
- "comments": row[5],
91
- "date": row[6],
92
- "image_url": row[7],
93
- "engagement_rate": row[8],
94
- "viral_score": row[9],
95
- "promote": bool(row[10])
96
- })
97
- return posts
98
-
99
- SQLALCHEMY_DATABASE_URL = f"sqlite:///{DATABASE}"
100
- engine = create_engine(SQLALCHEMY_DATABASE_URL)
101
- SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
102
 
103
  def get_db():
104
  """
@@ -108,4 +94,4 @@ def get_db():
108
  try:
109
  yield db
110
  finally:
111
- db.close()
 
 
 
1
  import json
2
  from typing import List, Dict
3
+ from sqlalchemy import create_engine, Column, Integer, String, Float, Boolean, Text
4
+ from sqlalchemy.ext.declarative import declarative_base
5
  from sqlalchemy.orm import sessionmaker, Session
6
 
7
+ DATABASE_URL = "postgresql://postgres:Lovyelias5584.@db.lgbnxplydqdymepehirg.supabase.co:5432/postgres"
8
+
9
+ Base = declarative_base()
10
+
11
+ # Define the posts table using SQLAlchemy ORM
12
+ class Post(Base):
13
+ __tablename__ = "posts"
14
+
15
+ id = Column(Integer, primary_key=True, index=True)
16
+ username = Column(String, nullable=False)
17
+ caption = Column(Text, nullable=True)
18
+ hashtags = Column(Text, nullable=True) # Store as JSON string
19
+ likes = Column(Integer, default=0)
20
+ comments = Column(Integer, default=0)
21
+ date = Column(String, nullable=True)
22
+ image_url = Column(String, unique=True, nullable=False)
23
+ engagement_rate = Column(Float, default=0.0)
24
+ viral_score = Column(Float, default=0.0)
25
+ promote = Column(Boolean, default=False)
26
+
27
+ # Initialize SQLAlchemy engine and session maker
28
+ engine = create_engine(DATABASE_URL)
29
+ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
30
 
31
  def init_db():
32
  """
33
+ Initialize the PostgreSQL database by creating tables.
34
  """
35
+ Base.metadata.create_all(bind=engine)
36
+ print("Database initialized.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ def post_exists(session: Session, image_url: str) -> bool:
39
  """
40
  Check if a post with the given image_url already exists in the database.
41
  """
42
+ return session.query(Post.id).filter(Post.image_url == image_url).first() is not None
 
 
 
 
 
43
 
44
  def save_to_db(data: List[Dict]):
45
  """
46
+ Save data to the PostgreSQL database, avoiding duplicates.
47
  """
48
+ with SessionLocal() as session:
49
+ for post in data:
50
+ if not post_exists(session, post.get("image_url")):
51
+ new_post = Post(
52
+ username=post.get("username", ""),
53
+ caption=post.get("caption", ""),
54
+ hashtags=json.dumps(post.get("hashtags", [])), # Convert list to JSON string
55
+ likes=post.get("likes", 0),
56
+ comments=post.get("comments", 0),
57
+ date=post.get("date", ""),
58
+ image_url=post.get("image_url", ""),
59
+ engagement_rate=post.get("engagement_rate", 0.0),
60
+ viral_score=post.get("viral_score", 0.0),
61
+ promote=post.get("promote", False),
62
+ )
63
+ session.add(new_post)
64
+ session.commit()
65
+ print("Data saved to database.")
 
 
 
 
66
 
67
  def fetch_posts_from_db(username: str) -> List[Dict]:
68
  """
69
  Fetch posts from the database for a given username.
70
  """
71
+ with SessionLocal() as session:
72
+ posts = session.query(Post).filter(Post.username == username).all()
73
+ return [
74
+ {
75
+ "username": post.username,
76
+ "caption": post.caption,
77
+ "hashtags": json.loads(post.hashtags), # Convert JSON string back to list
78
+ "likes": post.likes,
79
+ "comments": post.comments,
80
+ "date": post.date,
81
+ "image_url": post.image_url,
82
+ "engagement_rate": post.engagement_rate,
83
+ "viral_score": post.viral_score,
84
+ "promote": post.promote,
85
+ }
86
+ for post in posts
87
+ ]
 
 
 
 
 
 
 
 
 
88
 
89
  def get_db():
90
  """
 
94
  try:
95
  yield db
96
  finally:
97
+ db.close()