OrifjonKenjayev commited on
Commit
4cfde22
·
verified ·
1 Parent(s): 23e787f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -0
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import gradio as gr
3
+ import os
4
+ from langchain_community.vectorstores import FAISS
5
+ from langchain_together import TogetherEmbeddings
6
+ from operator import itemgetter
7
+ from langchain.memory import ConversationBufferMemory
8
+ from langchain.schema import format_document
9
+ from typing import List, Tuple
10
+
11
+ # Environment variables for API keys
12
+ TOGETHER_API_KEY = os.getenv('TOGETHER_API_KEY')
13
+
14
+ class ChatBot:
15
+ def __init__(self):
16
+ # Load the pre-created FAISS index
17
+ self.vectorstore = FAISS.load_local("faiss_index")
18
+ self.retriever = self.vectorstore.as_retriever()
19
+
20
+ # Initialize the model
21
+ self.model = Together(
22
+ model="meta-llama/Llama-3.3-70B-Instruct-Turbo",
23
+ temperature=0.7,
24
+ max_tokens=128,
25
+ top_k=50,
26
+ together_api_key=TOGETHER_API_KEY
27
+ )
28
+
29
+ # Initialize memory
30
+ self.memory = ConversationBufferMemory(
31
+ return_messages=True,
32
+ memory_key="chat_history",
33
+ output_key="answer"
34
+ )
35
+
36
+ # Create the prompt template
37
+ self.template = """<s>[INST] Based on the following context and chat history, answer the question naturally:
38
+
39
+ Context: {context}
40
+
41
+ Chat History: {chat_history}
42
+
43
+ Question: {question} [/INST]"""
44
+
45
+ self.prompt = ChatPromptTemplate.from_template(self.template)
46
+
47
+ # Create the chain
48
+ self.chain = (
49
+ {
50
+ "context": self.retriever,
51
+ "chat_history": lambda x: self.get_chat_history(),
52
+ "question": RunnablePassthrough()
53
+ }
54
+ | self.prompt
55
+ | self.model
56
+ | StrOutputParser()
57
+ )
58
+
59
+ def get_chat_history(self) -> str:
60
+ """Format chat history for the prompt"""
61
+ messages = self.memory.load_memory_variables({})["chat_history"]
62
+ return "\n".join([f"{m.type}: {m.content}" for m in messages])
63
+
64
+ def process_response(self, response: str) -> str:
65
+ """Clean up the response"""
66
+ response = response.replace("[/INST]", "").replace("<s>", "").replace("</s>", "")
67
+ return response.strip()
68
+
69
+ def chat(self, message: str, history: List[Tuple[str, str]]) -> str:
70
+ """Process a single chat message"""
71
+ self.memory.chat_memory.add_user_message(message)
72
+ response = self.chain.invoke(message)
73
+ clean_response = self.process_response(response)
74
+ self.memory.chat_memory.add_ai_message(clean_response)
75
+ return clean_response
76
+
77
+ def reset_chat(self) -> List[Tuple[str, str]]:
78
+ """Reset the chat history"""
79
+ self.memory.clear()
80
+ return []
81
+
82
+ # Create the Gradio interface
83
+ def create_demo() -> gr.Interface:
84
+ chatbot = ChatBot()
85
+
86
+ with gr.Blocks() as demo:
87
+ gr.Markdown("""# Knowledge Base Chatbot
88
+ Ask questions about your documents and get informed responses!""")
89
+
90
+ chatbot_interface = gr.Chatbot(
91
+ height=600,
92
+ show_copy_button=True,
93
+ )
94
+
95
+ with gr.Row():
96
+ msg = gr.Textbox(
97
+ show_label=False,
98
+ placeholder="Type your message here...",
99
+ container=False
100
+ )
101
+ submit = gr.Button("Send", variant="primary")
102
+
103
+ clear = gr.Button("New Chat")
104
+
105
+ def respond(message, chat_history):
106
+ bot_message = chatbot.chat(message, chat_history)
107
+ chat_history.append((message, bot_message))
108
+ return "", chat_history
109
+
110
+ submit.click(respond, [msg, chatbot_interface], [msg, chatbot_interface])
111
+ msg.submit(respond, [msg, chatbot_interface], [msg, chatbot_interface])
112
+ clear.click(lambda: chatbot.reset_chat(), None, chatbot_interface)
113
+
114
+ return demo
115
+
116
+ demo = create_demo()
117
+
118
+ if __name__ == "__main__":
119
+ demo.launch()