Libidrave commited on
Commit
e94b921
·
verified ·
1 Parent(s): c8d159b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -123
app.py CHANGED
@@ -1,123 +1,123 @@
1
- __import__('pysqlite3')
2
- import sys
3
- sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
4
-
5
- import os
6
- import time
7
- from uuid import uuid4
8
-
9
- from langchain_openai import ChatOpenAI
10
- from langchain_community.embeddings.fastembed import FastEmbedEmbeddings
11
-
12
- from langchain.chains import create_retrieval_chain
13
- from langchain_core.tools import tool
14
-
15
- from utils.preprocess import load_data, split_data, upsert_chromadb
16
- from utils.prebuilt_chain import history_aware_retriever, documents_retriever
17
-
18
- import streamlit as st
19
-
20
- db_name = "chroma" # default name for Chromadb
21
-
22
- st.set_page_config(page_title="RAG Demo App")
23
- st.title("Demo Retrieval Augmented Generation With LanghChain & Chroma")
24
-
25
- @st.cache_resource
26
- def load_model(api_key):
27
- """cached llm and embedding model"""
28
- if st.session_state.provider == "OpenAI":
29
- return ChatOpenAI(model="gpt-4o-mini", temperature=0.3, api_key=api_key)
30
- elif st.session_state.provider == "Groq":
31
- return ChatOpenAI(model="llama-3.1-8b-instant", temperature=0.3, api_key=api_key, base_url="https://api.groq.com/openai/v1")
32
-
33
- @st.cache_resource
34
- def load_embedding():
35
- st.session_state.embedding = FastEmbedEmbeddings(model_name="jinaai/jina-embeddings-v2-base-de",
36
- batch_size=64)
37
-
38
- def inputs():
39
- """Input fields for user interaction"""
40
- with st.sidebar:
41
- st.session_state.provider = st.radio("Pilih model LLM", ["OpenAI", "Groq"])
42
-
43
- st.session_state.api_key = st.text_input("Masukkan API Key", type="password")
44
- os.environ["OPENAI_API_KEY"] = st.session_state.api_key
45
-
46
- st.session_state.chroma_collection_name = st.text_input("Chroma Collection Name")
47
-
48
- st.session_state.source_docs = st.file_uploader("Unggah file PDF", type=["pdf"], accept_multiple_files=True)
49
- st.button("Proses Dokumen", on_click=process_data)
50
-
51
- def process_data():
52
- """Main function to process data"""
53
- if not st.session_state.api_key or not st.session_state.chroma_collection_name or not st.session_state.source_docs:
54
- st.error("Tolong masukan API key, Chroma collection name, dan dokumen yang diperlukan!!")
55
- else:
56
- with st.spinner("📚 Memproses dokumen..."):
57
- loaded_docs = load_data(st.session_state.source_docs)
58
- splitted_docs = split_data(loaded_docs)
59
-
60
- idx = [str(uuid4()) for _ in range(len(splitted_docs))]
61
-
62
- st.session_state.vector_store = upsert_chromadb(splitted_docs,
63
- st.session_state.embedding,
64
- idx,
65
- st.session_state.chroma_collection_name,
66
- db_name)
67
- msg = st.empty()
68
- msg.success("Dokumen berhasil diproses!")
69
- time.sleep(3)
70
- msg.empty()
71
-
72
- # Main retriever
73
- @tool(response_format="content_and_artifact")
74
- def retrieve(query: str):
75
- """Retrieve information related to a query.
76
-
77
- Args:
78
- query: The user's query.
79
- """
80
- retrieved_docs = st.session_state.vector_store.similarity_search(query, k=6)
81
- keys = ["author", "creator", "page", "source", "start_index", "total_pages"]
82
- serialized = "\n\n".join(
83
- (f"Source: {[{key: doc.metadata.get(key)} for key in keys]}\n" f"Content: {doc.page_content}")
84
- for doc in retrieved_docs
85
- )
86
- return serialized, retrieved_docs
87
-
88
- def generate(query):
89
- """Generate a response to the user's query."""
90
- # Dummy retriever.
91
- retriever = st.session_state.vector_store.as_retriever(search_kwargs={"k" : 1})
92
-
93
- # Create a RAG chain using the history-aware retriever and the document-retriever.
94
- history_retriever = history_aware_retriever(st.session_state.llm, retriever)
95
- question_answer_chain = documents_retriever(st.session_state.llm)
96
-
97
- rag_chain = create_retrieval_chain(history_retriever, question_answer_chain)
98
-
99
- # Usage:
100
- response = rag_chain.invoke({"input": query, "chat_history" : st.session_state.messages, "context" : retrieve.invoke(query)})
101
- st.session_state.messages.append(query)
102
- st.session_state.messages.append(response["answer"])
103
- return response["answer"]
104
-
105
- if __name__ == "__main__":
106
- os.makedirs(db_name, exist_ok=True) # This directory is used to store persistent files from Chromadb
107
-
108
- inputs()
109
- st.session_state.llm = load_model(os.getenv("OPENAI_API_KEY"))
110
- load_embedding()
111
-
112
- if "messages" not in st.session_state:
113
- st.session_state.messages = []
114
-
115
- if st.session_state.messages:
116
- st.chat_message('human').write(st.session_state.messages[-2])
117
- st.chat_message('ai').write(st.session_state.messages[-1])
118
-
119
- query = st.chat_input("Masukkan Prompt")
120
- if query:
121
- st.chat_message("human").write(query)
122
- response = generate(query)
123
- st.chat_message("ai").write(response)
 
