- Published on
Multi-Label Classification with Pytorch Lightning and Huggingface
- Author
- Name
- Sean Burt
- @SeanBurt8
This post covers fine-tuning distilbert for multi-label classification of comments. The dataset I used can be found on kaggle. The labels seen in this dataset are: toxic
, severe_toxic
, obscene
, threat
, insult
, and identity_hate
The code for this was inspired by this notebook by Sebastian Raschka @rasbt. It uses Lightning
with Pytorch
, and utilizes Huggingface Transformers
to load the base model and tokenizer for distilbert-base-uncased
The following sections talk a bit about some of the additional things you need to change in Sebastian's notebook to make this work. The code is not perfect, but it worked pretty well and got decent results with little to no pre-processing of the comments.
- Make sure the input are strings, not
- Otherwise you will get an error, show below, which is returned when you have something that cannot be tokenized in the input.
- I used
df["text"] = df["text"].str.lower()
on the column, which was enough to fix this issue.
ValueError: TextEncodeInput must be Union[TextInputSequence, Tuple[InputSequence, InputSequence]]
Additionally, I have the is_split_into_words
flag set to True
when loading the tokenizer from the pretrained model.
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased", is_split_into_words=True)
When using AutoModelForSequenceClassification.from_pretrained()
for multi-label classification, the problem_type
needs to be set to multi_label_classification
. This solution was discovered through alexander's post after looking through the official documentation for multi-label classification. The official documentation was not updated for this pr This changes the loss function to BCEWithLogitsLoss()
from CrossEntropyLoss()
or MSELoss()
(depending on the problem type).
model = AutoModelForSequenceClassification.from_pretrained(
- Additionally, you will need to add the number of labels as a property of the
If you are using torchmetrics for accuracy like I did, you will need to either add task=multilabel
to the accuracy method, or import MultilabelAccuracy
directly and set num_labels
to the desired number (6, for my problem).
self.val_acc = MultilabelAccuracy(num_labels=6)
self.test_acc = MultilabelAccuracy(num_labels=6)
The following show the code I used to finetune the model for multi-label sequence classification
import os
import random
# pathing
from pathlib import Path
# For data manipulation
import numpy as np
import pandas as pd
# Pytorch Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
# Lighting Imports
import pytorch_lightning as pt
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.core import LightningModule
# Finetune scheduler
from finetuning_scheduler import FinetuningScheduler
# Torchmetrics
from torchmetrics.classification import MultilabelAccuracy
Define the tokenizer. Notice that the defined the problem_type
flag in from_pretrained
model = AutoModelForSequenceClassification.from_pretrained(
Define the dataset
class JigsawDataset(Dataset):
def __init__(self, dataset_dict, partition_key="train"):
self.partition = dataset_dict[partition_key]
def __getitem__(self, index):
return self.partition[index]
def __len__(self):
return self.partition.num_rows
Define the lightning Model. This has 6 labels and shows some of the gotchas mentioned above regarding the MultilabelAccuracy
metrics, and the self.num_labels
class LightningModel(LightningModule):
def __init__(self, model, learning_rate=5e-5):
self.num_labels = 6
self.learning_rate = learning_rate
self.model = model
self.val_acc = MultilabelAccuracy(num_labels=6)
self.test_acc = MultilabelAccuracy(num_labels=6)
def forward(self, input_ids, attention_mask, labels):
return self.model(input_ids, attention_mask=attention_mask, labels=labels)
def training_step(self, batch, batch_idx):
label_toxic = batch["toxic"]
label_severe_toxic = batch["severe_toxic"]
label_obscene = batch["obscene"]
label_threat = batch["threat"]
label_insult = batch["insult"]
label_identity_hate = batch["identity_hate"]
labels = torch.stack((label_toxic, label_severe_toxic, label_obscene, label_threat, label_insult, label_identity_hate), axis=1).to(torch.float32)
ids = batch["input_ids"]
mask = batch["attention_mask"]
outputs = self(ids, attention_mask=mask, labels=labels)
self.log("train_loss", outputs["loss"])
return outputs["loss"] # this is passed to the optimizer for training
def validation_step(self, batch, batch_idx):
label_toxic = batch["toxic"]
label_severe_toxic = batch["severe_toxic"]
label_obscene = batch["obscene"]
label_threat = batch["threat"]
label_insult = batch["insult"]
label_identity_hate = batch["identity_hate"]
labels = torch.stack((label_toxic, label_severe_toxic, label_obscene, label_threat, label_insult, label_identity_hate), axis=1).to(torch.float32)
ids = batch["input_ids"]
mask = batch["attention_mask"]
outputs = self(ids, attention_mask=mask, labels=labels)
self.log("val_loss", outputs["loss"], prog_bar=True)
logits = outputs["logits"]
self.val_acc(logits, labels)
self.log("val_acc", self.val_acc, prog_bar=True)
def test_step(self, batch, batch_idx):
label_toxic = batch["toxic"]
label_severe_toxic = batch["severe_toxic"]
label_obscene = batch["obscene"]
label_threat = batch["threat"]
label_insult = batch["insult"]
label_identity_hate = batch["identity_hate"]
labels = torch.stack((label_toxic, label_severe_toxic, label_obscene, label_threat, label_insult, label_identity_hate), axis=1).to(torch.float32)
ids = batch["input_ids"]
mask = batch["attention_mask"]
outputs = self(ids, attention_mask=mask, labels=labels)
logits = outputs["logits"]
self.test_acc(logits, labels)
self.log("accuracy", self.test_acc, prog_bar=True)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
return optimizer
Open the kaggle dataset train.csv
file and perform some steps to prepare it for our model.
# Define a variable to store the input and temp output path
input_path = Path("input/")
output_path = Path("output/")
# Read the train csv file, rename the comment_text column, and make sure
# the column is a lowercase string
df = pd.read_csv(input_path / "train.csv").dropna()
df = df.rename(columns={"comment_text": "text"})
df["text"] = df["text"].str.lower()
# Drop the ID field since we don't need it
df.drop(columns=["id"], inplace=True)
# Shuffle the dataframe and split the data into train, validation, and test
df_shuffled = df.sample(frac=1, random_state=1).reset_index()
df_train = df_shuffled.iloc[:100_000]
df_val = df_shuffled.iloc[100_000:140_000]
df_test = df_shuffled.iloc[140_000:]
# Save the new shuffled data into csv files to be opened later
df_train.to_csv(output_path / "train.csv", index=False, encoding="utf-8")
df_val.to_csv(output_path / "validation.csv", index=False, encoding="utf-8")
df_test.to_csv(output_path / "test.csv", index=False, encoding="utf-8")
Define the dataset using load_data
from Dataset
from torch.utils.data import Dataset, DataLoader
jigsaw_dataset = load_dataset(
"train": str(output_path / "train.csv"),
"validation": str(output_path / "validation.csv"),
"test": str(output_path / "test.csv"),
Tokenize the input text
def tokenize_text(batch):
return tokenizer(
jigsaw_tokenized = jigsaw_dataset.map(tokenize_text, batched=True, batch_size=None)
jigsaw_tokenized.set_format("torch", columns=["input_ids", "attention_mask", "toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"])
# Create the datasets
train_dataset = JigsawDataset(jigsaw_tokenized, partition_key="train")
val_dataset = JigsawDataset(jigsaw_tokenized, partition_key="validation")
test_dataset = JigsawDataset(jigsaw_tokenized, partition_key="test")
# Create the dataloaders
train_loader = DataLoader(
val_loader = DataLoader(
test_loader = DataLoader(
lightning_model = LightningModel(model)
callbacks = [
ModelCheckpoint(save_top_k=1, mode="max", monitor="val_acc"), # save top 1 model
trainer = Trainer(
# Test the train, val, and test loaders
trainer.test(lightning_model, dataloaders=train_loader, ckpt_path="best")
trainer.test(lightning_model, dataloaders=val_loader, ckpt_path="best")
trainer.test(lightning_model, dataloaders=test_loader, ckpt_path="best")
This outputs the following results
[{'accuracy': 0.9737749695777893}]
[{'accuracy': 0.9733250141143799}]
[{'accuracy': 0.974358320236206}]
Not great, but not bad either