Source code for logitorch.data_collators.fairr_collator
fromtransformersimportRobertaTokenizer
[docs]classRuleSelectionProofWriterIterCollator:def__init__(self,pretrained_roberta_tokenizer:str)->None:""" Initializes the RuleSelectionProofWriterIterCollator. Args: | pretrained_roberta_tokenizer (str): The name or path of the pretrained Roberta tokenizer. """
def__call__(self,batch):""" Processes a batch of data. Args: | batch: The input batch. Returns: | tuple: A tuple containing the processed batch input and the corresponding labels. """batch_x=[]batch_y=[]y_indices=[]forfacts_list,rules_list,question,_,proofsinbatch:rules=[]facts=[]y=0forfactinfacts_list.values():facts.append(fact)forcnt,ruleinenumerate(rules_list):rules.append(rules_list[rule])ifproofs[0]isnotNone:ifruleinproofs[0]:y=cnt+1batch_x.append(question+" </s> "+" ".join(facts)+" </s> ".join(rules))batch_y.append(y)batch_x=self.tokenizer(batch_x,padding=True,return_tensors="pt")id_sep_token=2print(batch_x["input_ids"])print(id_sep_token)returnbatch_x,batch_y