from typing import Dict, Tuple, Optional
import lightning.pytorch as pl
import torch
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup
from transformers.modeling_outputs import SequenceClassifierOutput
from logitorch.models.fld import FLDAllAtOnceProver
[docs]
class PLFLDAllAtOnceProver(pl.LightningModule):
"""
PyTorch Lightning module for Fine-tuned Language Decoder (FLD) all-at-once prover.
Args:
pretrained_model (str): Pretrained model name or path (default: "t5-base").
learning_rate (float): Learning rate for optimizer.
weight_decay (float): Weight decay for optimizer (default: 0.1).
warmup_steps (int, optional): Number of warmup steps for learning rate scheduler (default: 1000).
Attributes:
model (FLDAllAtOnceProver): FLD model.
pretrained_model (str): Pretrained model name or path.
learning_rate (float): Learning rate for optimizer.
weight_decay (float): Weight decay for optimizer.
warmup_steps (int): Number of warmup steps for learning rate scheduler.
optimizer (AdamW): Optimizer for training.
"""
def __init__(
self,
pretrained_model: str = "t5-base",
learning_rate: float = None,
weight_decay=0.1,
warmup_steps: Optional[int] = 1000,
) -> None:
super().__init__()
[docs]
self.model = FLDAllAtOnceProver(pretrained_model)
[docs]
self.pretrained_model = pretrained_model
[docs]
self.learning_rate = learning_rate
[docs]
self.weight_decay = weight_decay
[docs]
self.warmup_steps = warmup_steps
[docs]
def forward(self, x, y) -> SequenceClassifierOutput:
"""
Forward pass of the model.
Args:
x: Input data.
y: Target data.
Returns:
SequenceClassifierOutput: Model output.
"""
return self.model(x, y)
[docs]
def predict(
self,
prompt: str,
num_beams: int = 5,
max_length: int = 1000,
device: str = "cpu",
):
"""
Generate predictions using the model.
Args:
prompt (str): Input prompt.
num_beams (int): Number of beams for beam search (default: 5).
max_length (int): Maximum length of generated sequence (default: 1000).
device (str): Device to use for prediction (default: "cpu").
Returns:
Model predictions.
"""
return self.model.predict(prompt, num_beams, max_length, device)
[docs]
def training_step(self, train_batch: Tuple[Dict[str, torch.Tensor], torch.Tensor], batch_idx: int) -> torch.Tensor:
"""
Training step.
Args:
train_batch: Batch of training data.
batch_idx: Index of the batch.
Returns:
torch.Tensor: Loss value.
"""
x, y = train_batch
loss = self(x, y).loss
self.log("train_loss", loss, on_step=True)
for param_group in self.optimizer.param_groups:
print("Current learning rate is: {}".format(param_group['lr']))
return loss
[docs]
def validation_step(self, val_batch: Tuple[Dict[str, torch.Tensor], torch.Tensor], batch_idx: int) -> None:
"""
Validation step.
Args:
val_batch: Batch of validation data.
batch_idx: Index of the batch.
"""
x, y = val_batch
loss = self(x, y).loss
self.log("val_loss", loss, on_epoch=True)