Batch inference yields no performance boost

#1
by Dregeary - opened

Hello there. Great performance on the RAID Leaderboard!
I'm trying to test the model on my end. The inference on a single text works alright, but for some reason, batch inference yields absolutely no performance gain whatsoever on GPU.

By using the same code in the Readme, but simply providing a list of texts to the tokenizer (encoded = tokenizer(texts, padding='max_length', ...)), here are some numbers I got:

Total time = 48.124741469001034 s (nb_texts = 200 | batch size = 2 sentences)
Total time = 47.815654558000006 s (nb_texts = 200 | batch size = 5 sentences)
Total time = 47.53672169100014 s (nb_texts = 200 | batch size = 10 sentences)
Total time = 47.7279913819998 s (nb_texts = 200 | batch size = 15 sentences)
Total time = 48.2357316770001 s (nb_texts = 200 | batch size = 30 sentences)

Using GPU T4x2

Total time = 33.9862 s (nb_texts = 200 | batch size = 2 sentences)
Total time = 23.4638 s (nb_texts = 200 | batch size = 5 sentences)
Total time = 15.3777 s (nb_texts = 200 | batch size = 10 sentences)
Total time = 15.2225 s (nb_texts = 200 | batch size = 15 sentences)
Total time = 13.3793 s (nb_texts = 200 | batch size = 30 sentences)

Using GPU P100

Total time = 31.0689 s (nb_texts = 200 | batch size = 2 sentences)
Total time = 28.6267 s (nb_texts = 200 | batch size = 5 sentences)
Total time = 27.5859 s (nb_texts = 200 | batch size = 10 sentences)
Total time = 27.4143 s (nb_texts = 200 | batch size = 15 sentences)
Total time = 27.1724 s (nb_texts = 200 | batch size = 30 sentences)

Code

import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoConfig, AutoModel, PreTrainedModel
import time
from typing import List, Tuple

class DesklibAIDetectionModel(PreTrainedModel):
    config_class = AutoConfig

    def __init__(self, config):
        super().__init__(config)
        self.model = AutoModel.from_config(config)
        self.classifier = nn.Linear(config.hidden_size, 1)
        self.init_weights()

    def forward(self, input_ids, attention_mask=None, labels=None):
        # Forward pass through the transformer
        outputs = self.model(input_ids, attention_mask=attention_mask)
        last_hidden_state = outputs[0]
        
        # Mean pooling with better memory efficiency
        mask_expanded = attention_mask.unsqueeze(-1)
        sum_embeddings = torch.sum(last_hidden_state * mask_expanded, dim=1)
        sum_mask = mask_expanded.sum(dim=1).clamp(min=1e-9)
        pooled_output = sum_embeddings / sum_mask

        # Classifier
        logits = self.classifier(pooled_output)
        
        if labels is not None:
            loss_fct = nn.BCEWithLogitsLoss()
            loss = loss_fct(logits.view(-1), labels.float())
            return {"logits": logits, "loss": loss}
        
        return {"logits": logits}

def predict_batch(texts: List[str], 
                  model: DesklibAIDetectionModel, 
                  tokenizer: AutoTokenizer, 
                  device: torch.device,
                  batch_size: int = 32,
                  max_len: int = 768,
                  threshold: float = 0.5) -> List[Tuple[float, int]]:
    
    results = []
    
    for i in range(0, len(texts), batch_size):
        batch_texts = texts[i:i + batch_size]
        
        # Tokenize with padding and truncation
        encoded = tokenizer(
            batch_texts,
            padding='max_length',
            truncation=True,
            max_length=max_len,
            return_tensors='pt'
        )
        
        # Efficiently move tensors to GPU
        input_ids = encoded['input_ids'].to(device, non_blocking=True)
        attention_mask = encoded['attention_mask'].to(device, non_blocking=True)
        
        # Inference using inference_mode and mixed precision
        with torch.inference_mode(), torch.cuda.amp.autocast(enabled=True):
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            probabilities = torch.sigmoid(outputs["logits"]).cpu().numpy().flatten()
            
        batch_results = [(float(prob), 1 if prob >= threshold else 0) 
                         for prob in probabilities]
        results.extend(batch_results)
    
    return results

def main():
    model_path = "desklib/ai-text-detector-v1.01"

    sample_text = "Artificial Intelligence is transforming the way businesses operate, including in education. AI tools and platforms are overhauling traditional teaching methods to become more accessible, personalized, and smart. However, like everything in life, the benefits of AI come with their own set of challenges—especially around fairness and equity, which educators must address if they want to fully integrate technology into the learning experience."
    
    # Example texts (replace with your actual texts)
    texts = [sample_text for _ in range(200)]
    
    # Test different batch sizes
    batch_sizes = [2, 5, 10, 15, 30]

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = DesklibAIDetectionModel.from_pretrained(model_path)
    
    # Enable multi-GPU if available
    if torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs")
        model = nn.DataParallel(model)
        
    model.to(device)
    model.eval()

    # Option 1: Check compute capability before compiling
    if hasattr(torch, "compile"):
        device_capability = torch.cuda.get_device_capability(device)
        if device_capability[0] >= 7:
            model = torch.compile(model)
        else:
            print(f"Skipping torch.compile because your GPU's compute capability is {device_capability} (< (7,0)).")
    
    # Option 2: Alternatively, you can suppress errors to fall back to eager mode
    # import torch._dynamo
    # torch._dynamo.config.suppress_errors = True
    # if hasattr(torch, "compile"):
    #     model = torch.compile(model)

    print("Is Fast Tokenizer: ", tokenizer.is_fast)
    
    results_benchmark = {}
    for batch_size in batch_sizes:
        torch.cuda.synchronize()
        start_time = time.perf_counter()
        
        _ = predict_batch(texts, model, tokenizer, device, batch_size=batch_size)
        
        torch.cuda.synchronize()
        end_time = time.perf_counter()
        total_time = end_time - start_time
        results_benchmark[batch_size] = total_time
        print(f"Batch size: {batch_size}, Time: {total_time:.4f} s")
    
    for batch_size, total_time in results_benchmark.items():
        print(f"Total time = {total_time:.4f} s (nb_texts = {len(texts)} | batch size = {batch_size} sentences)")

if __name__ == "__main__":
    main()
desklib changed discussion status to closed

Sign up or log in to comment