Problem with highly padded sequences

#49
by fmrs - opened

Hi all,

I am trying to fine-tune modernBERT model for a text classification task but I am getting nan logits when the input is highly padded. Example (inputs already tokenized):

import torch
from transformers import ModernBertForSequenceClassification, AutoTokenizer

# Initialize the tokenizer and model
model_name = "answerdotai/ModernBERT-base"

model = ModernBertForSequenceClassification.from_pretrained(model_name, num_labels=10, reference_compile=False)

# Provided input_ids tensor with all sequences
input_ids = torch.tensor([
    [50281, 28782, 2380, 281, 776, 10309, 3237, 21726, 50282, 50283,
     50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
     50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
     50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
     50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
     50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
     50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
     50283, 50283, 50283, 50283, 50283],
    [50281, 53, 2555, 369, 4895, 275, 1072, 1617, 352, 369,
     2197, 275, 285, 1335, 2509, 253, 1072, 2181, 352, 369,
     2197, 275, 323, 21726, 50282, 50283, 50283, 50283, 50283, 50283,
     50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
     50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
     50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
     50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
     50283, 50283, 50283, 50283, 50283],
    [50281, 42, 1904, 626, 4763, 667, 5511, 326, 352, 369,
     1146, 4895, 1919, 846, 352, 2692, 598, 3066, 46690, 387,
     776, 3906, 15, 21726, 50282, 50283, 50283, 50283, 50283, 50283,
     50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
     50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
     50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
     50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
     50283, 50283, 50283, 50283, 50283],
    [50281, 47638, 2, 388, 3271, 12278, 14, 5683, 50282, 50283,
     50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
     50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
     50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
     50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
     50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
     50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
     50283, 50283, 50283, 50283, 50283]
], device="cuda")

# Dynamically create attention masks
attention_mask = (input_ids != 50283).long()

# Define labels
labels = torch.tensor([0, 1, 2, 3], device="cuda")

# Move model to the appropriate device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Prepare model inputs
inputs = {
    "input_ids": input_ids,
    "attention_mask": attention_mask,
    "labels": labels
}

# Perform a forward pass
outputs = model(**inputs)

# Output the logits
print("Logits:", outputs.logits)
print("Logits shape:", outputs.logits.shape)

# Check for NaN values
nan_indices = torch.isnan(outputs.logits).any(dim=1).nonzero(as_tuple=True)[0]
if nan_indices.numel() > 0:
    print(f"NaN detected in logits for batch indices: {nan_indices.tolist()}")

# Print the loss if labels are provided
if outputs.loss is not None:
    print("Loss:", outputs.loss) 

This produces the following output:

Logits: tensor([[    nan,     nan,     nan,     nan,     nan,     nan,     nan,     nan,
             nan,     nan],
        [-0.8703, -0.6405, -0.6561, -0.2938, -1.2773, -0.1523,  0.9613,  1.4057,
         -0.5483,  0.0425],
        [-0.5069, -0.5104, -0.6058, -0.6675, -0.5392,  0.3078,  0.5109,  0.7993,
         -0.0116, -0.1695],
        [    nan,     nan,     nan,     nan,     nan,     nan,     nan,     nan,
             nan,     nan]], device='cuda:0', grad_fn=<AddmmBackward0>)
Logits shape: torch.Size([4, 10])
NaN detected in logits for batch indices: [0, 3]
Loss: tensor(nan, device='cuda:0', grad_fn=<NllLossBackward0>)

As you can see, inputs 0 and 3 outputs nan because the padding (ID 50283), any guesses why? Thanks a lot for helping!

Facing the same issue (4.48.0) and not using flash-attn.

Could you give more details about your setups?
I can't reproduce it on latest version of transformers with SDPA pip install git+https://github.com/huggingface/transformers.git

Seems related to flash-attn. Without flash-attn, the model produces nans. Seems also related to this issue: https://huggingface.co/answerdotai/ModernBERT-base/discussions/43

Fresh install of torch (before I was using 2.3.0, now 2.5.1) and latest version of transformers (directly from git as @staghado suggested) fixed the issue.

Sign up or log in to comment