ymoslem's picture
Update README.md
e4d84a3 verified
metadata
library_name: transformers
language:
  - multilingual
  - bn
  - cs
  - de
  - en
  - et
  - fi
  - fr
  - gu
  - ha
  - hi
  - is
  - ja
  - kk
  - km
  - lt
  - lv
  - pl
  - ps
  - ru
  - ta
  - tr
  - uk
  - xh
  - zh
  - zu
license: apache-2.0
base_model: answerdotai/ModernBERT-base
tags:
  - quality-estimation
  - regression
  - generated_from_trainer
datasets:
  - ymoslem/wmt-da-human-evaluation-long-context
model-index:
  - name: Quality Estimation for Machine Translation
    results:
      - task:
          type: regression
        dataset:
          name: ymoslem/wmt-da-human-evaluation-long-context
          type: QE
        metrics:
          - name: Pearson Correlation
            type: Pearson
            value: 0.5013
          - name: Mean Absolute Error
            type: MAE
            value: 0.1024
          - name: Root Mean Squared Error
            type: RMSE
            value: 0.1464
          - name: R-Squared
            type: R2
            value: 0.251
metrics:
  - pearsonr
  - mae
  - r_squared

Quality Estimation for Machine Translation

This model is a fine-tuned version of answerdotai/ModernBERT-base on the ymoslem/wmt-da-human-evaluation-long-context dataset. It achieves the following results on the evaluation set:

  • Loss: 0.0214
  • Pearson: 0.5013
  • MAE: 0.1024
  • RMSE: 0.1464
  • R2: 0.251

Model description

This model is for reference-free, long-context quality estimation (QE) of machine translation (MT) systems. It is trained on a dataset of translation pairs comprising up to 32 sentences (64 sentences for the source and target). Hence, this model is suitable for document-level quality estimation.

Training and evaluation data

The model is trained on the long-context dataset ymoslem/wmt-da-human-evaluation-long-context. The used long-context / document-level dataset for Quality Estimation of Machine Translation is an augmented variant of the sentence-level WMT DA Human Evaluation dataset. In addition to individual sentences, it contains augmentations of 2, 4, 8, 16, and 32 sentences, among each language pair lp and domain. The raw column represents a weighted average of scores of augmented sentences using character lengths of src and mt as weights.

  • Training data: 7.65 million long-context texts
  • Test data: 59,235 long-context texts

Training procedure

The model is trained on 1x H200 SXM (143 GB VRAM) for approx. 26 hours.

  • tokenizer.model_max_length: 8192 (full context length)
  • attn_implementation: flash_attention_2

Training hyperparameters

The following hyperparameters were used during training:

  • learning_rate: 0.0003
  • train_batch_size: 128
  • eval_batch_size: 128
  • seed: 42
  • optimizer: Use OptimizerNames.ADAMW_TORCH_FUSED with betas=(0.9,0.999) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments
  • lr_scheduler_type: linear
  • training_steps: 60000 (approx. 1 epoch)

Training results

