logitorch.pl_models.fld ======================= .. py:module:: logitorch.pl_models.fld Classes ------- .. autoapisummary:: logitorch.pl_models.fld.PLFLDAllAtOnceProver Module Contents --------------- .. py:class:: PLFLDAllAtOnceProver(pretrained_model: str = 't5-base', learning_rate: float = None, weight_decay=0.1, warmup_steps: Optional[int] = 1000) Bases: :py:obj:`lightning.pytorch.LightningModule` PyTorch Lightning module for Fine-tuned Language Decoder (FLD) all-at-once prover. Args: pretrained_model (str): Pretrained model name or path (default: "t5-base"). learning_rate (float): Learning rate for optimizer. weight_decay (float): Weight decay for optimizer (default: 0.1). warmup_steps (int, optional): Number of warmup steps for learning rate scheduler (default: 1000). Attributes: model (FLDAllAtOnceProver): FLD model. pretrained_model (str): Pretrained model name or path. learning_rate (float): Learning rate for optimizer. weight_decay (float): Weight decay for optimizer. warmup_steps (int): Number of warmup steps for learning rate scheduler. optimizer (AdamW): Optimizer for training. .. py:method:: configure_optimizers() Configure the optimizer and learning rate scheduler. Returns: Tuple[List[Optimizer], List[Dict[str, Any]]]: Optimizers and schedulers. .. py:method:: forward(x, y) -> transformers.modeling_outputs.SequenceClassifierOutput Forward pass of the model. Args: x: Input data. y: Target data. Returns: SequenceClassifierOutput: Model output. .. py:method:: predict(prompt: str, num_beams: int = 5, max_length: int = 1000, device: str = 'cpu') Generate predictions using the model. Args: prompt (str): Input prompt. num_beams (int): Number of beams for beam search (default: 5). max_length (int): Maximum length of generated sequence (default: 1000). device (str): Device to use for prediction (default: "cpu"). Returns: Model predictions. .. py:method:: training_step(train_batch: Tuple[Dict[str, torch.Tensor], torch.Tensor], batch_idx: int) -> torch.Tensor Training step. Args: train_batch: Batch of training data. batch_idx: Index of the batch. Returns: torch.Tensor: Loss value. .. py:method:: validation_step(val_batch: Tuple[Dict[str, torch.Tensor], torch.Tensor], batch_idx: int) -> None Validation step. Args: val_batch: Batch of validation data. batch_idx: Index of the batch. .. py:attribute:: learning_rate :value: None .. py:attribute:: model .. py:attribute:: optimizer :value: None .. py:attribute:: pretrained_model :value: 't5-base' .. py:attribute:: warmup_steps :value: 1000 .. py:attribute:: weight_decay :value: 0.1