Source code for logitorch.datasets.base

from abc import ABC, abstractmethod
from typing import Dict, List, Tuple, Union, Optional

from torch.utils.data import Dataset


[docs] class BaseLogicDataset(Dataset, ABC): def __init__(self) -> None: super().__init__() @abstractmethod def __len__(self) -> int: raise NotImplementedError() @abstractmethod def __str__(self) -> str: raise NotImplementedError()
[docs] class AbstractMCQADataset(BaseLogicDataset): @abstractmethod def __getitem__( self, index: int ) -> Union[Tuple[str, str, List[str], int], Tuple[str, str, List[str]]]: raise NotImplementedError()
[docs] class AbstractTEDataset(BaseLogicDataset): @abstractmethod def __getitem__(self, index: int) -> Tuple[str, str, int]: raise NotImplementedError()
[docs] class AbstractQADataset(BaseLogicDataset): @abstractmethod def __getitem__( self, index: int ) -> Union[ Tuple[str, str, int], Tuple[str, str, str], Tuple[str, str, int, int], Tuple[List[str], str, int, List[str]], ]: raise NotImplementedError()
[docs] class AbstractProofQADataset(BaseLogicDataset): @abstractmethod def __getitem__( self, index: int ) -> Union[ Tuple[ Dict[str, str], Dict[str, str], List[str], List[str], List[str], List[str] ], Tuple[Dict[str, str], Dict[str, str], List[str], List[str], List[str]], Tuple[Dict[str, str], Dict[str, str], List[str], List[str]], Tuple[Dict[str, str], Dict[str, str], List[str]], Dict[str, Union[Optional[str], Optional[int]]], ]: raise NotImplementedError()