Problem with highly padded sequences
Hi all,
I am trying to fine-tune modernBERT model for a text classification task but I am getting nan logits when the input is highly padded. Example (inputs already tokenized):
import torch
from transformers import ModernBertForSequenceClassification, AutoTokenizer
# Initialize the tokenizer and model
model_name = "answerdotai/ModernBERT-base"
model = ModernBertForSequenceClassification.from_pretrained(model_name, num_labels=10, reference_compile=False)
# Provided input_ids tensor with all sequences
input_ids = torch.tensor([
[50281, 28782, 2380, 281, 776, 10309, 3237, 21726, 50282, 50283,
50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
50283, 50283, 50283, 50283, 50283],
[50281, 53, 2555, 369, 4895, 275, 1072, 1617, 352, 369,
2197, 275, 285, 1335, 2509, 253, 1072, 2181, 352, 369,
2197, 275, 323, 21726, 50282, 50283, 50283, 50283, 50283, 50283,
50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
50283, 50283, 50283, 50283, 50283],
[50281, 42, 1904, 626, 4763, 667, 5511, 326, 352, 369,
1146, 4895, 1919, 846, 352, 2692, 598, 3066, 46690, 387,
776, 3906, 15, 21726, 50282, 50283, 50283, 50283, 50283, 50283,
50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
50283, 50283, 50283, 50283, 50283],
[50281, 47638, 2, 388, 3271, 12278, 14, 5683, 50282, 50283,
50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
50283, 50283, 50283, 50283, 50283]
], device="cuda")
# Dynamically create attention masks
attention_mask = (input_ids != 50283).long()
# Define labels
labels = torch.tensor([0, 1, 2, 3], device="cuda")
# Move model to the appropriate device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# Prepare model inputs
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels
}
# Perform a forward pass
outputs = model(**inputs)
# Output the logits
print("Logits:", outputs.logits)
print("Logits shape:", outputs.logits.shape)
# Check for NaN values
nan_indices = torch.isnan(outputs.logits).any(dim=1).nonzero(as_tuple=True)[0]
if nan_indices.numel() > 0:
print(f"NaN detected in logits for batch indices: {nan_indices.tolist()}")
# Print the loss if labels are provided
if outputs.loss is not None:
print("Loss:", outputs.loss)
This produces the following output:
Logits: tensor([[ nan, nan, nan, nan, nan, nan, nan, nan,
nan, nan],
[-0.8703, -0.6405, -0.6561, -0.2938, -1.2773, -0.1523, 0.9613, 1.4057,
-0.5483, 0.0425],
[-0.5069, -0.5104, -0.6058, -0.6675, -0.5392, 0.3078, 0.5109, 0.7993,
-0.0116, -0.1695],
[ nan, nan, nan, nan, nan, nan, nan, nan,
nan, nan]], device='cuda:0', grad_fn=<AddmmBackward0>)
Logits shape: torch.Size([4, 10])
NaN detected in logits for batch indices: [0, 3]
Loss: tensor(nan, device='cuda:0', grad_fn=<NllLossBackward0>)
As you can see, inputs 0 and 3 outputs nan because the padding (ID 50283), any guesses why? Thanks a lot for helping!
Facing the same issue (4.48.0) and not using flash-attn.
Could you give more details about your setups?
I can't reproduce it on latest version of transformers with SDPA pip install git+https://github.com/huggingface/transformers.git
Seems related to flash-attn. Without flash-attn, the model produces nans. Seems also related to this issue: https://huggingface.co/answerdotai/ModernBERT-base/discussions/43
Fresh install of torch (before I was using 2.3.0, now 2.5.1) and latest version of transformers (directly from git as @staghado suggested) fixed the issue.