from typing import Dict, Tuple
import lightning.pytorch as pl
import torch
from torch.optim import AdamW
from torch.utils.data.dataloader import DataLoader
from transformers import get_linear_schedule_with_warmup
from logitorch.data_collators.bertnot_collator import BERTNOTWiki20KCollator
from logitorch.datasets.mlm.wiki20k_dataset import Wiki20KDataset
from logitorch.models.bertnot import BERTNOT
[docs]
class PLBERTNOT(pl.LightningModule):
"""
PyTorch Lightning module for BERTNOT model.
Args:
pretrained_model (str): Pretrained model name or path.
task (str): Task type, either "mlm" (masked language modeling) or "te" (text entailment).
num_labels (int): Number of labels for the classification task.
learning_rate (float): Learning rate for the optimizer.
weight_decay (float): Weight decay for the optimizer.
batch_size (int): Batch size for the data loader.
gamma (float): Gamma value for the loss calculation.
Attributes:
model (BERTNOT): BERTNOT model instance.
pretrained_model (str): Pretrained model name or path.
learning_rate (float): Learning rate for the optimizer.
weight_decay (float): Weight decay for the optimizer.
batch_size (int): Batch size for the data loader.
gamma (float): Gamma value for the loss calculation.
task (str): Task type, either "mlm" (masked language modeling) or "te" (text entailment).
"""
def __init__(
self,
pretrained_model: str,
task: str = "mlm",
num_labels: int = 2,
learning_rate: float = 1e-5,
weight_decay: float = 0.1,
batch_size: int = 32,
gamma: float = 0.4,
) -> None:
super().__init__()
[docs]
self.model = BERTNOT(pretrained_model, num_labels=num_labels)
[docs]
self.pretrained_model = pretrained_model
[docs]
self.learning_rate = learning_rate
[docs]
self.weight_decay = weight_decay
[docs]
self.batch_size = batch_size
if self.task == "mlm":
self.automatic_optimization = False
[docs]
def forward(self, x, y=None, loss="cross_entropy"):
"""
Forward pass of the PLBERTNOT model.
Args:
x: Input data.
y: Target labels.
loss (str): Loss function type.
Returns:
torch.Tensor: Model output.
"""
if self.task == "mlm":
if y is not None:
return self.model(x, y, task="mlm", loss=loss)
return self.model(x, task="mlm")
else:
if y is not None:
return self.model(x, y, task="te")
return self.model(x, task="te")
[docs]
def train_dataloader(self):
"""
Get the training data loader.
Returns:
Dict[str, DataLoader]: Dictionary of data loaders.
"""
negated_wiki20k_dataset = Wiki20KDataset("negated_lm_wiki20k")
positive_wiki20k_dataset = Wiki20KDataset("positive_lm_wiki20k")
wiki20k_dataset = Wiki20KDataset("lm_wiki20k")
collator_fn = BERTNOTWiki20KCollator(self.pretrained_model)
loader_negated_wiki20k = DataLoader(
negated_wiki20k_dataset, self.batch_size, collate_fn=collator_fn
)
loader_positive_wiki20k = DataLoader(
positive_wiki20k_dataset, self.batch_size, collate_fn=collator_fn
)
loader_wiki20k = DataLoader(
wiki20k_dataset, self.batch_size, collate_fn=collator_fn
)
return {
"negated_wiki20k": loader_negated_wiki20k,
"positive_wiki20k": loader_positive_wiki20k,
"wiki20k": loader_wiki20k,
}
[docs]
def training_step(self, train_batch: Tuple[Dict[str, torch.Tensor], torch.Tensor], batch_idx: int):
"""
Training step of the PLBERTNOT model.
Args:
train_batch: Batch of training data.
batch_idx (int): Batch index.
Returns:
torch.Tensor: Loss value.
"""
if self.task == "mlm":
optimizer = self.optimizers()
x, y = train_batch["negated_wiki20k"]
loss_ul_negated_wiki20k, _ = self(x, y, loss="unlikelihood")
x, y = train_batch["positive_wiki20k"]
loss_kl_positive_wiki20k, _ = self(x, y, loss="kl")
loss = (
self.gamma * loss_ul_negated_wiki20k
+ (1 - self.gamma) * loss_kl_positive_wiki20k
)
optimizer.zero_grad()
self.manual_backward(loss)
optimizer.step()
x, y = train_batch["wiki20k"]
loss_kl_wiki20k, _ = self(x, y, loss="kl")
optimizer.zero_grad()
self.manual_backward(loss_kl_wiki20k)
optimizer.step()
self.log_dict(
{
"unlikelihood loss negated wiki20k": loss_ul_negated_wiki20k,
"knowledge distillation loss positive wiki20k": loss_kl_positive_wiki20k,
"knowledge distillation loss wiki20k": loss_kl_wiki20k,
},
prog_bar=True,
)
else:
x, y = train_batch
loss, _ = self(x, y)
self.log_dict({"train_loss": loss}, prog_bar=True, on_epoch=True)
return loss
[docs]
def validation_step(self, val_batch: Tuple[Dict[str, torch.Tensor], torch.Tensor], batch_idx: int):
"""
Validation step of the PLBERTNOT model.
Args:
val_batch: Batch of validation data.
batch_idx (int): Batch index.
"""
if self.task == "mlm":
x, y = val_batch["negated_wiki20k"]
loss_ul_negated_wiki20k, _ = self(x, y, loss="unlikelihood")
x, y = val_batch["positive_wiki20k"]
loss_kl_positive_wiki20k, _ = self(x, y, loss="kl")
loss = (
self.gamma * loss_ul_negated_wiki20k
+ (1 - self.gamma) * loss_kl_positive_wiki20k
)
x, y = val_batch["wiki20k"]
loss_kl_wiki20k, _ = self(x, y, loss="kl")
loss += loss_kl_wiki20k
self.log_dict({"val_loss": loss}, prog_bar=True)
else:
x, y = val_batch
loss, _ = self(x, y)
self.log_dict({"val_loss": loss}, prog_bar=True, on_epoch=True)
[docs]
def predict(self, context: str, hypothesis: str = None, task="mlm", device="cpu"):
"""
Make predictions using the PLBERTNOT model.
Args:
context (str): Input context.
hypothesis (str): Input hypothesis (optional).
task (str): Task type, either "mlm" (masked language modeling) or "te" (text entailment).
device (str): Device to run the model on.
Returns:
torch.Tensor: Model predictions.
"""
return self.model.predict(context, hypothesis, task, device)