logitorch.pl_models.ruletaker ============================= .. py:module:: logitorch.pl_models.ruletaker Classes ------- .. autoapisummary:: logitorch.pl_models.ruletaker.PLRuleTaker Module Contents --------------- .. py:class:: PLRuleTaker(learning_rate: float = 1e-05, weight_decay: float = 0.1, num_labels: int = 2) Bases: :py:obj:`lightning.pytorch.LightningModule` Initializes the PLRuleTaker module. Args: learning_rate (float): The learning rate for the optimizer. Default is 1e-5. weight_decay (float): The weight decay for the optimizer. Default is 0.1. num_labels (int): The number of labels for the RuleTaker model. Default is 2. .. py:method:: configure_optimizers() Configures the optimizer and scheduler for training. Returns: The optimizer and scheduler. .. py:method:: forward(x, y) Performs a forward pass of the PLRuleTaker module. Args: x: The input data. y: The target labels. Returns: The output of the model. .. py:method:: predict(context: str, question: str, device: str = 'cpu') -> int Predicts the label for a given context and question. Args: context (str): The context. question (str): The question. device (str): The device to use for prediction. Default is "cpu". Returns: The predicted label. .. py:method:: training_step(train_batch: Tuple[Dict[str, torch.Tensor], torch.Tensor], batch_idx: int) -> torch.Tensor Performs a training step. Args: train_batch: The batch of training data. batch_idx (int): The index of the batch. Returns: The training loss. .. py:method:: validation_step(val_batch: Tuple[Dict[str, torch.Tensor], torch.Tensor], batch_idx: int) -> None Performs a validation step. Args: val_batch: The batch of validation data. batch_idx (int): The index of the batch. .. py:attribute:: learning_rate :value: 1e-05 .. py:attribute:: model .. py:attribute:: weight_decay :value: 0.1