RakeshUtekar commited on
Commit
215e74e
·
verified ·
1 Parent(s): 19bd580

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -22
app.py CHANGED
@@ -1,8 +1,10 @@
1
  import os
2
  import streamlit as st
3
  import torch
 
4
  from langchain.prompts import ChatPromptTemplate
5
- from langchain.schema.runnable import RunnableSequence
 
6
  from langchain_huggingface import HuggingFaceEndpoint
7
 
8
  def create_prompt(name: str, persona_style: str):
@@ -32,7 +34,7 @@ def create_prompt(name: str, persona_style: str):
32
  """
33
  return ChatPromptTemplate.from_template(prompt_template_str)
34
 
35
- def simulate_conversation(chain: RunnableSequence, turns: int = 15, max_history_rounds=3):
36
  """Simulate a conversation for a given number of turns, limiting chat history."""
37
  chat_history_list = []
38
  human_messages = [
@@ -57,12 +59,11 @@ def simulate_conversation(chain: RunnableSequence, turns: int = 15, max_history_
57
  for i in range(turns):
58
  human_input = human_messages[i % len(human_messages)]
59
 
60
- # Keep only last max_history_rounds * 2 lines
61
  truncated_history_lines = chat_history_list[-(max_history_rounds*2):]
62
  truncated_history = "\n".join(truncated_history_lines)
63
 
64
- response = chain.invoke({"chat_history": truncated_history, "input": human_input})
65
- # Update chat history
66
  chat_history_list.append(f"Human: {human_input}")
67
  chat_history_list.append(f"AI: {response}")
68
 
@@ -72,11 +73,11 @@ def simulate_conversation(chain: RunnableSequence, turns: int = 15, max_history_
72
  st.error(f"Error during conversation simulation: {e}")
73
  return None
74
 
75
- def summarize_conversation(chain: RunnableSequence, conversation: str):
76
  """Use the LLM to summarize the completed conversation."""
77
  summary_prompt = f"Summarize the following conversation in a few short sentences highlighting the main points, tone, and conclusion:\n\n{conversation}\nSummary:"
78
  try:
79
- response = chain.invoke({"chat_history": "", "input": summary_prompt})
80
  return response.strip()
81
  except Exception as e:
82
  st.error(f"Error summarizing conversation: {e}")
@@ -98,32 +99,25 @@ def main():
98
 
99
  if st.button("Start Conversation Simulation"):
100
  with st.spinner("Starting simulation..."):
101
- # Build headers with your Hugging Face token
102
- hf_token = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
103
- if not hf_token:
104
- st.error("HUGGINGFACEHUB_API_TOKEN not found. Please set the token.")
105
- return
106
-
107
  endpoint_url = f"https://api-inference.huggingface.co/models/{selected_model}"
108
- headers = {"Authorization": f"Bearer {hf_token}"}
109
-
110
  try:
 
 
111
  llm = HuggingFaceEndpoint(
112
  endpoint_url=endpoint_url,
 
113
  task="text-generation",
114
- headers=headers,
115
- model_kwargs={
116
- "temperature": 0.7,
117
- "max_new_tokens": 512
118
- }
119
  )
120
  except Exception as e:
121
  st.error(f"Error initializing HuggingFaceEndpoint: {e}")
122
  return
123
 
124
  prompt = create_prompt(name, persona_style)
125
- # prompt and llm are both Runnables, chain them together
126
- chain = RunnableSequence([prompt, llm])
127
 
128
  conversation = simulate_conversation(chain, turns=15, max_history_rounds=3)
129
  if conversation:
 
1
  import os
2
  import streamlit as st
3
  import torch
4
+ from langchain.chains import LLMChain
5
  from langchain.prompts import ChatPromptTemplate
6
+
7
+ # Use the new package for HuggingFaceEndpoint
8
  from langchain_huggingface import HuggingFaceEndpoint
9
 
10
  def create_prompt(name: str, persona_style: str):
 
34
  """
35
  return ChatPromptTemplate.from_template(prompt_template_str)
36
 
37
+ def simulate_conversation(chain: LLMChain, turns: int = 15, max_history_rounds=3):
38
  """Simulate a conversation for a given number of turns, limiting chat history."""
39
  chat_history_list = []
40
  human_messages = [
 
59
  for i in range(turns):
60
  human_input = human_messages[i % len(human_messages)]
61
 
62
+ # Build truncated chat_history for prompt
63
  truncated_history_lines = chat_history_list[-(max_history_rounds*2):]
64
  truncated_history = "\n".join(truncated_history_lines)
65
 
66
+ response = chain.run(chat_history=truncated_history, input=human_input)
 
67
  chat_history_list.append(f"Human: {human_input}")
68
  chat_history_list.append(f"AI: {response}")
69
 
 
73
  st.error(f"Error during conversation simulation: {e}")
74
  return None
75
 
76
+ def summarize_conversation(chain: LLMChain, conversation: str):
77
  """Use the LLM to summarize the completed conversation."""
78
  summary_prompt = f"Summarize the following conversation in a few short sentences highlighting the main points, tone, and conclusion:\n\n{conversation}\nSummary:"
79
  try:
80
+ response = chain.run(chat_history="", input=summary_prompt)
81
  return response.strip()
82
  except Exception as e:
83
  st.error(f"Error summarizing conversation: {e}")
 
99
 
100
  if st.button("Start Conversation Simulation"):
101
  with st.spinner("Starting simulation..."):
102
+ # Construct the endpoint URL for the selected model
 
 
 
 
 
103
  endpoint_url = f"https://api-inference.huggingface.co/models/{selected_model}"
104
+
 
105
  try:
106
+ # Use HuggingFaceEndpoint instead of HuggingFaceHub
107
+ # Specify temperature and max_new_tokens as top-level arguments
108
  llm = HuggingFaceEndpoint(
109
  endpoint_url=endpoint_url,
110
+ huggingfacehub_api_token=os.environ.get("HUGGINGFACEHUB_API_TOKEN"),
111
  task="text-generation",
112
+ temperature=0.7,
113
+ max_new_tokens=512
 
 
 
114
  )
115
  except Exception as e:
116
  st.error(f"Error initializing HuggingFaceEndpoint: {e}")
117
  return
118
 
119
  prompt = create_prompt(name, persona_style)
120
+ chain = LLMChain(llm=llm, prompt=prompt)
 
121
 
122
  conversation = simulate_conversation(chain, turns=15, max_history_rounds=3)
123
  if conversation: