Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import os
|
2 |
import streamlit as st
|
3 |
import torch
|
4 |
-
from langchain.chains import LLMChain
|
5 |
from langchain.prompts import ChatPromptTemplate
|
6 |
-
from
|
|
|
7 |
|
8 |
def create_prompt(name: str, persona_style: str):
|
9 |
"""Create the chat prompt template."""
|
@@ -32,7 +32,7 @@ def create_prompt(name: str, persona_style: str):
|
|
32 |
"""
|
33 |
return ChatPromptTemplate.from_template(prompt_template_str)
|
34 |
|
35 |
-
def simulate_conversation(chain:
|
36 |
"""Simulate a conversation for a given number of turns, limiting chat history."""
|
37 |
chat_history_list = []
|
38 |
human_messages = [
|
@@ -57,12 +57,11 @@ def simulate_conversation(chain: LLMChain, turns: int = 15, max_history_rounds=3
|
|
57 |
for i in range(turns):
|
58 |
human_input = human_messages[i % len(human_messages)]
|
59 |
|
60 |
-
#
|
61 |
-
# Keep only the last max_history_rounds * 2 lines (Human + AI pairs)
|
62 |
truncated_history_lines = chat_history_list[-(max_history_rounds*2):]
|
63 |
truncated_history = "\n".join(truncated_history_lines)
|
64 |
|
65 |
-
response = chain.
|
66 |
# Update chat history
|
67 |
chat_history_list.append(f"Human: {human_input}")
|
68 |
chat_history_list.append(f"AI: {response}")
|
@@ -73,11 +72,11 @@ def simulate_conversation(chain: LLMChain, turns: int = 15, max_history_rounds=3
|
|
73 |
st.error(f"Error during conversation simulation: {e}")
|
74 |
return None
|
75 |
|
76 |
-
def summarize_conversation(chain:
|
77 |
"""Use the LLM to summarize the completed conversation."""
|
78 |
summary_prompt = f"Summarize the following conversation in a few short sentences highlighting the main points, tone, and conclusion:\n\n{conversation}\nSummary:"
|
79 |
try:
|
80 |
-
response = chain.
|
81 |
return response.strip()
|
82 |
except Exception as e:
|
83 |
st.error(f"Error summarizing conversation: {e}")
|
@@ -86,7 +85,6 @@ def summarize_conversation(chain: LLMChain, conversation: str):
|
|
86 |
def main():
|
87 |
st.title("LLM Conversation Simulation")
|
88 |
|
89 |
-
# Model selection
|
90 |
model_names = [
|
91 |
"meta-llama/Llama-3.3-70B-Instruct",
|
92 |
"meta-llama/Llama-3.1-405B-Instruct",
|
@@ -94,38 +92,44 @@ def main():
|
|
94 |
]
|
95 |
selected_model = st.selectbox("Select a model:", model_names)
|
96 |
|
97 |
-
# Persona Inputs
|
98 |
name = st.text_input("Enter the persona's name:", value="Alex")
|
99 |
persona_style = st.text_area("Enter the persona style characteristics:",
|
100 |
value="friendly, curious, and a bit sarcastic")
|
101 |
|
102 |
if st.button("Start Conversation Simulation"):
|
103 |
with st.spinner("Starting simulation..."):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
try:
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
model_kwargs={
|
110 |
"temperature": 0.7,
|
111 |
"max_new_tokens": 512
|
112 |
}
|
113 |
)
|
114 |
except Exception as e:
|
115 |
-
st.error(f"Error initializing
|
116 |
return
|
117 |
|
118 |
-
# Create our prompt template chain
|
119 |
prompt = create_prompt(name, persona_style)
|
120 |
-
|
|
|
121 |
|
122 |
-
# Simulate conversation
|
123 |
conversation = simulate_conversation(chain, turns=15, max_history_rounds=3)
|
124 |
if conversation:
|
125 |
st.subheader("Conversation:")
|
126 |
st.text(conversation)
|
127 |
|
128 |
-
# Summarize conversation
|
129 |
st.subheader("Summary:")
|
130 |
summary = summarize_conversation(chain, conversation)
|
131 |
st.write(summary)
|
|
|
1 |
import os
|
2 |
import streamlit as st
|
3 |
import torch
|
|
|
4 |
from langchain.prompts import ChatPromptTemplate
|
5 |
+
from langchain.schema.runnable import RunnableSequence
|
6 |
+
from langchain_huggingface import HuggingFaceEndpoint
|
7 |
|
8 |
def create_prompt(name: str, persona_style: str):
|
9 |
"""Create the chat prompt template."""
|
|
|
32 |
"""
|
33 |
return ChatPromptTemplate.from_template(prompt_template_str)
|
34 |
|
35 |
+
def simulate_conversation(chain: RunnableSequence, turns: int = 15, max_history_rounds=3):
|
36 |
"""Simulate a conversation for a given number of turns, limiting chat history."""
|
37 |
chat_history_list = []
|
38 |
human_messages = [
|
|
|
57 |
for i in range(turns):
|
58 |
human_input = human_messages[i % len(human_messages)]
|
59 |
|
60 |
+
# Keep only last max_history_rounds * 2 lines
|
|
|
61 |
truncated_history_lines = chat_history_list[-(max_history_rounds*2):]
|
62 |
truncated_history = "\n".join(truncated_history_lines)
|
63 |
|
64 |
+
response = chain.invoke({"chat_history": truncated_history, "input": human_input})
|
65 |
# Update chat history
|
66 |
chat_history_list.append(f"Human: {human_input}")
|
67 |
chat_history_list.append(f"AI: {response}")
|
|
|
72 |
st.error(f"Error during conversation simulation: {e}")
|
73 |
return None
|
74 |
|
75 |
+
def summarize_conversation(chain: RunnableSequence, conversation: str):
|
76 |
"""Use the LLM to summarize the completed conversation."""
|
77 |
summary_prompt = f"Summarize the following conversation in a few short sentences highlighting the main points, tone, and conclusion:\n\n{conversation}\nSummary:"
|
78 |
try:
|
79 |
+
response = chain.invoke({"chat_history": "", "input": summary_prompt})
|
80 |
return response.strip()
|
81 |
except Exception as e:
|
82 |
st.error(f"Error summarizing conversation: {e}")
|
|
|
85 |
def main():
|
86 |
st.title("LLM Conversation Simulation")
|
87 |
|
|
|
88 |
model_names = [
|
89 |
"meta-llama/Llama-3.3-70B-Instruct",
|
90 |
"meta-llama/Llama-3.1-405B-Instruct",
|
|
|
92 |
]
|
93 |
selected_model = st.selectbox("Select a model:", model_names)
|
94 |
|
|
|
95 |
name = st.text_input("Enter the persona's name:", value="Alex")
|
96 |
persona_style = st.text_area("Enter the persona style characteristics:",
|
97 |
value="friendly, curious, and a bit sarcastic")
|
98 |
|
99 |
if st.button("Start Conversation Simulation"):
|
100 |
with st.spinner("Starting simulation..."):
|
101 |
+
# Build headers with your Hugging Face token
|
102 |
+
hf_token = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
|
103 |
+
if not hf_token:
|
104 |
+
st.error("HUGGINGFACEHUB_API_TOKEN not found. Please set the token.")
|
105 |
+
return
|
106 |
+
|
107 |
+
endpoint_url = f"https://api-inference.huggingface.co/models/{selected_model}"
|
108 |
+
headers = {"Authorization": f"Bearer {hf_token}"}
|
109 |
+
|
110 |
try:
|
111 |
+
llm = HuggingFaceEndpoint(
|
112 |
+
endpoint_url=endpoint_url,
|
113 |
+
task="text-generation",
|
114 |
+
headers=headers,
|
115 |
model_kwargs={
|
116 |
"temperature": 0.7,
|
117 |
"max_new_tokens": 512
|
118 |
}
|
119 |
)
|
120 |
except Exception as e:
|
121 |
+
st.error(f"Error initializing HuggingFaceEndpoint: {e}")
|
122 |
return
|
123 |
|
|
|
124 |
prompt = create_prompt(name, persona_style)
|
125 |
+
# prompt and llm are both Runnables, chain them together
|
126 |
+
chain = RunnableSequence([prompt, llm])
|
127 |
|
|
|
128 |
conversation = simulate_conversation(chain, turns=15, max_history_rounds=3)
|
129 |
if conversation:
|
130 |
st.subheader("Conversation:")
|
131 |
st.text(conversation)
|
132 |
|
|
|
133 |
st.subheader("Summary:")
|
134 |
summary = summarize_conversation(chain, conversation)
|
135 |
st.write(summary)
|