{ "cells": [ { "cell_type": "code", "execution_count": 2, "id": "0f9f666f", "metadata": {}, "outputs": [], "source": [ "from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification, Trainer, TrainingArguments\n", "from datasets import load_dataset\n", "import torch\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.metrics import accuracy_score, precision_recall_fscore_support" ] }, { "cell_type": "code", "execution_count": 3, "id": "2f35116b", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a3bdffef37cd4d5aaa090640d5384825", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/25000 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Load the IMDb dataset\n", "dataset = load_dataset(\"imdb\")\n", "\n", "# Tokenizer function\n", "tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')\n", "\n", "def tokenize_function(examples):\n", " return tokenizer(examples[\"text\"], padding=\"max_length\", truncation=True, max_length=512)\n", "\n", "# Tokenize the dataset\n", "tokenized_datasets = dataset.map(tokenize_function, batched=True)\n", "\n", "# Format for PyTorch\n", "train_dataset = tokenized_datasets[\"train\"].shuffle(seed=42).select(range(10000)) # Subset for training\n", "test_dataset = tokenized_datasets[\"test\"].shuffle(seed=42).select(range(1000)) # Subset for testing\n", "\n", "train_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])\n", "test_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "93d6a61b", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_transform.weight']\n", "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'pre_classifier.bias', 'classifier.bias', 'pre_classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "source": [ "model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=2)\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "58400de8", "metadata": {}, "outputs": [], "source": [ "training_args = TrainingArguments(\n", " output_dir='./results',\n", " num_train_epochs=3,\n", " per_device_train_batch_size=16,\n", " per_device_eval_batch_size=64,\n", " warmup_steps=500,\n", " weight_decay=0.01,\n", " logging_dir='./logs',\n", " evaluation_strategy='steps', \n", " save_strategy='steps', \n", " load_best_model_at_end=True,\n", " logging_steps=50, \n", " save_steps=50 \n", ")\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "3389ad91", "metadata": {}, "outputs": [], "source": [ "def compute_metrics(pred):\n", " labels = pred.label_ids\n", " preds = pred.predictions.argmax(-1)\n", " precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')\n", " acc = accuracy_score(labels, preds)\n", " return {\n", " 'accuracy': acc,\n", " 'f1': f1,\n", " 'precision': precision,\n", " 'recall': recall\n", " }\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "0d68d5ea", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "The following columns in the training set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `DistilBertForSequenceClassification.forward`, you can safely ignore this message.\n", "C:\\Users\\saime\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\transformers\\optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n", " warnings.warn(\n", "***** Running training *****\n", " Num examples = 10000\n", " Num Epochs = 3\n", " Instantaneous batch size per device = 16\n", " Total train batch size (w. parallel, distributed & accumulation) = 16\n", " Gradient Accumulation steps = 1\n", " Total optimization steps = 1875\n", " Number of trainable parameters = 66955010\n" ] }, { "data": { "text/html": [ "\n", "
Step | \n", "Training Loss | \n", "Validation Loss | \n", "Accuracy | \n", "F1 | \n", "Precision | \n", "Recall | \n", "
---|---|---|---|---|---|---|
50 | \n", "0.688800 | \n", "0.680938 | \n", "0.661000 | \n", "0.543742 | \n", "0.792157 | \n", "0.413934 | \n", "
100 | \n", "0.629000 | \n", "0.465259 | \n", "0.841000 | \n", "0.819113 | \n", "0.920716 | \n", "0.737705 | \n", "
150 | \n", "0.371200 | \n", "0.323407 | \n", "0.868000 | \n", "0.867470 | \n", "0.850394 | \n", "0.885246 | \n", "
200 | \n", "0.336300 | \n", "0.374150 | \n", "0.857000 | \n", "0.836197 | \n", "0.948052 | \n", "0.747951 | \n", "
250 | \n", "0.336700 | \n", "0.312763 | \n", "0.865000 | \n", "0.871795 | \n", "0.812389 | \n", "0.940574 | \n", "
300 | \n", "0.311800 | \n", "0.296506 | \n", "0.889000 | \n", "0.882540 | \n", "0.912473 | \n", "0.854508 | \n", "
350 | \n", "0.309800 | \n", "0.286319 | \n", "0.886000 | \n", "0.886228 | \n", "0.863813 | \n", "0.909836 | \n", "
400 | \n", "0.272300 | \n", "0.292773 | \n", "0.890000 | \n", "0.884696 | \n", "0.905579 | \n", "0.864754 | \n", "
450 | \n", "0.315100 | \n", "0.419856 | \n", "0.854000 | \n", "0.831019 | \n", "0.954787 | \n", "0.735656 | \n", "
500 | \n", "0.350900 | \n", "0.298303 | \n", "0.862000 | \n", "0.869565 | \n", "0.807018 | \n", "0.942623 | \n", "
550 | \n", "0.355200 | \n", "0.333094 | \n", "0.870000 | \n", "0.852608 | \n", "0.954315 | \n", "0.770492 | \n", "
600 | \n", "0.279900 | \n", "0.282081 | \n", "0.887000 | \n", "0.879915 | \n", "0.913907 | \n", "0.848361 | \n", "
650 | \n", "0.279200 | \n", "0.288312 | \n", "0.892000 | \n", "0.883621 | \n", "0.931818 | \n", "0.840164 | \n", "
700 | \n", "0.198600 | \n", "0.338301 | \n", "0.876000 | \n", "0.863736 | \n", "0.931280 | \n", "0.805328 | \n", "
750 | \n", "0.195600 | \n", "0.292916 | \n", "0.897000 | \n", "0.897512 | \n", "0.872340 | \n", "0.924180 | \n", "
800 | \n", "0.243400 | \n", "0.289307 | \n", "0.899000 | \n", "0.900883 | \n", "0.864407 | \n", "0.940574 | \n", "
850 | \n", "0.193000 | \n", "0.304464 | \n", "0.897000 | \n", "0.894359 | \n", "0.895277 | \n", "0.893443 | \n", "
900 | \n", "0.214500 | \n", "0.257609 | \n", "0.899000 | \n", "0.895337 | \n", "0.905660 | \n", "0.885246 | \n", "
950 | \n", "0.228000 | \n", "0.279465 | \n", "0.887000 | \n", "0.891659 | \n", "0.837838 | \n", "0.952869 | \n", "
1000 | \n", "0.208100 | \n", "0.230380 | \n", "0.910000 | \n", "0.908537 | \n", "0.901210 | \n", "0.915984 | \n", "
1050 | \n", "0.200600 | \n", "0.307765 | \n", "0.901000 | \n", "0.902077 | \n", "0.871893 | \n", "0.934426 | \n", "
1100 | \n", "0.210600 | \n", "0.278725 | \n", "0.901000 | \n", "0.901493 | \n", "0.876209 | \n", "0.928279 | \n", "
1150 | \n", "0.208200 | \n", "0.283095 | \n", "0.912000 | \n", "0.909836 | \n", "0.909836 | \n", "0.909836 | \n", "
1200 | \n", "0.201000 | \n", "0.256353 | \n", "0.901000 | \n", "0.895238 | \n", "0.925602 | \n", "0.866803 | \n", "
1250 | \n", "0.186200 | \n", "0.249205 | \n", "0.909000 | \n", "0.906282 | \n", "0.910973 | \n", "0.901639 | \n", "
1300 | \n", "0.080400 | \n", "0.367344 | \n", "0.902000 | \n", "0.900609 | \n", "0.891566 | \n", "0.909836 | \n", "
1350 | \n", "0.152700 | \n", "0.323376 | \n", "0.905000 | \n", "0.900315 | \n", "0.922581 | \n", "0.879098 | \n", "
1400 | \n", "0.100400 | \n", "0.416915 | \n", "0.888000 | \n", "0.891892 | \n", "0.843066 | \n", "0.946721 | \n", "
1450 | \n", "0.108800 | \n", "0.324885 | \n", "0.908000 | \n", "0.907258 | \n", "0.892857 | \n", "0.922131 | \n", "
1500 | \n", "0.066700 | \n", "0.378826 | \n", "0.902000 | \n", "0.901210 | \n", "0.886905 | \n", "0.915984 | \n", "
1550 | \n", "0.078500 | \n", "0.368980 | \n", "0.906000 | \n", "0.901674 | \n", "0.920940 | \n", "0.883197 | \n", "
1600 | \n", "0.081500 | \n", "0.364918 | \n", "0.909000 | \n", "0.907048 | \n", "0.904277 | \n", "0.909836 | \n", "
1650 | \n", "0.062600 | \n", "0.386855 | \n", "0.905000 | \n", "0.903943 | \n", "0.892216 | \n", "0.915984 | \n", "
1700 | \n", "0.067000 | \n", "0.392243 | \n", "0.906000 | \n", "0.905051 | \n", "0.892430 | \n", "0.918033 | \n", "
1750 | \n", "0.047400 | \n", "0.409893 | \n", "0.910000 | \n", "0.908350 | \n", "0.902834 | \n", "0.913934 | \n", "
1800 | \n", "0.108200 | \n", "0.401962 | \n", "0.909000 | \n", "0.907801 | \n", "0.897796 | \n", "0.918033 | \n", "
1850 | \n", "0.105400 | \n", "0.390589 | \n", "0.912000 | \n", "0.910020 | \n", "0.908163 | \n", "0.911885 | \n", "
"
],
"text/plain": [
"