Leveraging Transformers and PyTorch for Multiple Choice Question Tasks

Community Article Published December 25, 2023

image/png

Introduction

Multiple Choice Questions (MCQs) are a ubiquitous form of assessment across various domains, from education to recruitment. The advent of deep learning, especially transformer-based architectures, has revolutionized natural language processing (NLP) tasks, making them incredibly effective for handling MCQs. PyTorch, a popular deep learning framework, seamlessly integrates with transformer models, enabling efficient handling of MCQ tasks. In this article, we'll explore how to leverage Transformers and PyTorch for MCQ tasks.

image/png

Understanding Transformers and PyTorch

Transformers: These models excel in understanding contextual information in sequences through self-attention mechanisms. This ability to capture relationships between different parts of text is particularly beneficial in comprehending and answering MCQs effectively.

PyTorch: PyTorch's dynamic computation graph and user-friendly interface simplify the implementation and training of complex neural networks. Its flexibility allows seamless integration with transformer architectures, enabling streamlined development and experimentation.

Benefits of Utilizing Transformers with PyTorch

  1. Enhanced Contextual Understanding: Transformers, combined with PyTorch, excel in capturing nuanced relationships within textual data. This enables them to grasp the context of MCQs comprehensively, leading to more accurate predictions.

  2. Transfer Learning Capabilities: Pre-trained transformer models, such as BERT, RoBERTa, or ALBERT, can be fine-tuned on MCQ datasets using PyTorch. Leveraging pre-trained models significantly reduces training time and data requirements while still achieving high performance.

3. Flexibility and Customization: PyTorch's flexibility allows for easy customization of transformer models. Researchers and developers can tailor the architectures, loss functions, and training methodologies to suit the specific requirements of MCQ tasks.

  1. State-of-the-Art Performance Transformer-based models consistently achieve state-of-the-art performance on various NLP benchmarks. When coupled with PyTorch's optimization tools, they deliver high accuracy in predicting correct answers for MCQs.

  2. Scalability and Efficiency: PyTorch's efficient handling of computations and the parallel processing capabilities of transformers make them scalable solutions. They can process large volumes of MCQs swiftly, making them suitable for real-time applications.

Code Implementation

Here's a brief elaboration on how each step in utilizing Transformers with PyTorch for MCQ tasks benefits from their synergy:

  1. Dataset Preparation: Transformers, with PyTorch's support, handle diverse dataset structures effectively. PyTorch's data handling capabilities simplify dataset organization, ensuring seamless integration of MCQs and their respective choices for efficient model training.
!pip install datasets transformers evaluate --quiet

import transformers
from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding, Trainer, TrainingArguments
import evaluate
import numpy as np
from datasets import load_metric, load_dataset
import random

print(transformers.__version__)

# Defining a constant SEED for reproducibility in random operations
SEED = 42

# Setting the seed for the random library to ensure consistent results
random.seed(SEED)

from datasets import load_dataset, load_metric
datasets = load_dataset("swag", "regular")

datasets["train"][0]

Output

{'video-id': 'anetv_jkn6uvmqwh4',
 'fold-ind': '3416',
 'startphrase': 'Members of the procession walk down the street holding small horn brass instruments. A drum line',
 'sent1': 'Members of the procession walk down the street holding small horn brass instruments.',
 'sent2': 'A drum line',
 'gold-source': 'gold',
 'ending0': 'passes by walking down the street playing their instruments.',
 'ending1': 'has heard approaching them.',
 'ending2': "arrives and they're outside dancing and asleep.",
 'ending3': 'turns the lead singer watches the performance.',
 'label': 0}
def show_one(example):
    print(f"Context: {example['sent1']}")
    print(f"  A - {example['sent2']} {example['ending0']}")
    print(f"  B - {example['sent2']} {example['ending1']}")
    print(f"  C - {example['sent2']} {example['ending2']}")
    print(f"  D - {example['sent2']} {example['ending3']}")
    print(f"\nGround truth: option {['A', 'B', 'C', 'D'][example['label']]}")

show_one(datasets["train"][15])

Output

Context: Now it's someone's turn to rain blades on his opponent.
  A - Someone pats his shoulder and spins wildly.
  B - Someone lunges forward through the window.
  C - Someone falls to the ground.
  D - Someone rolls up his fast run from the water and tosses in the sky.

