[docs]classBERTNOT(nn.Module):def__init__(self,pretrained_bert_model:str,num_labels:int=2)->None:""" BERTNOT model for fine-tuning BERT for various tasks. Args: pretrained_bert_model (str): Path or identifier of the pre-trained BERT model. num_labels (int, optional): Number of labels for the classification task. Defaults to 2. """super().__init__()
[docs]defforward(self,x,y=None,task="mlm",loss="cross_entropy"):""" Forward pass of the BERTNOT model. Args: x (dict): Input dictionary containing the input tensors. y (torch.Tensor, optional): Target tensor. Defaults to None. task (str, optional): Task type. Defaults to "mlm". loss (str, optional): Loss type. Defaults to "cross_entropy". Returns: tuple: Tuple containing the loss and logits if y is not None, otherwise returns logits. """try:iftasknotinself.tasks:raiseTaskError(self.tasks)iflossnotinself.losses:raiseLossError(self.losses)iftask=="mlm":outputs=self.model(**x)logits=outputs.logitsifyisnotNone:ifloss=="cross_entropy":loss=self.cross_entopy_loss(logits.view(-1,self.model.config.vocab_size),y.view(-1))return(loss,logits)elifloss=="unlikelihood":loss=self.unlikelihood_loss(logits.view(-1,self.model.config.vocab_size),y.view(-1))return(loss,logits)else:original_outputs=self.original_bert(**x)[0]mask_token_indexes=torch.ne(y,-100)original_outputs=original_outputs[mask_token_indexes]original_probs=self.original_bert_softmax(original_outputs)pred_probs=self.log_softmax(logits[mask_token_indexes])loss=self.kl_loss(pred_probs,original_probs)return(loss,logits)else:returnlogitselse:outputs=self.model.bert(**x)[0]cls_representation=outputs[:,0,:]sequence_outputs=self.dropout(cls_representation)logits=self.sequence_classifier(sequence_outputs)ifyisnotNone:loss=self.cross_entopy_loss(logits.view(-1,self.num_labels),y.view(-1))return(loss,logits)else:returnlogitsexceptTaskErroraserr:print(err.message)
[docs]defpredict(self,context:str,hypothesis:str=None,task="mlm",device="cpu"):""" Perform prediction using the BERTNOT model. Args: context (str): Input context string. hypothesis (str, optional): Input hypothesis string. Defaults to None. task (str, optional): Task type. Defaults to "mlm". device (str, optional): Device to run the model on. Defaults to "cpu". Returns: str or int: Predicted token or label. """try:iftasknotinself.tasks:raiseTaskError(self.tasks)ifhypothesisisNone:tokenized_x=self.tokenizer(context,return_tensors="pt")else:tokenized_x=self.tokenizer(context,hypothesis,return_tensors="pt")logits=self(tokenized_x.to(device),task=task)iftask=="mlm":mask_token_indexes=(tokenized_x.input_ids==self.tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]predicted_token_id=logits[0,mask_token_indexes].argmax(axis=-1)returnself.tokenizer.decode(predicted_token_id)else:pred=logits.argmax().item()returnpredexceptTaskErroraserr:print(err.message)