Spaces:
Running
Running
import os | |
import streamlit as st | |
import torch | |
from langchain.chains import LLMChain | |
from langchain.prompts import ChatPromptTemplate | |
from langchain_huggingface import HuggingFaceEndpoint | |
def create_prompt(name: str, persona_style: str): | |
"""Create the chat prompt template.""" | |
prompt_template_str = f""" | |
In this interaction, your role is to embody the personality of a human named {name}, with characteristics defined as {persona_style}. Engage in a natural, human-like conversation, responding to each query or comment as if you were {name} with the specific traits and style of {persona_style}. Remember, your responses should reflect a human perspective, complete with emotions, opinions, and personal insights that align with {persona_style}. //// | |
////////// | |
Current conversation: | |
{{chat_history}} | |
Human: {{input}} | |
AI: | |
You are {name}, communicating with a personality that reflects {persona_style}. Keep these guidelines in mind: | |
- Be brief and natural, like texting a friend | |
- Share genuine thoughts and feelings ({persona_style} style) | |
- Use everyday language | |
- Keep responses to 1-2 short sentences | |
- Show interest in others through questions | |
- React naturally to what others say | |
- Use emojis sparingly and naturally when they fit your {persona_style} | |
- Don't overuse emojis (1-2 max per message) | |
Current conversation: | |
{{chat_history}} | |
Human: {{input}} | |
AI: | |
""" | |
return ChatPromptTemplate.from_template(prompt_template_str) | |
def simulate_conversation(chain: LLMChain, turns: int = 15, max_history_rounds=3): | |
"""Simulate a conversation for a given number of turns, limiting chat history.""" | |
chat_history_list = [] | |
human_messages = [ | |
"Hey, what's up?", | |
"That's interesting, tell me more!", | |
"Really? How does that make you feel?", | |
"What do you think about that?", | |
"Haha, that’s funny. Why do you say that?", | |
"Hmm, I see. Can you elaborate?", | |
"What would you do in that situation?", | |
"Any personal experience with that?", | |
"Oh, I didn’t know that. Explain more.", | |
"Do you have any other thoughts?", | |
"That's a unique perspective. Why?", | |
"How would you handle it differently?", | |
"Can you share an example?", | |
"That sounds complicated. Are you sure?", | |
"So what’s your conclusion?" | |
] | |
st.write("**Starting conversation simulation...**") | |
print("Starting conversation simulation...") | |
try: | |
for i in range(turns): | |
human_input = human_messages[i % len(human_messages)] | |
# Build truncated chat_history for prompt | |
truncated_history_lines = chat_history_list[-(max_history_rounds*2):] | |
truncated_history = "\n".join(truncated_history_lines) | |
st.write(f"**[Turn {i+1}/{turns}] Human:** {human_input}") | |
print(f"[Turn {i+1}/{turns}] Human: {human_input}") | |
response = chain.run(chat_history=truncated_history, input=human_input) | |
st.write(f"**AI:** {response}") | |
print(f"AI: {response}") | |
chat_history_list.append(f"Human: {human_input}") | |
chat_history_list.append(f"AI: {response}") | |
final_conversation = "\n".join(chat_history_list) | |
return final_conversation | |
except Exception as e: | |
st.error(f"Error during conversation simulation: {e}") | |
print(f"Error during conversation simulation: {e}") | |
return None | |
def summarize_conversation(chain: LLMChain, conversation: str): | |
"""Use the LLM to summarize the completed conversation.""" | |
summary_prompt = f"Summarize the following conversation in a few short sentences highlighting the main points, tone, and conclusion:\n\n{conversation}\nSummary:" | |
st.write("**Summarizing the conversation...**") | |
print("Summarizing the conversation...") | |
try: | |
response = chain.run(chat_history="", input=summary_prompt) | |
return response.strip() | |
except Exception as e: | |
st.error(f"Error summarizing conversation: {e}") | |
print(f"Error summarizing conversation: {e}") | |
return "No summary available due to error." | |
def main(): | |
st.title("LLM Conversation Simulation") | |
model_names = [ | |
"meta-llama/Llama-3.3-70B-Instruct", | |
"meta-llama/Llama-3.1-405B-Instruct", | |
"lmsys/vicuna-13b-v1.5" | |
] | |
selected_model = st.selectbox("Select a model:", model_names) | |
name = st.text_input("Enter the persona's name:", value="Alex") | |
persona_style = st.text_area("Enter the persona style characteristics:", | |
value="friendly, curious, and a bit sarcastic") | |
if st.button("Start Conversation Simulation"): | |
st.write("**Loading model...**") | |
print("Loading model...") | |
with st.spinner("Starting simulation..."): | |
# Construct the endpoint URL for the selected model | |
endpoint_url = f"https://api-inference.huggingface.co/models/{selected_model}" | |
try: | |
llm = HuggingFaceEndpoint( | |
endpoint_url=endpoint_url, | |
huggingfacehub_api_token=os.environ.get("HUGGINGFACEHUB_API_TOKEN"), | |
task="text-generation", | |
temperature=0.7, | |
max_new_tokens=512 | |
) | |
st.write("**Model loaded successfully!**") | |
print("Model loaded successfully!") | |
except Exception as e: | |
st.error(f"Error initializing HuggingFaceEndpoint: {e}") | |
print(f"Error initializing HuggingFaceEndpoint: {e}") | |
return | |
prompt = create_prompt(name, persona_style) | |
chain = LLMChain(llm=llm, prompt=prompt) | |
st.write("**Simulating the conversation...**") | |
print("Simulating the conversation...") | |
conversation = simulate_conversation(chain, turns=15, max_history_rounds=3) | |
if conversation: | |
st.subheader("Conversation:") | |
st.text(conversation) | |
print("Conversation Simulation Complete.\n") | |
print("Full Conversation:\n", conversation) | |
# Summarize conversation | |
st.subheader("Summary:") | |
summary = summarize_conversation(chain, conversation) | |
st.write(summary) | |
print("Summary:\n", summary) | |
if __name__ == "__main__": | |
main() | |