Training Loss Epoch Step Validation Loss
0.0233 0.0167 1000 0.0233
0.0232 0.0335 2000 0.0230
0.0225 0.0502 3000 0.0230
0.023 0.0669 4000 0.0224
0.0226 0.0837 5000 0.0223
0.0226 0.1004 6000 0.0225
0.0219 0.1171 7000 0.0222
0.022 0.1339 8000 0.0222
0.0213 0.1506 9000 0.0221
0.0213 0.1673 10000 0.0220
0.0218 0.1840 11000 0.0219
0.0215 0.2008 12000 0.0225
0.0218 0.2175 13000 0.0219
0.0218 0.2342 14000 0.0218
0.0217 0.2510 15000 0.0219
0.0219 0.2677 16000 0.0219
0.0212 0.2844 17000 0.0219
0.0219 0.3012 18000 0.0219
0.0218 0.3179 19000 0.0219
0.0213 0.3346 20000 0.0217
0.0218 0.3514 21000 0.0217
0.021 0.3681 22000 0.0217
0.0219 0.3848 23000 0.0220
0.0211 0.4016 24000 0.0216
0.0211 0.4183 25000 0.0216
0.0206 0.4350 26000 0.0216
0.021 0.4517 27000 0.0215
0.0214 0.4685 28000 0.0215
0.0214 0.4852 29000 0.0216
0.0204 0.5019 30000 0.0216
0.022 0.5187 31000 0.0216
0.0212 0.5354 32000 0.0217
0.0211 0.5521 33000 0.0216
0.0208 0.5689 34000 0.0215
0.0208 0.5856 35000 0.0215
0.0215 0.6023 36000 0.0215
0.0212 0.6191 37000 0.0215
0.0213 0.6358 38000 0.0215
0.0211 0.6525 39000 0.0215
0.0208 0.6693 40000 0.0215
0.0205 0.6860 41000 0.0215
0.0209 0.7027 42000 0.0215
0.021 0.7194 43000 0.0215
0.0207 0.7362 44000 0.0215
0.0197 0.7529 45000 0.0215
0.0211 0.7696 46000 0.0214
0.021 0.7864 47000 0.0215
0.0207 0.8031 48000 0.0214
0.0219 0.8198 49000 0.0215
0.0208 0.8366 50000 0.0215
0.0202 0.8533 51000 0.0215
0.02 0.8700 52000 0.0215
0.0205 0.8868 53000 0.0214
0.0214 0.9035 54000 0.0215
0.0205 0.9202 55000 0.0214
0.0209 0.9370 56000 0.0214
0.0206 0.9537 57000 0.0214
0.0204 0.9704 58000 0.0214
0.0203 0.9872 59000 0.0214
0.0209 1.0039 60000 0.0214

Framework versions

  • Transformers 4.48.1
  • Pytorch 2.4.1+cu124
  • Datasets 3.2.0
  • Tokenizers 0.21.0

Inference

  1. Install the required libraries.
pip3 install --upgrade datasets accelerate transformers
pip3 install --upgrade flash_attn triton
  1. Load the test dataset.
from datasets import load_dataset

test_dataset = load_dataset("ymoslem/wmt-da-human-evaluation",
                             split="test",
                             trust_remote_code=True
                            )
print(test_dataset)
  1. Load the model and tokenizer:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch

# Load the fine-tuned model and tokenizer
model_name = "ymoslem/ModernBERT-base-long-context-qe-v1"
model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Move model to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
  1. Prepare the dataset. Each source segment src and target segment tgt are separated by the sep_token, which is '</s>' for ModernBERT.
sep_token = tokenizer.sep_token
input_test_texts = [f"{src} {sep_token} {tgt}" for src, tgt in zip(test_dataset["src"], test_dataset["mt"])]
  1. Generate predictions.

If you print model.config.problem_type, the output is regression. Still, you can use the "text-classification" pipeline as follows (cf. pipeline documentation):

from transformers import pipeline

classifier = pipeline("text-classification",
                      model=model_name,
                      tokenizer=tokenizer,
                      device=0,
                     )

predictions = classifier(input_test_texts,
                         batch_size=128,
                         truncation=True,
                         padding="max_length",
                         max_length=tokenizer.model_max_length,
                       )
predictions = [prediction["score"] for prediction in predictions]

Alternatively, you can use an elaborate version of the code, which is slightly faster and provides more control.

from torch.utils.data import DataLoader
import torch
from tqdm.auto import tqdm

# Tokenization function
def process_batch(batch, tokenizer, device):
    sep_token = tokenizer.sep_token
    input_texts = [f"{src} {sep_token} {tgt}" for src, tgt in zip(batch["src"], batch["mt"])]
    tokens = tokenizer(input_texts,
                       truncation=True,
                       padding="max_length",
                       max_length=tokenizer.model_max_length,
                       return_tensors="pt",
                      ).to(device)
    return tokens
    


# Create a DataLoader for batching
test_dataloader = DataLoader(test_dataset, 
                             batch_size=128,   # Adjust batch size as needed
                             shuffle=False)


# List to store all predictions
predictions = []

with torch.no_grad():
    for batch in tqdm(test_dataloader, desc="Inference Progress", unit="batch"):

        tokens = process_batch(batch, tokenizer, device)
        
        # Forward pass: Generate model's logits
        outputs = model(**tokens)

        # Get logits (predictions)
        logits = outputs.logits

        # Extract the regression predicted values
        batch_predictions = logits.squeeze()

        # Extend the list with the predictions
        predictions.extend(batch_predictions.tolist())