Source code for logitorch.pipelines.proof_qa_pipelines

import lightning.pytorch as pl
import torch.nn as nn
from lightning.pytorch.callbacks import ModelCheckpoint
from torch.utils.data.dataloader import DataLoader

from logitorch.data_collators.proofwriter_collator import (
    ProofWriterProofGenerationAllCollator,
)
from logitorch.data_collators.prover_collator import PRoverProofWriterCollator
from logitorch.data_collators.fld_collator import FLDProofGenerationAllCollator
from logitorch.datasets.proof_qa.proofwriter_dataset import ProofWriterDataset
from logitorch.datasets.proof_qa.fld_dataset import FLDDataset
from logitorch.pipelines.exceptions import ModelNotCompatibleError
from logitorch.pl_models.proofwriter import PLProofWriter
from logitorch.pl_models.prover import PLPRover
from logitorch.pl_models.fld import FLDAllAtOnceProver

[docs] PROOFWRITER_COMPATIBLE_MODELS = (PLProofWriter, PLPRover)
[docs] FLD_COMPATIBLE_MODELS = (FLDAllAtOnceProver,)
[docs] def proofwriter_pipeline( model: nn.Module, dataset_name: str, task: str = "proof_generation_all", open_world_assumption: bool = False, saved_model_path: str = "/", saved_model_name: str = "best_model", batch_size: int = 1, epochs: int = 1, accelerator: str = "cpu", gpus: int = 0, ): """ Executes the proofwriter pipeline for training a proof generation model. Args: model (nn.Module): The proof generation model. dataset_name (str): The name of the dataset. task (str, optional): The task to perform. Defaults to "proof_generation_all". open_world_assumption (bool, optional): Whether to use open world assumption. Defaults to False. saved_model_path (str, optional): The path to save the trained model. Defaults to "/". saved_model_name (str, optional): The name of the saved model. Defaults to "best_model". batch_size (int, optional): The batch size for training. Defaults to 1. epochs (int, optional): The number of training epochs. Defaults to 1. accelerator (str, optional): The accelerator to use (e.g., "cpu", "gpu"). Defaults to "cpu". gpus (int, optional): The number of GPUs to use. Defaults to 0. Raises: ModelNotCompatibleError: If the provided model is not compatible with the proofwriter pipeline. """ try: if isinstance(model, PROOFWRITER_COMPATIBLE_MODELS): train_dataset = ProofWriterDataset( dataset_name, "train", task, open_world_assumption ) val_dataset = ProofWriterDataset( dataset_name, "val", task, open_world_assumption ) if isinstance(model, PLProofWriter): proofwriter_collate_fn = ProofWriterProofGenerationAllCollator( model.pretrained_model ) elif isinstance(model, PLPRover): proofwriter_collate_fn = PRoverProofWriterCollator( model.pretrained_model ) train_dataloader = DataLoader( train_dataset, batch_size=batch_size, collate_fn=proofwriter_collate_fn ) val_dataloader = DataLoader( val_dataset, batch_size=batch_size, collate_fn=proofwriter_collate_fn ) checkpoint_callback = ModelCheckpoint( save_top_k=1, monitor="val_loss", mode="min", dirpath=saved_model_path, filename=saved_model_name, ) trainer = pl.Trainer( callbacks=[checkpoint_callback], max_epochs=epochs, accelerator=accelerator, gpus=gpus, ) trainer.fit(model, train_dataloader, val_dataloader) else: raise ModelNotCompatibleError(PROOFWRITER_COMPATIBLE_MODELS) except ModelNotCompatibleError as err: print(err.message)
[docs] def fld_pipeline( model: nn.Module, dataset_name: str, task: str = "proof_generation_all", saved_model_path: str = "/", saved_model_name: str = "best_model", batch_size: int = 4, accum_steps: int = 16, epochs: int = 40, accelerator: str = "cpu", gpus: int = 0, ): """ Executes the fld pipeline for training a proof generation model. Args: model (nn.Module): The proof generation model. dataset_name (str): The name of the dataset. task (str, optional): The task to perform. Defaults to "proof_generation_all". saved_model_path (str, optional): The path to save the trained model. Defaults to "/". saved_model_name (str, optional): The name of the saved model. Defaults to "best_model". batch_size (int, optional): The batch size for training. Defaults to 4. accum_steps (int, optional): The number of accumulation steps. Defaults to 16. epochs (int, optional): The number of training epochs. Defaults to 40. accelerator (str, optional): The accelerator to use (e.g., "cpu", "gpu"). Defaults to "cpu". gpus (int, optional): The number of GPUs to use. Defaults to 0. Raises: ModelNotCompatibleError: If the provided model is not compatible with the fld pipeline. """ try: if isinstance(model, FLD_COMPATIBLE_MODELS): train_dataset = FLDDataset( dataset_name, "train", task, ) val_dataset = FLDDataset( dataset_name, "val", task, max_samples=100, ) if isinstance(model, FLDAllAtOnceProver): fld_collate_fn = FLDProofGenerationAllCollator( "t5-base", log_examples=False, ) else: raise ValueError() train_dataloader = DataLoader( train_dataset, batch_size=batch_size, collate_fn=fld_collate_fn ) val_dataloader = DataLoader( val_dataset, batch_size=batch_size, collate_fn=fld_collate_fn ) checkpoint_callback = ModelCheckpoint( save_top_k=1, monitor="val_loss", mode="min", dirpath=saved_model_path, filename=saved_model_name, ) trainer = pl.Trainer( callbacks=[checkpoint_callback], auto_lr_find=False, accelerator=accelerator, accumulate_grad_batches=accum_steps, max_epochs=epochs, gpus=gpus, ) trainer.fit(model, train_dataloader, val_dataloader) else: raise ModelNotCompatibleError(FLD_COMPATIBLE_MODELS) except ModelNotCompatibleError as err: print(err.message)