Spaces:
Running
Running
import os | |
import streamlit as st | |
import torch | |
from langchain.chains import LLMChain | |
from langchain.prompts import ChatPromptTemplate | |
from langchain_huggingface import HuggingFaceEndpoint | |
def create_prompt(name1: str, name2: str, persona_style: str): | |
"""Create a prompt that instructs the model to produce all 15 messages at once.""" | |
prompt_template_str = f""" | |
You are to simulate a conversation of exactly 15 messages total between two people: {name1} and {name2}. | |
The conversation should reflect the style: {persona_style}. | |
{name1} speaks first (message 1), {name2} responds (message 2), then {name1} (message 3), and so on, alternating until 15 messages are complete. | |
Rules: | |
- Each message should be written as: | |
{name1}: <message> or {name2}: <message> | |
- Each message should be 1-2 short sentences, friendly, and natural. | |
- Keep it casual, can ask questions, share opinions. | |
- Use emojis sparingly if it fits the persona (no more than 1-2 per message). | |
- Do not repeat the same line over and over. | |
- The conversation must flow logically and naturally. | |
- After producing exactly 15 messages (the 15th message by {name1}), stop. Do not add anything else. | |
- Do not continue the conversation beyond 15 messages. | |
Produce all 15 messages now: | |
""" | |
return ChatPromptTemplate.from_template(prompt_template_str) | |
def summarize_conversation(chain: LLMChain, conversation: str, name1: str, name2: str): | |
"""Summarize the completed conversation.""" | |
st.write("**Summarizing the conversation...**") | |
print("Summarizing the conversation...") | |
summary_prompt = f""" | |
Below is a completed conversation between {name1} and {name2}: | |
{conversation} | |
Use the conversation above and write a short Title and a summary of above conversation. The summary should be in paragraph which highlights what was the conversation about. | |
""" | |
try: | |
response = chain.run(chat_history="", input=summary_prompt) | |
return response.strip() | |
except Exception as e: | |
st.error(f"Error summarizing conversation: {e}") | |
print(f"Error summarizing conversation: {e}") | |
return "Title: No Title\nSummary: No summary available due to error." | |
def main(): | |
st.title("LLM Conversation Simulation") | |
model_names = [ | |
"meta-llama/Llama-3.3-70B-Instruct", | |
"meta-llama/Llama-3.1-405B-Instruct", | |
"lmsys/vicuna-13b-v1.5" | |
] | |
selected_model = st.selectbox("Select a model:", model_names) | |
name1 = st.text_input("Enter the first user's name:", value="Alice") | |
name2 = st.text_input("Enter the second user's name:", value="Bob") | |
persona_style = st.text_area("Enter the persona style characteristics:", | |
value="friendly, curious, and a bit sarcastic") | |
if st.button("Start Conversation Simulation"): | |
st.write("**Loading model...**") | |
print("Loading model...") | |
with st.spinner("Starting simulation..."): | |
endpoint_url = f"https://api-inference.huggingface.co/models/{selected_model}" | |
try: | |
llm = HuggingFaceEndpoint( | |
endpoint_url=endpoint_url, | |
huggingfacehub_api_token=os.environ.get("HUGGINGFACEHUB_API_TOKEN"), | |
task="text-generation", | |
temperature=0.7, | |
max_new_tokens=512 | |
) | |
st.write("**Model loaded successfully!**") | |
print("Model loaded successfully!") | |
except Exception as e: | |
st.error(f"Error initializing HuggingFaceEndpoint: {e}") | |
print(f"Error initializing HuggingFaceEndpoint: {e}") | |
return | |
prompt = create_prompt(name1, name2, persona_style) | |
chain = LLMChain(llm=llm, prompt=prompt) | |
st.write("**Generating the full 15-message conversation...**") | |
print("Generating the full 15-message conversation...") | |
try: | |
# Generate all 15 messages in one go | |
conversation = chain.run(chat_history="", input="Produce the full conversation now.") | |
conversation = conversation.strip() | |
# Print and display the conversation | |
st.subheader("Final Conversation:") | |
st.text(conversation) | |
print("Conversation Generation Complete.\n") | |
print("Full Conversation:\n", conversation) | |
# Summarize the conversation | |
st.subheader("Summary and Title:") | |
summary = summarize_conversation(chain, conversation, name1, name2) | |
st.write(summary) | |
print("Summary:\n", summary) | |
except Exception as e: | |
st.error(f"Error generating conversation: {e}") | |
print(f"Error generating conversation: {e}") | |
if __name__ == "__main__": | |
main() | |