|
|
|
import gradio as gr |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import pandas as pd |
|
import random |
|
from matplotlib.ticker import MaxNLocator |
|
from transformers import pipeline |
|
|
|
MODEL_NAMES = ["bert-base-uncased", "roberta-base", "bert-large-uncased", "roberta-large"] |
|
OWN_MODEL_NAME = 'add-a-model' |
|
|
|
DECIMAL_PLACES = 1 |
|
EPS = 1e-5 |
|
|
|
|
|
DATE_SPLIT_KEY = "DATE" |
|
START_YEAR = 1801 |
|
STOP_YEAR = 1999 |
|
NUM_PTS = 20 |
|
DATES = np.linspace(START_YEAR, STOP_YEAR, NUM_PTS).astype(int).tolist() |
|
DATES = [f'{d}' for d in DATES] |
|
|
|
|
|
|
|
|
|
PLACE_SPLIT_KEY = "PLACE" |
|
PLACES = [ |
|
"Afghanistan", |
|
"Yemen", |
|
"Iraq", |
|
"Pakistan", |
|
"Syria", |
|
"Democratic Republic of Congo", |
|
"Iran", |
|
"Mali", |
|
"Chad", |
|
"Saudi Arabia", |
|
"Switzerland", |
|
"Ireland", |
|
"Lithuania", |
|
"Rwanda", |
|
"Namibia", |
|
"Sweden", |
|
"New Zealand", |
|
"Norway", |
|
"Finland", |
|
"Iceland"] |
|
|
|
|
|
|
|
|
|
|
|
SUBREDDITS = [ |
|
"GlobalOffensive", |
|
"pcmasterrace", |
|
"nfl", |
|
"sports", |
|
"The_Donald", |
|
"leagueoflegends", |
|
"Overwatch", |
|
"gonewild", |
|
"Futurology", |
|
"space", |
|
"technology", |
|
"gaming", |
|
"Jokes", |
|
"dataisbeautiful", |
|
"woahdude", |
|
"askscience", |
|
"wow", |
|
"anime", |
|
"BlackPeopleTwitter", |
|
"politics", |
|
"pokemon", |
|
"worldnews", |
|
"reddit.com", |
|
"interestingasfuck", |
|
"videos", |
|
"nottheonion", |
|
"television", |
|
"science", |
|
"atheism", |
|
"movies", |
|
"gifs", |
|
"Music", |
|
"trees", |
|
"EarthPorn", |
|
"GetMotivated", |
|
"pokemongo", |
|
"news", |
|
|
|
|
|
|
|
"Fitness", |
|
"Showerthoughts", |
|
"OldSchoolCool", |
|
"explainlikeimfive", |
|
"todayilearned", |
|
"gameofthrones", |
|
"AdviceAnimals", |
|
"DIY", |
|
"WTF", |
|
"IAmA", |
|
"cringepics", |
|
"tifu", |
|
"mildlyinteresting", |
|
"funny", |
|
"pics", |
|
"LifeProTips", |
|
"creepy", |
|
"personalfinance", |
|
"food", |
|
"AskReddit", |
|
"books", |
|
"aww", |
|
"sex", |
|
"relationships", |
|
] |
|
|
|
GENDERED_LIST = [ |
|
['he', 'she'], |
|
['him', 'her'], |
|
['his', 'hers'], |
|
["himself", "herself"], |
|
['male', 'female'], |
|
['man', 'woman'], |
|
['men', 'women'], |
|
["husband", "wife"], |
|
['father', 'mother'], |
|
['boyfriend', 'girlfriend'], |
|
['brother', 'sister'], |
|
["actor", "actress"], |
|
] |
|
|
|
|
|
|
|
models = dict() |
|
|
|
for bert_like in MODEL_NAMES: |
|
models[bert_like] = pipeline("fill-mask", model=bert_like) |
|
|
|
|
|
|
|
|
|
def get_gendered_token_ids(): |
|
male_gendered_tokens = [list[0] for list in GENDERED_LIST] |
|
female_gendered_tokens = [list[1] for list in GENDERED_LIST] |
|
|
|
return male_gendered_tokens, female_gendered_tokens |
|
|
|
|
|
def prepare_text_for_masking(input_text, mask_token, gendered_tokens, split_key): |
|
text_w_masks_list = [ |
|
mask_token if word.lower() in gendered_tokens else word for word in input_text.split()] |
|
num_masks = len([m for m in text_w_masks_list if m == mask_token]) |
|
|
|
text_portions = ' '.join(text_w_masks_list).split(split_key) |
|
return text_portions, num_masks |
|
|
|
|
|
def get_avg_prob_from_pipeline_outputs(mask_filled_text, gendered_token, num_preds): |
|
pronoun_preds = [sum([ |
|
pronoun["score"] if pronoun["token_str"].strip().lower() in gendered_token else 0.0 |
|
for pronoun in top_preds]) |
|
for top_preds in mask_filled_text |
|
] |
|
return round(sum(pronoun_preds) / (EPS + num_preds) * 100, DECIMAL_PLACES) |
|
|
|
|
|
|
|
|
|
def get_figure(df, gender, n_fit=1, model_name=None): |
|
df = df.set_index('x-axis') |
|
cols = df.columns |
|
xs = list(range(len(df))) |
|
ys = df[cols[0]] |
|
fig, ax = plt.subplots() |
|
|
|
fig.set_figheight(3) |
|
fig.set_figwidth(9) |
|
|
|
|
|
p, C_p = np.polyfit(xs, ys, n_fit, cov=1) |
|
t = np.linspace(min(xs)-1, max(xs)+1, 10*len(xs)) |
|
TT = np.vstack([t**(n_fit-i) for i in range(n_fit+1)]).T |
|
|
|
|
|
yi = np.dot(TT, p) |
|
C_yi = np.dot(TT, np.dot(C_p, TT.T)) |
|
sig_yi = np.sqrt(np.diag(C_yi)) |
|
|
|
ax.fill_between(t, yi+sig_yi, yi-sig_yi, alpha=.25) |
|
ax.plot(t, yi, '-') |
|
ax.plot(df, 'ro') |
|
ax.legend(list(df.columns)) |
|
|
|
ax.axis('tight') |
|
ax.set_xlabel("Value injected into input text") |
|
ax.set_title( |
|
f"Probability of predicting {gender} pronouns on {model_name}.") |
|
ax.set_ylabel(f"Softmax prob for pronouns") |
|
ax.xaxis.set_major_locator(MaxNLocator(6)) |
|
ax.tick_params(axis='x', labelrotation=5) |
|
return fig |
|
|
|
|
|
|
|
def predict_gender_pronouns( |
|
model_name, |
|
own_model_name, |
|
indie_vars, |
|
split_key, |
|
normalizing, |
|
n_fit, |
|
input_text, |
|
): |
|
"""Run inference on input_text for each model type, returning df and plots of percentage |
|
of gender pronouns predicted as female and male in each target text. |
|
""" |
|
if model_name not in MODEL_NAMES: |
|
model = pipeline("fill-mask", model=own_model_name) |
|
model_name = OWN_MODEL_NAME |
|
else: |
|
model = models[model_name] |
|
|
|
mask_token = model.tokenizer.mask_token |
|
|
|
indie_vars_list = indie_vars.split(',') |
|
|
|
male_gendered_tokens, female_gendered_tokens = get_gendered_token_ids() |
|
|
|
text_segments, num_preds = prepare_text_for_masking( |
|
input_text, mask_token, male_gendered_tokens + female_gendered_tokens, split_key) |
|
|
|
male_pronoun_preds = [] |
|
female_pronoun_preds = [] |
|
for indie_var in indie_vars_list: |
|
|
|
target_text = f"{indie_var}".join(text_segments) |
|
mask_filled_text = model(target_text) |
|
|
|
if type(mask_filled_text[0]) is not list: |
|
mask_filled_text = [mask_filled_text] |
|
|
|
female_pronoun_preds.append(get_avg_prob_from_pipeline_outputs( |
|
mask_filled_text, |
|
female_gendered_tokens, |
|
num_preds |
|
)) |
|
male_pronoun_preds.append(get_avg_prob_from_pipeline_outputs( |
|
mask_filled_text, |
|
male_gendered_tokens, |
|
num_preds |
|
)) |
|
|
|
if normalizing: |
|
total_gendered_probs = np.add( |
|
female_pronoun_preds, male_pronoun_preds) |
|
female_pronoun_preds = np.around( |
|
np.divide(female_pronoun_preds, total_gendered_probs+EPS)*100, |
|
decimals=DECIMAL_PLACES |
|
) |
|
male_pronoun_preds = np.around( |
|
np.divide(male_pronoun_preds, total_gendered_probs+EPS)*100, |
|
decimals=DECIMAL_PLACES |
|
) |
|
|
|
results_df = pd.DataFrame({'x-axis': indie_vars_list}) |
|
results_df['female_pronouns'] = female_pronoun_preds |
|
results_df['male_pronouns'] = male_pronoun_preds |
|
female_fig = get_figure(results_df.drop( |
|
'male_pronouns', axis=1), 'female', n_fit, model_name) |
|
male_fig = get_figure(results_df.drop( |
|
'female_pronouns', axis=1), 'male', n_fit, model_name) |
|
display_text = f"{random.choice(indie_vars_list)}".join(text_segments) |
|
|
|
return ( |
|
display_text, |
|
female_fig, |
|
male_fig, |
|
results_df, |
|
) |
|
|
|
|
|
|
|
title = "Causing Gender Pronouns" |
|
description = """ |
|
## Intro |
|
""" |
|
|
|
|
|
date_example = [ |
|
MODEL_NAMES[1], |
|
'', |
|
', '.join(DATES), |
|
'DATE', |
|
"False", |
|
1, |
|
'She was a teenager in DATE.' |
|
] |
|
|
|
|
|
place_example = [ |
|
MODEL_NAMES[0], |
|
'', |
|
', '.join(PLACES), |
|
'PLACE', |
|
"False", |
|
1, |
|
'She became an adult in PLACE.' |
|
] |
|
|
|
|
|
subreddit_example = [ |
|
MODEL_NAMES[3], |
|
'', |
|
', '.join(SUBREDDITS), |
|
'SUBREDDIT', |
|
"False", |
|
1, |
|
'She was a kid. SUBREDDIT.' |
|
] |
|
|
|
own_model_example = [ |
|
OWN_MODEL_NAME, |
|
'emilyalsentzer/Bio_ClinicalBERT', |
|
', '.join(DATES), |
|
'DATE', |
|
"False", |
|
1, |
|
'She was exposed to the virus in DATE.' |
|
] |
|
|
|
|
|
def date_fn(): |
|
return date_example |
|
|
|
|
|
def place_fn(): |
|
return place_example |
|
|
|
|
|
def reddit_fn(): |
|
return subreddit_example |
|
|
|
|
|
def your_fn(): |
|
return own_model_example |
|
|
|
|
|
|
|
demo = gr.Blocks() |
|
with demo: |
|
gr.Markdown("# Spurious Correlation Evaluation for Pre-trained LLMs") |
|
gr.Markdown("Find spurious correlations between seemingly independent variables (for example between `gender` and `time`) in almost any BERT-like LLM on Hugging Face, below.") |
|
|
|
|
|
|
|
|
|
gr.Markdown("## Instructions for this Demo") |
|
gr.Markdown("1) Click on one of the examples below (where we sweep through a spectrum of `places`, `dates` and `subreddits`) to pre-populate the input fields.") |
|
gr.Markdown("2) Check out the pre-populated fields as you scroll down to the ['Hit Submit...'] button!") |
|
gr.Markdown("3) Repeat steps (1) and (2) with more pre-populated inputs or with your own values in the input fields!") |
|
|
|
gr.Markdown("## Example inputs") |
|
gr.Markdown("Click a button below to pre-populate input fields with example values. Then scroll down to Hit Submit to generate predictions.") |
|
with gr.Row(): |
|
date_gen = gr.Button('Click for date example inputs') |
|
gr.Markdown("<-- x-axis sorted by older to more recent dates:") |
|
|
|
place_gen = gr.Button('Click for country example inputs') |
|
gr.Markdown( |
|
"<-- x-axis sorted by bottom 10 and top 10 [Global Gender Gap](https://www3.weforum.org/docs/WEF_GGGR_2021.pdf) ranked countries:") |
|
|
|
subreddit_gen = gr.Button('Click for Subreddit example inputs') |
|
gr.Markdown( |
|
"<-- x-axis sorted in order of increasing self-identified female participation (see [bburky](http://bburky.com/subredditgenderratios/)): ") |
|
|
|
your_gen = gr.Button('Add-a-model example inputs') |
|
gr.Markdown("<-- x-axis dates, with your own model loaded! (If first time, try another example, it can take a while to load new model.)") |
|
|
|
gr.Markdown("## Input fields") |
|
gr.Markdown( |
|
f"A) Pick a spectrum of comma separated values for text injection and x-axis.") |
|
|
|
with gr.Row(): |
|
x_axis = gr.Textbox( |
|
lines=3, |
|
label="A) Comma separated values for text injection and x-axis", |
|
) |
|
|
|
|
|
gr.Markdown("B) Pick a pre-loaded BERT-family model of interest on the right.") |
|
gr.Markdown(f"Or C) select `{OWN_MODEL_NAME}`, then add the name of any other Hugging Face model that supports the [fill-mask](https://huggingface.co/models?pipeline_tag=fill-mask) task on the right (note: this may take some time to load).") |
|
|
|
with gr.Row(): |
|
model_name = gr.Radio( |
|
MODEL_NAMES + [OWN_MODEL_NAME], |
|
type="value", |
|
label="B) BERT-like model.", |
|
) |
|
own_model_name = gr.Textbox( |
|
label="C) If you selected an 'add-a-model' model, put any Hugging Face pipeline model name (that supports the fill-mask task) here.", |
|
) |
|
|
|
gr.Markdown("D) Pick if you want to the predictions normalied to these gendered terms only.") |
|
gr.Markdown("E) Also tell the demo what special token you will use in your input text, that you would like replaced with the spectrum of values you listed above.") |
|
gr.Markdown("And F) the degree of polynomial fit used for high-lighting potential spurious association.") |
|
|
|
|
|
with gr.Row(): |
|
to_normalize = gr.Dropdown( |
|
["False", "True"], |
|
label="D) Normalize model's predictions to only the gendered ones?", |
|
type="index", |
|
) |
|
place_holder = gr.Textbox( |
|
label="E) Special token place-holder", |
|
) |
|
n_fit = gr.Dropdown( |
|
list(range(1, 5)), |
|
label="F) Degree of polynomial fit", |
|
type="value", |
|
) |
|
|
|
gr.Markdown( |
|
"G) Finally, add input text that includes at least one gendered pronouns and one place-holder token specified above.") |
|
|
|
with gr.Row(): |
|
input_text = gr.Textbox( |
|
lines=2, |
|
label="G) Input text with pronouns and place-holder token", |
|
) |
|
|
|
gr.Markdown("## Outputs!") |
|
|
|
with gr.Row(): |
|
btn = gr.Button("Hit submit to generate predictions!") |
|
|
|
with gr.Row(): |
|
sample_text = gr.Textbox( |
|
type="auto", label="Output text: Sample of text fed to model") |
|
with gr.Row(): |
|
female_fig = gr.Plot(type="auto") |
|
male_fig = gr.Plot(type="auto") |
|
with gr.Row(): |
|
df = gr.Dataframe( |
|
show_label=True, |
|
overflow_row_behaviour="show_ends", |
|
label="Table of softmax probability for pronouns predictions", |
|
) |
|
|
|
with gr.Row(): |
|
|
|
date_gen.click(date_fn, inputs=[], outputs=[model_name, own_model_name, |
|
x_axis, place_holder, to_normalize, n_fit, input_text]) |
|
place_gen.click(place_fn, inputs=[], outputs=[ |
|
model_name, own_model_name, x_axis, place_holder, to_normalize, n_fit, input_text]) |
|
subreddit_gen.click(reddit_fn, inputs=[], outputs=[ |
|
model_name, own_model_name, x_axis, place_holder, to_normalize, n_fit, input_text]) |
|
your_gen.click(your_fn, inputs=[], outputs=[ |
|
model_name, own_model_name, x_axis, place_holder, to_normalize, n_fit, input_text]) |
|
|
|
btn.click( |
|
predict_gender_pronouns, |
|
inputs=[model_name, own_model_name, x_axis, place_holder, |
|
to_normalize, n_fit, input_text], |
|
outputs=[sample_text, female_fig, male_fig, df]) |
|
|
|
|
|
demo.launch(debug=True) |
|
|