[docs]classFLDAllAtOnceProver(nn.Module):def__init__(self,pretrained_t5_model:str)->None:""" Initializes the FLDAllAtOnceProver model. Args: pretrained_t5_model (str): The name or path of the pretrained T5 model. """super().__init__()
[docs]defforward(self,x:Dict[str,torch.Tensor],y:torch.Tensor=None)->SequenceClassifierOutput:""" Performs a forward pass of the model. Args: x (Dict[str, torch.Tensor]): The input tensors. y (torch.Tensor, optional): The labels tensor. Defaults to None. Returns: SequenceClassifierOutput: The output of the model. """ifyisnotNone:returnself.model(**x,labels=y)returnself.model(**x)
[docs]defpredict(self,prompt:str,num_beams:int=5,max_length:int=1000,device:str="cpu",)->List[str]:""" Generates predictions based on the given prompt. Args: prompt (str): The input prompt. num_beams (int, optional): The number of beams for beam search. Defaults to 5. max_length (int, optional): The maximum length of the generated sequence. Defaults to 1000. device (str, optional): The device to run the model on. Defaults to "cpu". Returns: List[str]: The generated predictions. """withtorch.no_grad():tokenized_x=self.tokenizer(prompt,padding=True,return_tensors="pt")beam_output=self.model.generate(**tokenized_x.to(device),max_length=max_length,num_beams=num_beams,do_sample=True,top_p=0.90)pred=self.tokenizer.decode(beam_output[0],skip_special_tokens=True,clean_up_tokenization_spaces=True,)returnpred