mental-health / app.py
adeelshuaib's picture
Update app.py
54fb593 verified
raw
history blame
2.56 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# Load the tokenizer and model directly
tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-3B")
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/blenderbot-3B")
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
):
# Compile the messages for context
messages = [{"role": "system", "content": system_message}]
for val in history:
if val[0]: # user message
messages.append({"role": "user", "content": val[0]})
if val[1]: # assistant response
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
# Concatenate messages as input text
input_text = "\n".join([msg["content"] for msg in messages if msg["role"] == "user"])
# Tokenize input text and generate response
inputs = tokenizer(input_text, return_tensors="pt")
outputs = model.generate(
**inputs,
max_length=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True
)
# Decode the generated response
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Return response iteratively as tokens arrive (optional: can remove yield if streaming is not needed)
yield response
# Customize the system message for mental health support
default_system_message = """
You are a compassionate mental health specialist trained to listen empathetically and offer support.
When engaging with users, make sure to respond with kindness and provide general emotional support.
Avoid giving specific medical or clinical advice, but offer guidance, validate feelings, and suggest appropriate resources when needed.
Encourage open conversations and create a safe, non-judgmental space for the user to share.
"""
# Set up the Gradio interface
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value=default_system_message, label="System Message (Mental Health Specialist)"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
)
if __name__ == "__main__":
demo.launch()