1
+ __import__('pysqlite3')
2
+ import sys
3
+ sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
4
+
5
+ import os
6
+ import time
7
+ from uuid import uuid4
8
+
9
+ from langchain_openai import ChatOpenAI
10
+ from langchain_community.embeddings.fastembed import FastEmbedEmbeddings
11
+
12
+ from langchain.chains import create_retrieval_chain
13
+ from langchain_core.tools import tool
14
+
15
+ from utils.preprocess import load_data, split_data, upsert_chromadb
16
+ from utils.prebuilt_chain import history_aware_retriever, documents_retriever
17
+
18
+ import streamlit as st
19
+
20
+ db_name = "chroma" # default name for Chromadb
21
+
22
+ st.set_page_config(page_title="RAG Demo App")
23
+ st.title("Demo Retrieval Augmented Generation With LanghChain & Chroma")
24
+
25
+ @st.cache_resource
26
+ def load_model(api_key):
27
+ """cached llm and embedding model"""
28
+ if st.session_state.provider == "OpenAI":
29
+ return ChatOpenAI(model="gpt-4o-mini", temperature=0.3, api_key=api_key)
30
+ elif st.session_state.provider == "Groq":
31
+ return ChatOpenAI(model="qwen-2.5-32b", temperature=0.3, api_key=api_key, base_url="https://api.groq.com/openai/v1")
32
+
33
+ @st.cache_resource
34
+ def load_embedding():
35
+ st.session_state.embedding = FastEmbedEmbeddings(model_name="jinaai/jina-embeddings-v2-base-de",
36
+ batch_size=64)
37
+
38
+ def inputs():
39
+ """Input fields for user interaction"""
40
+ with st.sidebar:
41
+ st.session_state.provider = st.radio("Pilih model LLM", ["OpenAI", "Groq"])
42
+
43
+ st.session_state.api_key = st.text_input("Masukkan API Key", type="password")
44
+ os.environ["OPENAI_API_KEY"] = st.session_state.api_key
45
+
46
+ st.session_state.chroma_collection_name = st.text_input("Chroma Collection Name")
47
+
48
+ st.session_state.source_docs = st.file_uploader("Unggah file PDF", type=["pdf"], accept_multiple_files=True)
49
+ st.button("Proses Dokumen", on_click=process_data)
50
+
51
+ def process_data():
52
+ """Main function to process data"""
53
+ if not st.session_state.api_key or not st.session_state.chroma_collection_name or not st.session_state.source_docs:
54
+ st.error("Tolong masukan API key, Chroma collection name, dan dokumen yang diperlukan!!")
55
+ else:
56
+ with st.spinner("📚 Memproses dokumen..."):
57
+ loaded_docs = load_data(st.session_state.source_docs)
58
+ splitted_docs = split_data(loaded_docs)
59
+
60
+ idx = [str(uuid4()) for _ in range(len(splitted_docs))]
61
+
62
+ st.session_state.vector_store = upsert_chromadb(splitted_docs,
63
+ st.session_state.embedding,
64
+ idx,
65
+ st.session_state.chroma_collection_name,
66
+ db_name)
67
+ msg = st.empty()
68
+ msg.success("Dokumen berhasil diproses!")
69
+ time.sleep(3)
70
+ msg.empty()
71
+
72
+ # Main retriever
73
+ @tool(response_format="content_and_artifact")
74
+ def retrieve(query: str):
75
+ """Retrieve information related to a query.
76
+
77
+ Args:
78
+ query: The user's query.
79
+ """
80
+ retrieved_docs = st.session_state.vector_store.similarity_search(query, k=6)
81
+ keys = ["author", "creator", "page", "source", "start_index", "total_pages"]
82
+ serialized = "\n\n".join(
83
+ (f"Source: {[{key: doc.metadata.get(key)} for key in keys]}\n" f"Content: {doc.page_content}")
84
+ for doc in retrieved_docs
85
+ )
86
+ return serialized, retrieved_docs
87
+
88
+ def generate(query):
89
+ """Generate a response to the user's query."""
90
+ # Dummy retriever.
91
+ retriever = st.session_state.vector_store.as_retriever(search_kwargs={"k" : 1})
92
+
93
+ # Create a RAG chain using the history-aware retriever and the document-retriever.
94
+ history_retriever = history_aware_retriever(st.session_state.llm, retriever)
95
+ question_answer_chain = documents_retriever(st.session_state.llm)
96
+
97
+ rag_chain = create_retrieval_chain(history_retriever, question_answer_chain)
98
+
99
+ # Usage:
100
+ response = rag_chain.invoke({"input": query, "chat_history" : st.session_state.messages, "context" : retrieve.invoke(query)})
101
+ st.session_state.messages.append(query)
102
+ st.session_state.messages.append(response["answer"])
103
+ return response["answer"]
104
+
105
+ if __name__ == "__main__":
106
+ os.makedirs(db_name, exist_ok=True) # This directory is used to store persistent files from Chromadb
107
+
108
+ inputs()
109
+ st.session_state.llm = load_model(os.getenv("OPENAI_API_KEY"))
110
+ load_embedding()
111
+
112
+ if "messages" not in st.session_state:
113
+ st.session_state.messages = []
114
+
115
+ if st.session_state.messages:
116
+ st.chat_message('human').write(st.session_state.messages[-2])
117
+ st.chat_message('ai').write(st.session_state.messages[-1])
118
+
119
+ query = st.chat_input("Masukkan Prompt")
120
+ if query:
121
+ st.chat_message("human").write(query)
122
+ response = generate(query)
123
+ st.chat_message("ai").write(response)