logitorch.pipelines.proof_qa_pipelines ====================================== .. py:module:: logitorch.pipelines.proof_qa_pipelines Attributes ---------- .. autoapisummary:: logitorch.pipelines.proof_qa_pipelines.FLD_COMPATIBLE_MODELS logitorch.pipelines.proof_qa_pipelines.PROOFWRITER_COMPATIBLE_MODELS Functions --------- .. autoapisummary:: logitorch.pipelines.proof_qa_pipelines.fld_pipeline logitorch.pipelines.proof_qa_pipelines.proofwriter_pipeline Module Contents --------------- .. py:function:: fld_pipeline(model: torch.nn.Module, dataset_name: str, task: str = 'proof_generation_all', saved_model_path: str = '/', saved_model_name: str = 'best_model', batch_size: int = 4, accum_steps: int = 16, epochs: int = 40, accelerator: str = 'cpu', gpus: int = 0) Executes the fld pipeline for training a proof generation model. Args: model (nn.Module): The proof generation model. dataset_name (str): The name of the dataset. task (str, optional): The task to perform. Defaults to "proof_generation_all". saved_model_path (str, optional): The path to save the trained model. Defaults to "/". saved_model_name (str, optional): The name of the saved model. Defaults to "best_model". batch_size (int, optional): The batch size for training. Defaults to 4. accum_steps (int, optional): The number of accumulation steps. Defaults to 16. epochs (int, optional): The number of training epochs. Defaults to 40. accelerator (str, optional): The accelerator to use (e.g., "cpu", "gpu"). Defaults to "cpu". gpus (int, optional): The number of GPUs to use. Defaults to 0. Raises: ModelNotCompatibleError: If the provided model is not compatible with the fld pipeline. .. py:function:: proofwriter_pipeline(model: torch.nn.Module, dataset_name: str, task: str = 'proof_generation_all', open_world_assumption: bool = False, saved_model_path: str = '/', saved_model_name: str = 'best_model', batch_size: int = 1, epochs: int = 1, accelerator: str = 'cpu', gpus: int = 0) Executes the proofwriter pipeline for training a proof generation model. Args: model (nn.Module): The proof generation model. dataset_name (str): The name of the dataset. task (str, optional): The task to perform. Defaults to "proof_generation_all". open_world_assumption (bool, optional): Whether to use open world assumption. Defaults to False. saved_model_path (str, optional): The path to save the trained model. Defaults to "/". saved_model_name (str, optional): The name of the saved model. Defaults to "best_model". batch_size (int, optional): The batch size for training. Defaults to 1. epochs (int, optional): The number of training epochs. Defaults to 1. accelerator (str, optional): The accelerator to use (e.g., "cpu", "gpu"). Defaults to "cpu". gpus (int, optional): The number of GPUs to use. Defaults to 0. Raises: ModelNotCompatibleError: If the provided model is not compatible with the proofwriter pipeline. .. py:data:: FLD_COMPATIBLE_MODELS .. py:data:: PROOFWRITER_COMPATIBLE_MODELS