Spaces:
Running
Running
RakeshUtekar
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -1,8 +1,10 @@
|
|
1 |
import os
|
2 |
import streamlit as st
|
3 |
import torch
|
|
|
4 |
from langchain.prompts import ChatPromptTemplate
|
5 |
-
|
|
|
6 |
from langchain_huggingface import HuggingFaceEndpoint
|
7 |
|
8 |
def create_prompt(name: str, persona_style: str):
|
@@ -32,7 +34,7 @@ def create_prompt(name: str, persona_style: str):
|
|
32 |
"""
|
33 |
return ChatPromptTemplate.from_template(prompt_template_str)
|
34 |
|
35 |
-
def simulate_conversation(chain:
|
36 |
"""Simulate a conversation for a given number of turns, limiting chat history."""
|
37 |
chat_history_list = []
|
38 |
human_messages = [
|
@@ -57,12 +59,11 @@ def simulate_conversation(chain: RunnableSequence, turns: int = 15, max_history_
|
|
57 |
for i in range(turns):
|
58 |
human_input = human_messages[i % len(human_messages)]
|
59 |
|
60 |
-
#
|
61 |
truncated_history_lines = chat_history_list[-(max_history_rounds*2):]
|
62 |
truncated_history = "\n".join(truncated_history_lines)
|
63 |
|
64 |
-
response = chain.
|
65 |
-
# Update chat history
|
66 |
chat_history_list.append(f"Human: {human_input}")
|
67 |
chat_history_list.append(f"AI: {response}")
|
68 |
|
@@ -72,11 +73,11 @@ def simulate_conversation(chain: RunnableSequence, turns: int = 15, max_history_
|
|
72 |
st.error(f"Error during conversation simulation: {e}")
|
73 |
return None
|
74 |
|
75 |
-
def summarize_conversation(chain:
|
76 |
"""Use the LLM to summarize the completed conversation."""
|
77 |
summary_prompt = f"Summarize the following conversation in a few short sentences highlighting the main points, tone, and conclusion:\n\n{conversation}\nSummary:"
|
78 |
try:
|
79 |
-
response = chain.
|
80 |
return response.strip()
|
81 |
except Exception as e:
|
82 |
st.error(f"Error summarizing conversation: {e}")
|
@@ -98,32 +99,25 @@ def main():
|
|
98 |
|
99 |
if st.button("Start Conversation Simulation"):
|
100 |
with st.spinner("Starting simulation..."):
|
101 |
-
#
|
102 |
-
hf_token = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
|
103 |
-
if not hf_token:
|
104 |
-
st.error("HUGGINGFACEHUB_API_TOKEN not found. Please set the token.")
|
105 |
-
return
|
106 |
-
|
107 |
endpoint_url = f"https://api-inference.huggingface.co/models/{selected_model}"
|
108 |
-
|
109 |
-
|
110 |
try:
|
|
|
|
|
111 |
llm = HuggingFaceEndpoint(
|
112 |
endpoint_url=endpoint_url,
|
|
|
113 |
task="text-generation",
|
114 |
-
|
115 |
-
|
116 |
-
"temperature": 0.7,
|
117 |
-
"max_new_tokens": 512
|
118 |
-
}
|
119 |
)
|
120 |
except Exception as e:
|
121 |
st.error(f"Error initializing HuggingFaceEndpoint: {e}")
|
122 |
return
|
123 |
|
124 |
prompt = create_prompt(name, persona_style)
|
125 |
-
|
126 |
-
chain = RunnableSequence([prompt, llm])
|
127 |
|
128 |
conversation = simulate_conversation(chain, turns=15, max_history_rounds=3)
|
129 |
if conversation:
|
|
|
1 |
import os
|
2 |
import streamlit as st
|
3 |
import torch
|
4 |
+
from langchain.chains import LLMChain
|
5 |
from langchain.prompts import ChatPromptTemplate
|
6 |
+
|
7 |
+
# Use the new package for HuggingFaceEndpoint
|
8 |
from langchain_huggingface import HuggingFaceEndpoint
|
9 |
|
10 |
def create_prompt(name: str, persona_style: str):
|
|
|
34 |
"""
|
35 |
return ChatPromptTemplate.from_template(prompt_template_str)
|
36 |
|
37 |
+
def simulate_conversation(chain: LLMChain, turns: int = 15, max_history_rounds=3):
|
38 |
"""Simulate a conversation for a given number of turns, limiting chat history."""
|
39 |
chat_history_list = []
|
40 |
human_messages = [
|
|
|
59 |
for i in range(turns):
|
60 |
human_input = human_messages[i % len(human_messages)]
|
61 |
|
62 |
+
# Build truncated chat_history for prompt
|
63 |
truncated_history_lines = chat_history_list[-(max_history_rounds*2):]
|
64 |
truncated_history = "\n".join(truncated_history_lines)
|
65 |
|
66 |
+
response = chain.run(chat_history=truncated_history, input=human_input)
|
|
|
67 |
chat_history_list.append(f"Human: {human_input}")
|
68 |
chat_history_list.append(f"AI: {response}")
|
69 |
|
|
|
73 |
st.error(f"Error during conversation simulation: {e}")
|
74 |
return None
|
75 |
|
76 |
+
def summarize_conversation(chain: LLMChain, conversation: str):
|
77 |
"""Use the LLM to summarize the completed conversation."""
|
78 |
summary_prompt = f"Summarize the following conversation in a few short sentences highlighting the main points, tone, and conclusion:\n\n{conversation}\nSummary:"
|
79 |
try:
|
80 |
+
response = chain.run(chat_history="", input=summary_prompt)
|
81 |
return response.strip()
|
82 |
except Exception as e:
|
83 |
st.error(f"Error summarizing conversation: {e}")
|
|
|
99 |
|
100 |
if st.button("Start Conversation Simulation"):
|
101 |
with st.spinner("Starting simulation..."):
|
102 |
+
# Construct the endpoint URL for the selected model
|
|
|
|
|
|
|
|
|
|
|
103 |
endpoint_url = f"https://api-inference.huggingface.co/models/{selected_model}"
|
104 |
+
|
|
|
105 |
try:
|
106 |
+
# Use HuggingFaceEndpoint instead of HuggingFaceHub
|
107 |
+
# Specify temperature and max_new_tokens as top-level arguments
|
108 |
llm = HuggingFaceEndpoint(
|
109 |
endpoint_url=endpoint_url,
|
110 |
+
huggingfacehub_api_token=os.environ.get("HUGGINGFACEHUB_API_TOKEN"),
|
111 |
task="text-generation",
|
112 |
+
temperature=0.7,
|
113 |
+
max_new_tokens=512
|
|
|
|
|
|
|
114 |
)
|
115 |
except Exception as e:
|
116 |
st.error(f"Error initializing HuggingFaceEndpoint: {e}")
|
117 |
return
|
118 |
|
119 |
prompt = create_prompt(name, persona_style)
|
120 |
+
chain = LLMChain(llm=llm, prompt=prompt)
|
|
|
121 |
|
122 |
conversation = simulate_conversation(chain, turns=15, max_history_rounds=3)
|
123 |
if conversation:
|