RakeshUtekar commited on
Commit
887b1f9
·
verified ·
1 Parent(s): 38b6ee6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -19
app.py CHANGED
@@ -1,9 +1,9 @@
1
  import os
2
  import streamlit as st
3
  import torch
4
- from langchain.chains import LLMChain
5
  from langchain.prompts import ChatPromptTemplate
6
- from langchain_community.llms import HuggingFaceHub
 
7
 
8
  def create_prompt(name: str, persona_style: str):
9
  """Create the chat prompt template."""
@@ -32,7 +32,7 @@ def create_prompt(name: str, persona_style: str):
32
  """
33
  return ChatPromptTemplate.from_template(prompt_template_str)
34
 
35
- def simulate_conversation(chain: LLMChain, turns: int = 15, max_history_rounds=3):
36
  """Simulate a conversation for a given number of turns, limiting chat history."""
37
  chat_history_list = []
38
  human_messages = [
@@ -57,12 +57,11 @@ def simulate_conversation(chain: LLMChain, turns: int = 15, max_history_rounds=3
57
  for i in range(turns):
58
  human_input = human_messages[i % len(human_messages)]
59
 
60
- # Build truncated chat_history for prompt
61
- # Keep only the last max_history_rounds * 2 lines (Human + AI pairs)
62
  truncated_history_lines = chat_history_list[-(max_history_rounds*2):]
63
  truncated_history = "\n".join(truncated_history_lines)
64
 
65
- response = chain.run(chat_history=truncated_history, input=human_input)
66
  # Update chat history
67
  chat_history_list.append(f"Human: {human_input}")
68
  chat_history_list.append(f"AI: {response}")
@@ -73,11 +72,11 @@ def simulate_conversation(chain: LLMChain, turns: int = 15, max_history_rounds=3
73
  st.error(f"Error during conversation simulation: {e}")
74
  return None
75
 
76
- def summarize_conversation(chain: LLMChain, conversation: str):
77
  """Use the LLM to summarize the completed conversation."""
78
  summary_prompt = f"Summarize the following conversation in a few short sentences highlighting the main points, tone, and conclusion:\n\n{conversation}\nSummary:"
79
  try:
80
- response = chain.run(chat_history="", input=summary_prompt)
81
  return response.strip()
82
  except Exception as e:
83
  st.error(f"Error summarizing conversation: {e}")
@@ -86,7 +85,6 @@ def summarize_conversation(chain: LLMChain, conversation: str):
86
  def main():
87
  st.title("LLM Conversation Simulation")
88
 
89
- # Model selection
90
  model_names = [
91
  "meta-llama/Llama-3.3-70B-Instruct",
92
  "meta-llama/Llama-3.1-405B-Instruct",
@@ -94,38 +92,44 @@ def main():
94
  ]
95
  selected_model = st.selectbox("Select a model:", model_names)
96
 
97
- # Persona Inputs
98
  name = st.text_input("Enter the persona's name:", value="Alex")
99
  persona_style = st.text_area("Enter the persona style characteristics:",
100
  value="friendly, curious, and a bit sarcastic")
101
 
102
  if st.button("Start Conversation Simulation"):
103
  with st.spinner("Starting simulation..."):
 
 
 
 
 
 
 
 
 
104
  try:
105
- # Use HuggingFaceHub as LLM
106
- # Make sure you have a valid HUGGINGFACEHUB_API_TOKEN set
107
- llm = HuggingFaceHub(
108
- repo_id=selected_model,
109
  model_kwargs={
110
  "temperature": 0.7,
111
  "max_new_tokens": 512
112
  }
113
  )
114
  except Exception as e:
115
- st.error(f"Error initializing model from Hugging Face Hub: {e}")
116
  return
117
 
118
- # Create our prompt template chain
119
  prompt = create_prompt(name, persona_style)
120
- chain = LLMChain(llm=llm, prompt=prompt)
 
121
 
122
- # Simulate conversation
123
  conversation = simulate_conversation(chain, turns=15, max_history_rounds=3)
124
  if conversation:
125
  st.subheader("Conversation:")
126
  st.text(conversation)
127
 
128
- # Summarize conversation
129
  st.subheader("Summary:")
130
  summary = summarize_conversation(chain, conversation)
131
  st.write(summary)
 
1
  import os
2
  import streamlit as st
3
  import torch
 
4
  from langchain.prompts import ChatPromptTemplate
5
+ from langchain.schema.runnable import RunnableSequence
6
+ from langchain_huggingface import HuggingFaceEndpoint
7
 
8
  def create_prompt(name: str, persona_style: str):
9
  """Create the chat prompt template."""
 
32
  """
