RedBERT - a Reddit post classifier

This model based on distilbert is finetuned to predict the subreddit of a Reddit post.

Usage

Preparations

The model uses the transformers library, so make sure to install it.

pip install transformers[torch]

After the installation, the model can be loaded from Hugging Face.
The model will be sored localy so if you run this lines multiple times the model will be loaded from cache.

from transformers import pipeline
pipe = pipeline("text-classification", model="traberph/RedBERT")

Basic

For a simple classification task just call the pipeline with the text of your choice

text = "I (33f) need to explain to my coworker (30m) I don't want his company on the commute back home"
pipe(text)

output:
[{'label': 'relationships', 'score': 0.9622366428375244}]

Multiclass with visualization

Everyone likes visualizations! Therefore this is an example to output the 5 most probable labels and visualize the result.
Make sure that all requirements are satisfied.

pip install pandas seaborn
import pandas as pd
import seaborn as sns

# if the model is already loaded this can be skipped
from transformers import pipeline
pipe = pipeline("text-classification", model="traberph/RedBERT")

text = "Today I spilled coffee over my pc. It started to smoke and the screen turned black. I guess I have a problem now."

# predict the 5 most probable labels
res = pipe(text, top_k=5)

# create a pandas dataframe from the result
df = pd.DataFrame(res)

# use seaborn to create a barplot
sns.barplot(df, x='score', y='label', color='steelblue')

output:

Training

The training of the final version of this model took 130h on a single Tesla P100 GPU.
90% of the webis/tldr-17 where used for this version.

Bias and Limitations

The webis/tldr-17 dataset used to train this model contains 3 848 330 posts from 29 651 subreddits.
Those posts however are not equally distributed over the subreddits. 589 947 posts belong to the subreddit AskReddit, which is 15% of the whole dataset. Other subreddits are underrepresented.

top subreddits distribution
distribution distribution

This bias in the subreddit distribution is also represented in the model and can be observed during inference.

class labels for "Biden says US is at tipping point on gun control: We will ban assault weapons in this country", from r/politics
classification
Downloads last month
110
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Dataset used to train traberph/RedBERT