Ground truth: option C
  1. Preprocessing: PyTorch's compatibility with transformer models facilitates smooth text preprocessing. This includes tokenization, encoding, and sequence preparation, streamlining the conversion of textual data into numerical representations that transformers can comprehend.
model_checkpoint = 'distilbert-base-uncased' # "bert-base-uncased"
batch_size = 4
    
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)
tokenizer("Hello, this one sentence!", "And this sentence goes with it.")

Output

{'input_ids': [101, 7592, 1010, 2023, 2028, 6251, 999, 102, 1998, 2023, 6251, 3632, 2007, 2009, 1012, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
ending_names = ["ending0", "ending1", "ending2", "ending3"]

def preprocess_function(examples):
    # Repeat each first sentence four times to go with the four possibilities of second sentences.
    first_sentences = [[context] * 4 for context in examples["sent1"]]
    # Grab all second sentences possible for each context.
    question_headers = examples["sent2"]
    second_sentences = [[f"{header} {examples[end][i]}" for end in ending_names] for i, header in enumerate(question_headers)]
    
    # Flatten everything
    first_sentences = sum(first_sentences, [])
    second_sentences = sum(second_sentences, [])
    
    # Tokenize
    tokenized_examples = tokenizer(first_sentences, second_sentences, truncation=True)
    # Un-flatten
    return {k: [v[i:i+4] for i in range(0, len(v), 4)] for k, v in tokenized_examples.items()}

examples = datasets["train"][:5]
features = preprocess_function(examples)
print(len(features["input_ids"]), len(features["input_ids"][0]), [len(x) for x in features["input_ids"][0]])

Output:

5 4 [30, 25, 30, 28]
idx = 3
[tokenizer.decode(features["input_ids"][idx][i]) for i in range(4)]

Output:

['[CLS] a drum line passes by walking down the street playing their instruments. [SEP] members of the procession are playing ping pong and celebrating one left each in quick. [SEP]',
 '[CLS] a drum line passes by walking down the street playing their instruments. [SEP] members of the procession wait slowly towards the cadets. [SEP]',
 '[CLS] a drum line passes by walking down the street playing their instruments. [SEP] members of the procession makes a square call and ends by jumping down into snowy streets where fans begin to take their positions. [SEP]',
 '[CLS] a drum line passes by walking down the street playing their instruments. [SEP] members of the procession play and go back and forth hitting the drums while the audience claps for them. [SEP]']
encoded_datasets = datasets.map(preprocess_function, batched=True)
  1. Fine-tuning: The synergy between PyTorch and transformers is pivotal during fine-tuning. PyTorch's gradient-based optimization and backpropagation enable efficient adjustment of transformer model parameters to adapt specifically to the nuances of MCQ tasks.
from transformers import AutoModelForMultipleChoice, TrainingArguments, Trainer
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/roberta/modeling_roberta.py#L1266
model = AutoModelForMultipleChoice.from_pretrained(model_checkpoint)
# https://github.com/huggingface/datasets/issues/2165
from torch.utils.data import Dataset, DataLoader, RandomSampler
 
class HFDataset(Dataset):
    def __init__(self, dset):
        self.dset = dset

    def __getitem__(self, idx):
        x = self.dset[idx]
        return {'input_ids': x['input_ids'],
                'attention_mask': x['attention_mask'], # ignore token_type_ids
                'label' : x['label']}

    def __len__(self):
        return len(self.dset)

train_ds = HFDataset(encoded_datasets['train'])
test_ds = HFDataset(encoded_datasets['validation'])
len(encoded_datasets['train']), len(train_ds)

Output

(73546, 73546)
from dataclasses import dataclass
from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
from typing import Optional, Union
import torch

@dataclass
class DataCollatorForMultipleChoice:
    """
    Data collator that will dynamically pad the inputs for multiple choice received.
    """

    tokenizer: PreTrainedTokenizerBase
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None

    def __call__(self, features):
        label_name = "label" if "label" in features[0].keys() else "labels"
        labels = [feature.pop(label_name) for feature in features]
        batch_size = len(features)
        num_choices = len(features[0]["input_ids"])
        flattened_features = [[{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features]
        flattened_features = sum(flattened_features, [])
        
        batch = self.tokenizer.pad(
            flattened_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )
        
        # Un-flatten
        batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
        # Add back labels
        batch["labels"] = torch.tensor(labels, dtype=torch.int64)
        return batch
    
    
def HFDataLoader(dataset, tokenizer, batch_size=4, shuffle=True, num_workers=2):
    
    def listdict2dictlist(batch):
        '''
        Input: batch -- list of dict
        Output: dict of list-size-batch
        '''
        d = {}
        keys = batch[0].keys()
    
        for k in keys:
            d[k] = []
            for i in range(len(batch)):
                d[k].append(batch[i][k])
    
        return d
    
    def prepare_sample(sample):
        
        padding = True
        max_length = None
        pad_to_multiple_of = None

        features = listdict2dictlist(sample)
        batch_size = len(features["input_ids"])
        num_choices = len(features["input_ids"][0])
        
        flattened_features = {}
        for k,v in features.items():
            if k=='label': 
                continue
                
            flattened_features[k] = []
            for example in features[k]: # e.g. k='input_ids'
                for choice in example: # e.g. 4 choices per example
                    flattened_features[k].append(choice)

        
        batch = tokenizer.pad(
            flattened_features,
            padding=padding,
            max_length=max_length,
            pad_to_multiple_of=pad_to_multiple_of,
            return_tensors="pt",
        )
        
        # Un-flatten
        batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
        # Add back labels
        batch["labels"] = torch.tensor(features['label'], dtype=torch.int64)
        return batch
    
    sampler = RandomSampler(dataset) if shuffle else None
    return DataLoader(dataset,
            sampler=sampler,
            batch_size=batch_size,
            collate_fn=prepare_sample,
            num_workers=num_workers)
import os
os.environ['TOKENIZERS_PARALLELISM'] = "false"

train_loader = HFDataLoader(train_ds, tokenizer,  batch_size=16)
test_loader = HFDataLoader(test_ds, tokenizer,  batch_size=16, shuffle=False)
for x in train_loader:
    print(x)
    break

Output:

ou're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
{'input_ids': tensor([[[  101,  1037,  2450,  ...,     0,     0,     0],
         [  101,  1037,  2450,  ...,     0,     0,     0],
         [  101,  1037,  2450,  ...,     0,     0,     0],
         [  101,  1037,  2450,  ...,     0,     0,     0]],

        [[  101,  2111,  2024,  ...,     0,     0,     0],
         [  101,  2111,  2024,  ...,     0,     0,     0],
         [  101,  2111,  2024,  ...,     0,     0,     0],
         [  101,  2111,  2024,  ...,     0,     0,     0]],

        [[  101,  2059,  2007,  ...,     0,     0,     0],
         [  101,  2059,  2007,  ...,     0,     0,     0],
         [  101,  2059,  2007,  ...,     0,     0,     0],
         [  101,  2059,  2007,  ...,     0,     0,     0]],

        ...,

        [[  101,  2002, 17395,  ...,     0,     0,     0],
         [  101,  2002, 17395,  ...,     0,     0,     0],
         [  101,  2002, 17395,  ...,     0,     0,     0],
         [  101,  2002, 17395,  ...,     0,     0,     0]],

        [[  101,  1037,  2450,  ...,     0,     0,     0],
         [  101,  1037,  2450,  ...,     0,     0,     0],
         [  101,  1037,  2450,  ...,     0,     0,     0],
         [  101,  1037,  2450,  ...,     0,     0,     0]],

        [[  101,  2002, 12668,  ...,     0,     0,     0],
         [  101,  2002, 12668,  ...,     0,     0,     0],
         [  101,  2002, 12668,  ...,     0,     0,     0],
         [  101,  2002, 12668,  ...,     0,     0,     0]]]), 'attention_mask': tensor([[[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]],

        [[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]],

        [[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]],

        ...,

        [[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]],

        [[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]],

        [[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]]]), 'labels': tensor([1, 0, 0, 2, 2, 0, 1, 0, 2, 3, 1, 3, 2, 0, 0, 2])}
  1. Training: PyTorch's training utilities combined with transformer architectures streamline the training process. The seamless integration allows for efficient computation and parameter updates, accelerating the convergence of the model on MCQ datasets.
import pytorch_lightning as pl
class PLTransformer(pl.LightningModule):
    def __init__(
        self,
        model_base,
        learning_rate: float = 2e-5,
        adam_epsilon: float = 1e-8,
        warmup_steps: int = 0,
        weight_decay: float = 0.0,
        train_batch_size: int = 32,
        eval_batch_size: int = 32,
        **kwargs,
    ):
        super().__init__()

#         self.save_hyperparameters() # cause code to freeze if we have model_base as argument !!
        self.model_base = model_base
        self.lr = learning_rate
        self.num_labels = 4 # TODO: hard code ATM
        
    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwarg):
        return self.model_base(input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwarg)

    def training_step(self, batch, batch_idx):
        outputs = self.forward(**batch)
        loss = outputs.loss
        return loss

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        outputs = self.forward(**batch)
        val_loss, logits = outputs.loss, outputs.logits

        if self.num_labels > 1:
            preds = torch.argmax(logits, axis=1)
        elif self.num_labels == 1:
            preds = logits.squeeze()

        labels = batch["labels"]

        return {"loss": val_loss, "preds": preds, "labels": labels}


    def configure_optimizers(self):
        """Prepare optimizer and schedule (linear warmup and decay)"""
        optimizer =  torch.optim.AdamW(self.model_base.parameters(), lr=self.lr,)
        return optimizer #, [scheduler]

pl_model = PLTransformer(model)

print(pl_model.to('cpu')(**x))
pl_model.to('cpu').training_step(x, 0)

Output:

MultipleChoiceModelOutput(loss=tensor(1.3915, grad_fn=<NllLossBackward0>), logits=tensor([[ 0.0144,  0.0352,  0.0017,  0.0198],
        [-0.0346, -0.0176, -0.0254, -0.0258],
        [-0.0054,  0.0022,  0.0579, -0.0057],
        [-0.0168, -0.0084, -0.0332,  0.0098],
        [ 0.0393,  0.0254,  0.0325,  0.0005],
        [ 0.0292,  0.0291,  0.0407,  0.0326],
        [-0.0220, -0.0277, -0.0461, -0.0345],
        [-0.0347, -0.0353, -0.0412, -0.0308],
        [ 0.0145,  0.0040, -0.0098, -0.0152],
        [ 0.0151, -0.0131,  0.0044, -0.0081],
        [-0.0025, -0.0051,  0.0014, -0.0056],
        [ 0.0293,  0.0211,  0.0291,  0.0254],
        [-0.0377,  0.0128, -0.0248, -0.0133],
        [ 0.0255,  0.0315,  0.0295,  0.0504],
        [-0.0230,  0.0035,  0.0003, -0.0109],
        [ 0.0458,  0.0464,  0.0418,  0.0733]], grad_fn=<ViewBackward0>), hidden_states=None, attentions=None)
tensor(1.3915, grad_fn=<NllLossBackward0>)
trainer = pl.Trainer(
    max_epochs=1,
    accelerator="gpu",
    devices=[0],
    precision='16',
)

trainer.fit(pl_model, train_loader, test_loader)

Conclusion

The combination of transformer-based architectures and PyTorch presents a compelling framework for addressing MCQ tasks efficiently and accurately. The advantages offered by transformers, including enhanced contextual understanding and transfer learning capabilities, coupled with PyTorch's flexibility and optimization tools, make this fusion an ideal choice for developing robust MCQ-solving models.

As transformer architectures and PyTorch continue to evolve, their integration promises even greater advancements in automating MCQ assessments across diverse domains.

In summary, the amalgamation of Transformers and PyTorch serves as a cornerstone in the development of highly effective models for handling MCQ tasks, paving the way for improved automated question-answering systems.

“Stay connected and support my work through various platforms:

Huggingface: For natural language processing and AI-related projects, you can explore my Huggingface profile at https://huggingface.co/Andyrasika.

LinkedIn: To stay updated on my latest projects and posts, you can follow me on LinkedIn. Here is the link to my profile: https://www.linkedin.com/in/ankushsingal/."

Requests and questions: If you have a project in mind that you’d like me to work on or if you have any questions about the concepts I’ve explained, don’t hesitate to let me know. I’m always looking for new ideas for future Notebooks and I love helping to resolve any doubts you might have.

Resources: