Source code for logitorch.models.bertnot

import torch
import torch.nn as nn
from torch.nn.init import xavier_normal_
from transformers import BertForMaskedLM, BertTokenizer

from logitorch.losses.unlikelihood_loss import UnlikelihoodLoss
from logitorch.models.exceptions import LossError, TaskError


[docs] class BERTNOT(nn.Module): def __init__(self, pretrained_bert_model: str, num_labels: int = 2) -> None: """ BERTNOT model for fine-tuning BERT for various tasks. Args: pretrained_bert_model (str): Path or identifier of the pre-trained BERT model. num_labels (int, optional): Number of labels for the classification task. Defaults to 2. """ super().__init__()
[docs] self.model = BertForMaskedLM.from_pretrained(pretrained_bert_model)
classifier_dropout = ( self.model.config.classifier_dropout if self.model.config.classifier_dropout is not None else self.model.config.hidden_dropout_prob )
[docs] self.dropout = nn.Dropout(classifier_dropout)
[docs] self.sequence_classifier = nn.Linear(self.model.config.hidden_size, num_labels)
[docs] self.original_bert = BertForMaskedLM.from_pretrained(pretrained_bert_model)
[docs] self.tasks = ["mlm", "te"]
[docs] self.losses = ["cross_entropy", "unlikelihood", "kl"]
[docs] self.num_labels = num_labels
[docs] self.original_bert_softmax = nn.Softmax(dim=1)
[docs] self.log_softmax = nn.LogSoftmax(dim=1)
[docs] self.cross_entopy_loss = nn.CrossEntropyLoss()
[docs] self.unlikelihood_loss = UnlikelihoodLoss()
[docs] self.kl_loss = nn.KLDivLoss(reduction="batchmean")
[docs] self.tokenizer = BertTokenizer.from_pretrained(pretrained_bert_model)
xavier_normal_(self.sequence_classifier.weight)
[docs] def forward(self, x, y=None, task="mlm", loss="cross_entropy"): """ Forward pass of the BERTNOT model. Args: x (dict): Input dictionary containing the input tensors. y (torch.Tensor, optional): Target tensor. Defaults to None. task (str, optional): Task type. Defaults to "mlm". loss (str, optional): Loss type. Defaults to "cross_entropy". Returns: tuple: Tuple containing the loss and logits if y is not None, otherwise returns logits. """ try: if task not in self.tasks: raise TaskError(self.tasks) if loss not in self.losses: raise LossError(self.losses) if task == "mlm": outputs = self.model(**x) logits = outputs.logits if y is not None: if loss == "cross_entropy": loss = self.cross_entopy_loss( logits.view(-1, self.model.config.vocab_size), y.view(-1) ) return (loss, logits) elif loss == "unlikelihood": loss = self.unlikelihood_loss( logits.view(-1, self.model.config.vocab_size), y.view(-1) ) return (loss, logits) else: original_outputs = self.original_bert(**x)[0] mask_token_indexes = torch.ne(y, -100) original_outputs = original_outputs[mask_token_indexes] original_probs = self.original_bert_softmax(original_outputs) pred_probs = self.log_softmax(logits[mask_token_indexes]) loss = self.kl_loss(pred_probs, original_probs) return (loss, logits) else: return logits else: outputs = self.model.bert(**x)[0] cls_representation = outputs[:, 0, :] sequence_outputs = self.dropout(cls_representation) logits = self.sequence_classifier(sequence_outputs) if y is not None: loss = self.cross_entopy_loss( logits.view(-1, self.num_labels), y.view(-1) ) return (loss, logits) else: return logits except TaskError as err: print(err.message)
[docs] def predict(self, context: str, hypothesis: str = None, task="mlm", device="cpu"): """ Perform prediction using the BERTNOT model. Args: context (str): Input context string. hypothesis (str, optional): Input hypothesis string. Defaults to None. task (str, optional): Task type. Defaults to "mlm". device (str, optional): Device to run the model on. Defaults to "cpu". Returns: str or int: Predicted token or label. """ try: if task not in self.tasks: raise TaskError(self.tasks) if hypothesis is None: tokenized_x = self.tokenizer(context, return_tensors="pt") else: tokenized_x = self.tokenizer(context, hypothesis, return_tensors="pt") logits = self(tokenized_x.to(device), task=task) if task == "mlm": mask_token_indexes = ( tokenized_x.input_ids == self.tokenizer.mask_token_id )[0].nonzero(as_tuple=True)[0] predicted_token_id = logits[0, mask_token_indexes].argmax(axis=-1) return self.tokenizer.decode(predicted_token_id) else: pred = logits.argmax().item() return pred except TaskError as err: print(err.message)