Source code for logitorch.datasets.qa.ruletaker_dataset

import os
from typing import List, Tuple

from logitorch.datasets.base import AbstractQADataset
from logitorch.datasets.exceptions import DatasetNameError, SplitSetError
from logitorch.datasets.utils import (
    DATASETS_FOLDER,
    SPLIT_SETS,
    download_dataset,
    read_jsonl,
)

[docs] RULETAKER_DATASET_ZIP_URL = ( "https://www.dropbox.com/s/1ix9j86nlgl0ygp/ruletaker_dataset.zip?dl=1" )
[docs] RULETAKER_DATASET = "ruletaker_dataset"
[docs] RULETAKER_SUB_DATASETS = [ "birds-electricity", "depth-0", "depth-1", "depth-2", "depth-3", "depth-3ext", "depth-3ext-NatLang", "depth-5", "NatLang", ]
[docs] RULETAKER_DATASET_FOLDER = f"{DATASETS_FOLDER}/{RULETAKER_DATASET}"
[docs] RULETAKER_LABEL_TO_ID = {"False": 0, "True": 1}
[docs] RULETAKER_ID_TO_LABEL = {0: "False", 1: "True"}
[docs] class RuleTakerDataset(AbstractQADataset): def __init__(self, dataset_name: str, split_set: str) -> None: super().__init__() try: if dataset_name not in RULETAKER_SUB_DATASETS: raise DatasetNameError() if split_set != "test" and dataset_name == "birds-electricity": raise SplitSetError(["test"]) if split_set == "val": split_set = "dev" elif split_set not in SPLIT_SETS: raise SplitSetError(SPLIT_SETS) if not os.path.exists(RULETAKER_DATASET_FOLDER): download_dataset(RULETAKER_DATASET_ZIP_URL, RULETAKER_DATASET) self.dataset_name = dataset_name self.split_set = split_set self.dataset_path = ( f"{RULETAKER_DATASET_FOLDER}/{self.dataset_name}/{self.split_set}.jsonl" ) ( self.contexts, self.questions, self.labels, self.depths, ) = self.__read_dataset("context", "questions", "text", "label") except DatasetNameError as err: print(err.message) print(f"The RuleTaker datasets are: {RULETAKER_SUB_DATASETS}") except SplitSetError as err: print(err.message) def __read_dataset( self, contexts_key: str, questions_key: str, questions_text_key: str, labels_key: str, ) -> Tuple[List[str], List[str], List[int], List[int]]: data = read_jsonl(self.dataset_path) contexts_list = [] questions_list = [] labels_list = [] depths_list = [] for i in data: for q in i[questions_key]: contexts_list.append(i[contexts_key]) questions_list.append(q[questions_text_key]) labels_list.append(RULETAKER_LABEL_TO_ID[str(q[str(labels_key)])]) depths_list.append(q["meta"]["QDep"]) return contexts_list, questions_list, labels_list, depths_list def __getitem__(self, index: int) -> Tuple[str, str, int, int]: return ( self.contexts[index], self.questions[index], self.labels[index], self.depths[index], ) def __str__(self) -> str: return f"The {self.split_set} set of {self.dataset_name}'s RuleTaker has {self.__len__()} instances" def __len__(self) -> int: return len(self.contexts)