import os
from typing import List, Tuple
from logitorch.datasets.exceptions import DatasetNameError, SplitSetError
from logitorch.datasets.utils import DATASETS_FOLDER, download_dataset, read_jsonl
[docs]
WIKI20K_DATASET_ZIP_URL = (
"https://www.dropbox.com/s/yeh70n6etbg0a95/wiki20k_dataset.zip?dl=1"
)
[docs]
WIKI20K_DATASET = "wiki20k_dataset"
[docs]
WIKI20K_SUB_DATASETS = ["lm_wiki20k", "positive_lm_wiki20k", "negated_lm_wiki20k"]
[docs]
WIKI20K_DATASET_FOLDER = f"{DATASETS_FOLDER}/{WIKI20K_DATASET}"
[docs]
class Wiki20KDataset:
"""
A class representing the Wiki20K dataset for RuleTaker.
Attributes:
dataset_name (str): The name of the dataset.
dataset_path (str): The path to the dataset file.
sentences (List[str]): The list of sentences in the dataset.
labels (List[str]): The list of labels in the dataset.
Methods:
__init__(self, dataset_name: str, size: int = None) -> None:
Initializes a Wiki20KDataset object.
__read_dataset(self, sentences_key: str, labels_key: str, size: int = None) -> Tuple[List[str], List[str], List[int]]:
Reads the dataset file and returns the sentences and labels.
__getitem__(self, index: int) -> Tuple[str, str, int]:
Returns the sentence, label, and index at the given index.
__str__(self) -> str:
Returns a string representation of the dataset.
__len__(self) -> int:
Returns the number of instances in the dataset.
"""
def __init__(self, dataset_name: str, size: int = None) -> None:
"""
Initializes a Wiki20KDataset object.
Args:
dataset_name (str): The name of the dataset.
size (int, optional): The number of instances to load from the dataset. Defaults to None.
"""
super().__init__()
try:
if dataset_name not in WIKI20K_SUB_DATASETS:
raise DatasetNameError()
if not os.path.exists(WIKI20K_DATASET_FOLDER):
download_dataset(WIKI20K_DATASET_ZIP_URL, WIKI20K_DATASET)
self.dataset_name = dataset_name
self.dataset_path = f"{WIKI20K_DATASET_FOLDER}/{self.dataset_name}.jsonl"
self.sentences, self.labels = self.__read_dataset("sentence", "label", size)
except DatasetNameError as err:
print(err.message)
print(f"The RuleTaker datasets are: {WIKI20K_SUB_DATASETS}")
except SplitSetError as err:
print(err.message)
def __read_dataset(
self, sentences_key: str, labels_key: str, size: int = None
) -> Tuple[List[str], List[str], List[int]]:
"""
Reads the Wiki20K dataset.
Args:
sentences_key (str): The key for the sentences in the dataset file.
labels_key (str): The key for the labels in the dataset file.
size (int, optional): The number of instances to read from the dataset. Defaults to None.
Returns:
Tuple[List[str], List[str], List[int]]: A tuple containing the sentences, labels, and indices.
"""
data = read_jsonl(self.dataset_path)
sentences_list = []
labels_list = []
if size is None:
size = len(data)
for i in data[:size]:
sentences_list.append(i[sentences_key])
labels_list.append(i[labels_key])
return sentences_list, labels_list
def __getitem__(self, index: int) -> Tuple[str, str, int]:
"""
Returns the sentence, label, and index at the given index.
Args:
index (int): The index of the instance to retrieve.
Returns:
Tuple[str, str, int]: A tuple containing the sentence, label, and index.
"""
return self.sentences[index], self.labels[index]
def __str__(self) -> str:
"""
Returns a string representation of the dataset.
Returns:
str: A string representation of the dataset.
"""
return (
f"The set of {self.dataset_name}'s RuleTaker has {self.__len__()} instances"
)
def __len__(self) -> int:
"""
Returns the number of instances in the dataset.
Returns:
int: The number of instances in the dataset.
"""
return len(self.sentences)