Libidrave commited on
Commit
b5262d1
·
verified ·
1 Parent(s): 90cb715

Up to Spaces

Browse files
Files changed (2) hide show
  1. requirements.txt +14 -0
  2. streamlitrag.py +123 -0
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ langchain
2
+ langgraph
3
+ langchain-core
4
+ langchain-text-splitters
5
+ langchain-community
6
+ langchain-openai
7
+ langchain-chroma
8
+ openai
9
+ chromadb
10
+ python-dotenv
11
+ pandas
12
+ pymupdf
13
+ pysqlite3-binary
14
+ fastembed
streamlitrag.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __import__('pysqlite3')
2
+ import sys
3
+ sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
4
+
5
+ import os
6
+ import time
7
+ from uuid import uuid4
8
+
9
+ from langchain_openai import ChatOpenAI
10
+ from langchain_community.embeddings.fastembed import FastEmbedEmbeddings
11
+
12
+ from langchain.chains import create_retrieval_chain
13
+ from langchain_core.tools import tool
14
+
15
+ from utils.preprocess import load_data, split_data, upsert_chromadb
16
+ from utils.prebuilt_chain import history_aware_retriever, documents_retriever
17
+
18
+ import streamlit as st
19
+
20
+ db_name = "chroma" # default name for Chromadb
21
+
22
+ st.set_page_config(page_title="RAG Demo App")
23
+ st.title("Demo Retrieval Augmented Generation With LanghChain & Chroma")
24
+
25
+ @st.cache_resource
26
+ def load_model(api_key):
27
+ """cached llm and embedding model"""
28
+ if st.session_state.provider == "OpenAI":
29
+ return ChatOpenAI(model="gpt-4o-mini", temperature=0.3, api_key=api_key)
30
+ elif st.session_state.provider == "Groq":
31
+ return ChatOpenAI(model="llama-3.1-8b-instant", temperature=0.3, api_key=api_key, base_url="https://api.groq.com/openai/v1")
32
+
33
+ @st.cache_resource
34
+ def load_embedding():
35
+ st.session_state.embedding = FastEmbedEmbeddings(model_name="jinaai/jina-embeddings-v2-base-de",
36
+ batch_size=64)
37
+
38
+ def inputs():
39
+ """Input fields for user interaction"""
40
+ with st.sidebar:
41
+ st.session_state.provider = st.radio("Pilih model LLM", ["OpenAI", "Groq"])
42
+
43
+ st.session_state.api_key = st.text_input("Masukkan API Key", type="password")
44
+ os.environ["OPENAI_API_KEY"] = st.session_state.api_key
45
+
46
+ st.session_state.chroma_collection_name = st.text_input("Chroma Collection Name")
47
+
48
+ st.session_state.source_docs = st.file_uploader("Unggah file PDF", type=["pdf"], accept_multiple_files=True)
49
+ st.button("Proses Dokumen", on_click=process_data)
50
+
51
+ def process_data():
52
+ """Main function to process data"""
53
+ if not st.session_state.api_key or not st.session_state.chroma_collection_name or not st.session_state.source_docs:
54
+ st.error("Tolong masukan API key, Chroma collection name, dan dokumen yang diperlukan!!")
55
+ else:
56
+ with st.spinner("📚 Memproses dokumen..."):
57
+ loaded_docs = load_data(st.session_state.source_docs)
58
+ splitted_docs = split_data(loaded_docs)
59
+
60
+ idx = [str(uuid4()) for _ in range(len(splitted_docs))]
61
+
62
+ st.session_state.vector_store = upsert_chromadb(splitted_docs,
63
+ st.session_state.embedding,
64
+ idx,
65
+ st.session_state.chroma_collection_name,
66
+ db_name)
67
+ msg = st.empty()
68
+ msg.success("Dokumen berhasil diproses!")
69
+ time.sleep(3)
70
+ msg.empty()
71
+
72
+ # Main retriever
73
+ @tool(response_format="content_and_artifact")
74
+ def retrieve(query: str):
75
+ """Retrieve information related to a query.
76
+
77
+ Args:
78
+ query: The user's query.
79
+ """
80
+ retrieved_docs = st.session_state.vector_store.similarity_search(query, k=6)
81
+ keys = ["author", "creator", "page", "source", "start_index", "total_pages"]
82
+ serialized = "\n\n".join(
83
+ (f"Source: {[{key: doc.metadata.get(key)} for key in keys]}\n" f"Content: {doc.page_content}")
84
+ for doc in retrieved_docs
85
+ )
86
+ return serialized, retrieved_docs
87
+
88
+ def generate(query):
89
+ """Generate a response to the user's query."""
90
+ # Dummy retriever.
91
+ retriever = st.session_state.vector_store.as_retriever(search_kwargs={"k" : 1})
92
+
93
+ # Create a RAG chain using the history-aware retriever and the document-retriever.
94
+ history_retriever = history_aware_retriever(st.session_state.llm, retriever)
95
+ question_answer_chain = documents_retriever(st.session_state.llm)
96
+
97
+ rag_chain = create_retrieval_chain(history_retriever, question_answer_chain)
98
+
99
+ # Usage:
100
+ response = rag_chain.invoke({"input": query, "chat_history" : st.session_state.messages, "context" : retrieve.invoke(query)})
101
+ st.session_state.messages.append(query)
102
+ st.session_state.messages.append(response["answer"])
103
+ return response["answer"]
104
+
105
+ if __name__ == "__main__":
106
+ os.makedirs(db_name, exist_ok=True) # This directory is used to store persistent files from Chromadb
107
+
108
+ inputs()
109
+ st.session_state.llm = load_model(os.getenv("OPENAI_API_KEY"))
110
+ load_embedding()
111
+
112
+ if "messages" not in st.session_state:
113
+ st.session_state.messages = []
114
+
115
+ if st.session_state.messages:
116
+ st.chat_message('human').write(st.session_state.messages[-2])
117
+ st.chat_message('ai').write(st.session_state.messages[-1])
118
+
119
+ query = st.chat_input("Masukkan Prompt")
120
+ if query:
121
+ st.chat_message("human").write(query)
122
+ response = generate(query)
123
+ st.chat_message("ai").write(response)