from typing import Dict, Tuple
import lightning.pytorch as pl
import torch
from transformers import Adafactor, get_linear_schedule_with_warmup
from transformers.modeling_outputs import SequenceClassifierOutput
from logitorch.models.proofwriter import ProofWriter
[docs]
class PLProofWriter(pl.LightningModule):
def __init__(
self,
pretrained_model: str = "google/t5-v1_1-large",
learning_rate: float = None,
weight_decay=0.1,
) -> None:
"""
Initializes a PLProofWriter object.
Args:
pretrained_model (str, optional): The name or path of the pretrained model to use. Defaults to "google/t5-v1_1-large".
learning_rate (float, optional): The learning rate for the optimizer. Defaults to None.
weight_decay (float, optional): The weight decay for the optimizer. Defaults to 0.1.
"""
super().__init__()
[docs]
self.model = ProofWriter(pretrained_model)
[docs]
self.pretrained_model = pretrained_model
[docs]
self.learning_rate = learning_rate
[docs]
self.weight_decay = weight_decay
[docs]
def forward(self, x, y) -> SequenceClassifierOutput: # type: ignore
"""
Performs a forward pass of the model.
Args:
x: The input data.
y: The target data.
Returns:
SequenceClassifierOutput: The output of the model.
"""
return self.model(x, y)
[docs]
def predict(
self,
context: str,
question: str,
num_beams: int = 5,
max_length: int = 120,
device: str = "cpu",
):
"""
Generates predictions for the given context and question.
Args:
context (str): The context for the prediction.
question (str): The question for the prediction.
num_beams (int, optional): The number of beams for beam search decoding. Defaults to 5.
max_length (int, optional): The maximum length of the generated sequence. Defaults to 120.
device (str, optional): The device to use for prediction. Defaults to "cpu".
Returns:
The generated predictions.
"""
return self.model.predict(context, question, num_beams, max_length, device)
[docs]
def training_step(self, train_batch: Tuple[Dict[str, torch.Tensor], torch.Tensor], batch_idx: int) -> torch.Tensor: # type: ignore
"""
Performs a training step.
Args:
train_batch (Tuple[Dict[str, torch.Tensor], torch.Tensor]): The batch of training data.
batch_idx (int): The index of the batch.
Returns:
torch.Tensor: The loss value.
"""
x, y = train_batch
loss = self(x, y).loss
self.log("train_loss", loss, on_epoch=True)
return loss
[docs]
def validation_step(self, val_batch: Tuple[Dict[str, torch.Tensor], torch.Tensor], batch_idx: int) -> None: # type: ignore
"""
Performs a validation step.
Args:
val_batch (Tuple[Dict[str, torch.Tensor], torch.Tensor]): The batch of validation data.
batch_idx (int): The index of the batch.
"""
x, y = val_batch
loss = self(x, y).loss
self.log("val_loss", loss, on_epoch=True)