Asankhaya Sharma commited on
Commit
cae23e1
·
1 Parent(s): 033cc04
Files changed (4) hide show
  1. main.py +67 -31
  2. question.py +12 -7
  3. requirements.txt +1 -1
  4. stats.py +5 -0
main.py CHANGED
@@ -7,6 +7,10 @@ from question import chat_with_doc
7
  from langchain.embeddings import HuggingFaceInferenceAPIEmbeddings
8
  from langchain.vectorstores import SupabaseVectorStore
9
  from supabase import Client, create_client
 
 
 
 
10
 
11
  supabase_url = st.secrets.SUPABASE_URL
12
  supabase_key = st.secrets.SUPABASE_KEY
@@ -19,7 +23,6 @@ username = st.secrets.username
19
 
20
  # embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
21
 
22
-
23
  embeddings = HuggingFaceInferenceAPIEmbeddings(
24
  api_key=hf_api_key,
25
  model_name="BAAI/bge-large-en-v1.5"
@@ -36,38 +39,71 @@ if anthropic_api_key:
36
  models += ["claude-v1", "claude-v1.3",
37
  "claude-instant-v1-100k", "claude-instant-v1.1-100k"]
38
 
39
- # Set the theme
40
- st.set_page_config(
41
- page_title="Securade.ai - Safety Copilot",
42
- page_icon="https://securade.ai/favicon.ico",
43
- layout="centered",
44
- initial_sidebar_state="collapsed",
45
- menu_items={
46
- "About": "# Securade.ai Safety Copilot v0.1\n [https://securade.ai](https://securade.ai)",
47
- "Get Help" : "https://securade.ai",
48
- "Report a Bug": "mailto:[email protected]"
49
- }
50
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- st.title("👷‍♂️ Safety Copilot 🦺")
53
- st.markdown("Chat with your personal assistant about health and safety information.")
54
 
55
- st.markdown("---\n\n")
56
 
57
- # Initialize session state variables
58
- if 'model' not in st.session_state:
59
- st.session_state['model'] = "meta-llama/Llama-2-70b-chat-hf"
60
- if 'temperature' not in st.session_state:
61
- st.session_state['temperature'] = 0.1
62
- if 'chunk_size' not in st.session_state:
63
- st.session_state['chunk_size'] = 500
64
- if 'chunk_overlap' not in st.session_state:
65
- st.session_state['chunk_overlap'] = 0
66
- if 'max_tokens' not in st.session_state:
67
- st.session_state['max_tokens'] = 500
68
- if 'username' not in st.session_state:
69
- st.session_state['username'] = username
70
 
71
- chat_with_doc(st.session_state['model'], vector_store, stats_db=supabase)
72
 
73
- st.markdown("---\n\n")
 
7
  from langchain.embeddings import HuggingFaceInferenceAPIEmbeddings
8
  from langchain.vectorstores import SupabaseVectorStore
9
  from supabase import Client, create_client
10
+ from stats import add_usage
11
+ from langchain.llms import HuggingFaceEndpoint
12
+ from langchain.chains import ConversationalRetrievalChain
13
+ from langchain.memory import ConversationBufferMemory
14
 
15
  supabase_url = st.secrets.SUPABASE_URL
16
  supabase_key = st.secrets.SUPABASE_KEY
 
23
 
24
  # embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
25
 
 
26
  embeddings = HuggingFaceInferenceAPIEmbeddings(
27
  api_key=hf_api_key,
28
  model_name="BAAI/bge-large-en-v1.5"
 
39
  models += ["claude-v1", "claude-v1.3",
40
  "claude-instant-v1-100k", "claude-instant-v1.1-100k"]
41
 
42
+ if 'question' in st.query_params:
43
+ query = st.query_params['question']
44
+ model = "meta-llama/Llama-2-70b-chat-hf"
45
+ temp = 0.1
46
+ max_tokens = 500
47
+ add_usage(supabase, "api", "prompt" + query, {"model": model, "temperature": temp})
48
+ # print(st.session_state['max_tokens'])
49
+ endpoint_url = ("https://api-inference.huggingface.co/models/"+ model)
50
+ model_kwargs = {"temperature" : temp,
51
+ "max_new_tokens" : max_tokens,
52
+ "return_full_text" : False}
53
+ hf = HuggingFaceEndpoint(
54
+ endpoint_url=endpoint_url,
55
+ task="text-generation",
56
+ huggingfacehub_api_token=hf_api_key,
57
+ model_kwargs=model_kwargs
58
+ )
59
+ memory = ConversationBufferMemory(memory_key="chat_history", input_key='question', output_key='answer', return_messages=True)
60
+ qa = ConversationalRetrievalChain.from_llm(hf, retriever=vector_store.as_retriever(search_kwargs={"score_threshold": 0.8, "k": 4,"filter": {"user": username}}), memory=memory, return_source_documents=True)
61
+ model_response = qa({"question": query})
62
+ # print( model_response["answer"])
63
+ sources = model_response["source_documents"]
64
+ # print(sources)
65
+ if len(sources) > 0:
66
+ json = {"response": model_response["answer"]}
67
+ st.code(json, language="json")
68
+ else:
69
+ json = {"response": "I am sorry, I do not have enough information to provide an answer. If there is a public source of data that you would like to add, please email [email protected]."}
70
+ st.code(json, language="json")
71
+ memory.clear()
72
+ else:
73
+ # Set the theme
74
+ st.set_page_config(
75
+ page_title="Securade.ai - Safety Copilot",
76
+ page_icon="https://securade.ai/favicon.ico",
77
+ layout="centered",
78
+ initial_sidebar_state="collapsed",
79
+ menu_items={
80
+ "About": "# Securade.ai Safety Copilot v0.1\n [https://securade.ai](https://securade.ai)",
81
+ "Get Help" : "https://securade.ai",
82
+ "Report a Bug": "mailto:[email protected]"
83
+ }
84
+ )
85
+
86
+ st.title("👷‍♂️ Safety Copilot 🦺")
87
 
88
+ st.markdown("Chat with your personal safety assistant about any health & safety related queries.")
89
+ st.markdown("Up-to-date with latest OSH regulations for Singapore, Indonesia, Malaysia & other parts of Asia.")
90
 
91
+ st.markdown("---\n\n")
92
 
93
+ # Initialize session state variables
94
+ if 'model' not in st.session_state:
95
+ st.session_state['model'] = "meta-llama/Llama-2-70b-chat-hf"
96
+ if 'temperature' not in st.session_state:
97
+ st.session_state['temperature'] = 0.1
98
+ if 'chunk_size' not in st.session_state:
99
+ st.session_state['chunk_size'] = 500
100
+ if 'chunk_overlap' not in st.session_state:
101
+ st.session_state['chunk_overlap'] = 0
102
+ if 'max_tokens' not in st.session_state:
103
+ st.session_state['max_tokens'] = 500
104
+ if 'username' not in st.session_state:
105
+ st.session_state['username'] = username
106
 
107
+ chat_with_doc(st.session_state['model'], vector_store, stats_db=supabase)
108
 
109
+ st.markdown("---\n\n")
question.py CHANGED
@@ -7,7 +7,7 @@ from langchain.llms import OpenAI
7
  from langchain.llms import HuggingFaceEndpoint
8
  from langchain.chat_models import ChatAnthropic
9
  from langchain.vectorstores import SupabaseVectorStore
10
- from stats import add_usage
11
 
12
  memory = ConversationBufferMemory(memory_key="chat_history", input_key='question', output_key='answer', return_messages=True)
13
  openai_api_key = st.secrets.openai_api_key
@@ -15,13 +15,13 @@ anthropic_api_key = st.secrets.anthropic_api_key
15
  hf_api_key = st.secrets.hf_api_key
16
  logger = get_logger(__name__)
17
 
18
-
19
  def chat_with_doc(model, vector_store: SupabaseVectorStore, stats_db):
20
 
21
  if 'chat_history' not in st.session_state:
22
  st.session_state['chat_history'] = []
23
-
24
- question = st.text_area("## Ask a question")
 
25
  columns = st.columns(2)
26
  with columns[0]:
27
  button = st.button("Ask")
@@ -62,16 +62,21 @@ def chat_with_doc(model, vector_store: SupabaseVectorStore, stats_db):
62
  huggingfacehub_api_token=hf_api_key,
63
  model_kwargs=model_kwargs
64
  )
65
- qa = ConversationalRetrievalChain.from_llm(hf, retriever=vector_store.as_retriever(search_kwargs={"score_threshold": 0.6, "k": 4,"filter": {"user": st.session_state["username"]}}), memory=memory, verbose=True, return_source_documents=True)
66
 
67
  st.session_state['chat_history'].append(("You", question))
68
 
69
  # Generate model's response and add it to chat history
70
  model_response = qa({"question": question})
71
  logger.info('Result: %s', model_response["answer"])
72
-
73
- st.session_state['chat_history'].append(("Safety Copilot", model_response["answer"]))
74
  logger.info('Sources: %s', model_response["source_documents"])
 
 
 
 
 
 
75
 
76
  # Display chat history
77
  st.empty()
 
7
  from langchain.llms import HuggingFaceEndpoint
8
  from langchain.chat_models import ChatAnthropic
9
  from langchain.vectorstores import SupabaseVectorStore
10
+ from stats import add_usage, get_usage
11
 
12
  memory = ConversationBufferMemory(memory_key="chat_history", input_key='question', output_key='answer', return_messages=True)
13
  openai_api_key = st.secrets.openai_api_key
 
15
  hf_api_key = st.secrets.hf_api_key
16
  logger = get_logger(__name__)
17
 
 
18
  def chat_with_doc(model, vector_store: SupabaseVectorStore, stats_db):
19
 
20
  if 'chat_history' not in st.session_state:
21
  st.session_state['chat_history'] = []
22
+
23
+ stats = str(get_usage(stats_db))
24
+ question = st.text_area("## Ask a question (" + stats + " queries answered so far)", max_chars=500)
25
  columns = st.columns(2)
26
  with columns[0]:
27
  button = st.button("Ask")
 
62
  huggingfacehub_api_token=hf_api_key,
63
  model_kwargs=model_kwargs
64
  )
65
+ qa = ConversationalRetrievalChain.from_llm(hf, retriever=vector_store.as_retriever(search_kwargs={"score_threshold": 0.8, "k": 4,"filter": {"user": st.session_state["username"]}}), memory=memory, verbose=True, return_source_documents=True)
66
 
67
  st.session_state['chat_history'].append(("You", question))
68
 
69
  # Generate model's response and add it to chat history
70
  model_response = qa({"question": question})
71
  logger.info('Result: %s', model_response["answer"])
72
+ sources = model_response["source_documents"]
 
73
  logger.info('Sources: %s', model_response["source_documents"])
74
+
75
+ if len(sources) > 0:
76
+ st.session_state['chat_history'].append(("Safety Copilot", model_response["answer"]))
77
+ else:
78
+ st.session_state['chat_history'].append(("Safety Copilot", "I am sorry, I do not have enough information to provide an answer. If there is a public source of data that you would like to add, please email [email protected]."))
79
+
80
 
81
  # Display chat history
82
  st.empty()
requirements.txt CHANGED
@@ -3,7 +3,7 @@ Markdown==3.4.3
3
  openai==0.27.6
4
  pdf2image==1.16.3
5
  pypdf==3.8.1
6
- streamlit==1.22.0
7
  StrEnum==0.4.10
8
  supabase==1.0.3
9
  tiktoken==0.4.0
 
3
  openai==0.27.6
4
  pdf2image==1.16.3
5
  pypdf==3.8.1
6
+ streamlit==1.30.0
7
  StrEnum==0.4.10
8
  supabase==1.0.3
9
  tiktoken==0.4.0
stats.py CHANGED
@@ -29,3 +29,8 @@ def add_usage(supabase, type, details, metadata):
29
  "details": details,
30
  "metadata": metadata
31
  }).execute()
 
 
 
 
 
 
29
  "details": details,
30
  "metadata": metadata
31
  }).execute()
32
+
33
+ def get_usage(supabase):
34
+ # Returns the number of rows in the stats table for the last 24 hours
35
+ response = supabase.table("stats").select("id", count="exact").execute()
36
+ return response.count