[docs]classRuleTakerCollator:""" A collator class for RuleTaker model. This collator is used to preprocess and collate data for RuleTaker model training or inference. """def__init__(self)->None:
def__call__(self,batch)->Tuple[Dict[str,torch.Tensor],torch.Tensor]:""" Preprocesses and collates the batch data. Args: | batch: A list of tuples containing the context, question, label, and additional information. Returns: | Tuple[Dict[str, torch.Tensor], torch.Tensor]: A tuple containing the batch inputs and labels. """contexts=[]questions=[]batch_y=[]forcontext,question,label,_inbatch:contexts.append(context)questions.append(question)batch_y.append(label)batch_x=self.tokenizer(contexts,questions,padding=True,return_tensors="pt")returnbatch_x,torch.tensor(batch_y,dtype=torch.int64)
[docs]classRuleTakerProofWriterCollator:""" A collator class for RuleTaker with ProofWriter model. This collator is used to preprocess and collate data for RuleTaker with ProofWriter model training or inference. """def__init__(self)->None:
def__call__(self,batch)->Tuple[Dict[str,torch.Tensor],torch.Tensor]:""" Preprocesses and collates the batch data. Args: | batch: A list of tuples containing the context, question, label, and additional information. Returns: | Tuple[Dict[str, torch.Tensor], torch.Tensor]: A tuple containing the batch inputs and labels. """contexts=[]questions=[]labels=[]foriinbatch:sentences=[]fork,vini[0].items():sentences.append(f"{k}: {v}")fork,vini[1].items():sentences.append(f"{k}: {v}")contexts.append("".join(sentences))questions.append(i[2])labels.append(PROOFWRITER_LABEL_TO_ID[str(i[3])])batch_x=self.tokenizer(contexts,questions,padding=True,return_tensors="pt")returnbatch_x,torch.tensor(labels,dtype=torch.int64)