Source code for logitorch.datasets.mcqa.arlsat_dataset

import os
from typing import List, Tuple

from logitorch.datasets.base import AbstractMCQADataset
from logitorch.datasets.exceptions import SplitSetError
from logitorch.datasets.utils import (
    DATASETS_FOLDER,
    SPLIT_SETS,
    download_dataset,
    read_json,
)

[docs] ARLSAT_DATASET_ZIP_URL = ( "https://www.dropbox.com/s/yuaoz1kon66w2o6/arlsat_dataset.zip?dl=1" )
[docs] ARLSAT_DATASET = "arlsat_dataset"
[docs] ARLSAT_DATASET_FOLDER = f"{DATASETS_FOLDER}/{ARLSAT_DATASET}"
[docs] ARLSAT_LABEL_TO_ID = {"A": 0, "B": 1, "C": 2, "D": 3, "E": 4}
[docs] ARLSAT_ID_TO_LABEL = {0: "A", 1: "B", 2: "C", 3: "D", 4: "E"}
[docs] class ARLSATDataset(AbstractMCQADataset): """ ARLSAT dataset for multiple-choice question answering. """ def __init__(self, split_set: str) -> None: super().__init__() try: if split_set not in SPLIT_SETS: raise SplitSetError(SPLIT_SETS) if not os.path.exists(ARLSAT_DATASET_FOLDER): download_dataset(ARLSAT_DATASET_ZIP_URL, ARLSAT_DATASET) self.split_set = split_set self.dataset_path = f"{ARLSAT_DATASET_FOLDER}/{self.split_set}.json" ( self.contexts, self.questions, self.answers, self.labels, ) = self.__read_dataset( "passage", "questions", "question", "options", "answer" ) except SplitSetError as err: print(err.message) def __read_dataset( self, contexts_key: str, questions_key: str, questions_text_key: str, answers_key: str, labels_key: str, ) -> Tuple[List[str], List[str], List[List[str]], List[int]]: """ Reads the ARLSAT dataset. Args: contexts_key (str): The key for the contexts in the JSON file. questions_key (str): The key for the questions in the JSON file. questions_text_key (str): The key for the question text in the JSON file. answers_key (str): The key for the answers in the JSON file. labels_key (str): The key for the labels in the JSON file. Returns: Tuple[List[str], List[str], List[List[str]], List[int]]: A tuple containing the contexts, questions, answers, and labels of the dataset. """ data = read_json(self.dataset_path) contexts_list = [] questions_list = [] answers_list = [] labels_list = [] for i in data: for q in i[questions_key]: tmp_answers = [] for a in q[answers_key]: tmp_answers.append(a) contexts_list.append(i[contexts_key]) questions_list.append(q[questions_text_key]) answers_list.append(tmp_answers) labels_list.append(ARLSAT_LABEL_TO_ID[q[labels_key]]) return contexts_list, questions_list, answers_list, labels_list def __getitem__(self, index: int) -> Tuple[str, str, List[str], int]: """ Returns the item at the given index. Args: index (int): The index of the item. Returns: Tuple[str, str, List[str], int]: A tuple containing the context, question, answers, and label of the item at the given index. """ return ( self.contexts[index], self.questions[index], self.answers[index], self.labels[index], ) def __str__(self) -> str: """ Returns a string representation of the dataset. Returns: str: A string representation of the dataset. """ return f"The {self.split_set} set of ARLSAT has {self.__len__()} instances" def __len__(self) -> int: """ Returns the number of instances in the dataset. Returns: int: The number of instances in the dataset. """ return len(self.contexts)