33
  return ChatPromptTemplate.from_template(prompt_template_str)
34
 
35
+ def simulate_conversation(chain: RunnableSequence, turns: int = 15, max_history_rounds=3):
36
  """Simulate a conversation for a given number of turns, limiting chat history."""
37
  chat_history_list = []
38
  human_messages = [
 
57
  for i in range(turns):
58
  human_input = human_messages[i % len(human_messages)]
59
 
60
+ # Keep only last max_history_rounds * 2 lines
 
61
  truncated_history_lines = chat_history_list[-(max_history_rounds*2):]
62
  truncated_history = "\n".join(truncated_history_lines)
63
 
64
+ response = chain.invoke({"chat_history": truncated_history, "input": human_input})
65
  # Update chat history
66
  chat_history_list.append(f"Human: {human_input}")
67
  chat_history_list.append(f"AI: {response}")
 
72
  st.error(f"Error during conversation simulation: {e}")
73
  return None
74
 
75
+ def summarize_conversation(chain: RunnableSequence, conversation: str):
76
  """Use the LLM to summarize the completed conversation."""
77
  summary_prompt = f"Summarize the following conversation in a few short sentences highlighting the main points, tone, and conclusion:\n\n{conversation}\nSummary:"
78
  try:
79
+ response = chain.invoke({"chat_history": "", "input": summary_prompt})
80
  return response.strip()
81
  except Exception as e:
82
  st.error(f"Error summarizing conversation: {e}")
 
85
  def main():
86
  st.title("LLM Conversation Simulation")
87
 
 
88
  model_names = [
89
  "meta-llama/Llama-3.3-70B-Instruct",
90
  "meta-llama/Llama-3.1-405B-Instruct",
 
92
  ]
93
  selected_model = st.selectbox("Select a model:", model_names)
94
 
 
95
  name = st.text_input("Enter the persona's name:", value="Alex")
96
  persona_style = st.text_area("Enter the persona style characteristics:",
97
  value="friendly, curious, and a bit sarcastic")
98
 
99
  if st.button("Start Conversation Simulation"):
100
  with st.spinner("Starting simulation..."):
101
+ # Build headers with your Hugging Face token
102
+ hf_token = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
103
+ if not hf_token:
104
+ st.error("HUGGINGFACEHUB_API_TOKEN not found. Please set the token.")
105
+ return
106
+
107
+ endpoint_url = f"https://api-inference.huggingface.co/models/{selected_model}"
108
+ headers = {"Authorization": f"Bearer {hf_token}"}
109
+
110
  try:
111
+ llm = HuggingFaceEndpoint(
112
+ endpoint_url=endpoint_url,
113
+ task="text-generation",
114
+ headers=headers,
115
  model_kwargs={
116
  "temperature": 0.7,
117
  "max_new_tokens": 512
118
  }
119
  )
120
  except Exception as e:
121
+ st.error(f"Error initializing HuggingFaceEndpoint: {e}")
122
  return
123
 
 
124
  prompt = create_prompt(name, persona_style)
125
+ # prompt and llm are both Runnables, chain them together
126
+ chain = RunnableSequence([prompt, llm])
127
 
 
128
  conversation = simulate_conversation(chain, turns=15, max_history_rounds=3)
129
  if conversation:
130
  st.subheader("Conversation:")
131
  st.text(conversation)
132
 
 
133
  st.subheader("Summary:")
134
  summary = summarize_conversation(chain, conversation)
135
  st.write(summary)