Source code for logitorch.datasets.proof_qa.proofwriter_dataset

import os
import re
from typing import Dict, List, Optional, Tuple, Union

from logitorch.datasets.base import AbstractProofQADataset
from logitorch.datasets.exceptions import (
    AbductionClosedWorldAssumptionError,
    DatasetNameError,
    SplitSetError,
    TaskError,
)
from logitorch.datasets.utils import (
    DATASETS_FOLDER,
    SPLIT_SETS,
    download_dataset,
    read_jsonl,
)

[docs] PROOFWRITER_DATASET_ZIP_URL = ( "https://www.dropbox.com/s/5dpm2yjcm5h9hsd/proofwriter_dataset.zip?dl=1" )
[docs] PROOFWRITER_DATASET = "proofwriter_dataset"
[docs] PROOFWRITER_SUB_DATASETS = [ "birds-electricity", "depth-0", "depth-1", "depth-2", "depth-3", "depth-3ext", "depth-3ext-NatLang", "depth-5", "NatLang", ]
[docs] PROOFWRITER_TASKS = [ "proof_generation_all", "proof_generation_iter", "implication_enumeration", "abduction", ]
[docs] PROOFWRITER_DATASET_FOLDER = f"{DATASETS_FOLDER}/{PROOFWRITER_DATASET}"
[docs] PROOFWRITER_LABEL_TO_ID = {"False": 0, "True": 1, "Unknown": 2}
[docs] PROOFWRITER_ID_TO_LABEL = {0: "False", 1: "True", 2: "Unknown"}
[docs] class ProofWriterDataset(AbstractProofQADataset): def __init__( self, dataset_name: str, split_set: str, task: str, open_world_assumption: bool = False, ) -> None: try: if dataset_name not in PROOFWRITER_SUB_DATASETS: raise DatasetNameError() if split_set != "test" and dataset_name == "birds-electricity": raise SplitSetError(["test"]) if split_set == "val": split_set = "dev" elif split_set not in SPLIT_SETS: raise SplitSetError(SPLIT_SETS) if task not in PROOFWRITER_TASKS: raise TaskError() if not os.path.exists(PROOFWRITER_DATASET_FOLDER): download_dataset(PROOFWRITER_DATASET_ZIP_URL, PROOFWRITER_DATASET) self.dataset_name = dataset_name self.split_set = split_set self.task = task self.world_assumption = "CWA" if open_world_assumption: self.world_assumption = "OWA" if self.world_assumption == "CWA" and self.task == "abduction": raise AbductionClosedWorldAssumptionError() if self.task == "proof_generation_all": self.dataset_path = f"{PROOFWRITER_DATASET_FOLDER}/{self.world_assumption}/{self.dataset_name}/meta-{self.split_set}.jsonl" ( self.triples, self.rules, self.questions, self.labels, self.proofs, self.proofs_intermerdiates, self.depths, ) = self.__read_dataset_proof_generation_all( "triples", "rules", "questions" ) elif self.task == "proof_generation_iter": self.dataset_path = f"{PROOFWRITER_DATASET_FOLDER}/{self.world_assumption}/{self.dataset_name}/meta-stage-{self.split_set}.jsonl" ( self.triples, self.rules, self.labels, self.proofs, ) = self.__read_dataset_proof_generation_iter( "triples", "rules", "allInferences" ) elif self.task == "implication_enumeration": self.dataset_path = f"{PROOFWRITER_DATASET_FOLDER}/{self.world_assumption}/{self.dataset_name}/meta-{self.split_set}.jsonl" ( self.triples, self.rules, self.labels, ) = self.__read_dataset_implication_enumeration( "triples", "rules", "proofDetails" ) elif self.task == "abduction": self.dataset_path = f"{PROOFWRITER_DATASET_FOLDER}/{self.world_assumption}/{self.dataset_name}/meta-abduct-{self.split_set}.jsonl" ( self.triples, self.rules, self.questions, self.labels, self.proofs, ) = self.__read_dataset_abduction("triples", "rules", "abductions") except DatasetNameError as err: print(err.message) print(f"The ProofWriter datasets are: {PROOFWRITER_SUB_DATASETS}") except SplitSetError as err: print(err.message) except TaskError as err: print(err.message) print(f"The ProofWriter tasks are: {PROOFWRITER_TASKS}") except AbductionClosedWorldAssumptionError as err: print(err.message) def __read_dataset_proof_generation_all( self, triples_key: str, rules_key: str, questions_key: str, ) -> Tuple[ List[Dict[str, str]], List[Dict[str, str]], List[List[str]], List[List[str]], List[List[str]], List[List[str]], List[List[int]], ]: data = read_jsonl(self.dataset_path) triples_list = [] rules_list = [] questions_list = [] labels_list = [] proofs_list = [] proofs_intermediates_list = [] depths_list = [] proofs_intermediates_key = "proofsWithIntermediates" for i in data: triples = {} rules = {} questions = [] proofs = [] proofs_intermediates = [] labels = [] depths = [] for t, val in i[triples_key].items(): triples[t] = val["text"] for r, val in i[rules_key].items(): rules[r] = val["text"] for q in i[questions_key].values(): questions.append(q["question"]) if proofs_intermediates_key in q: tmp_proof = [] for p in q[proofs_intermediates_key]: str_proof = f"{p['representation']}" if len(p["intermediates"]) > 0: str_proof += " ; " str_proof += "with " for intr, val in p["intermediates"].items(): str_proof += f"{intr} = {val['text']}" tmp_proof.append(str_proof) labels.append(q["answer"]) proofs.append(q["proofs"]) proofs_intermediates.append(tmp_proof) depths.append(q["QDep"]) for q, l, p, p_i, d in zip( questions, labels, proofs, proofs_intermediates, depths ): triples_list.append(triples) rules_list.append(rules) questions_list.append(q) labels_list.append(l) proofs_list.append(p) proofs_intermediates_list.append(p_i) depths_list.append(d) # print( # triples_list[0], # rules_list[0], # questions_list[0], # labels_list[0], # proofs_list[0], # ) return ( triples_list, rules_list, questions_list, labels_list, proofs_list, proofs_intermediates_list, depths_list, ) def __read_dataset_proof_generation_iter( self, triples_key: str, rules_key: str, proofs_key: str ) -> Tuple[ List[Dict[str, str]], List[Dict[str, str]], List[List[str]], List[List[str]], ]: data = read_jsonl(self.dataset_path) triples_list = [] rules_list = [] labels_list = [] proofs_list = [] for i in data: triples = {} rules = {} inferences = [] proofs = [] for t, val in i[triples_key].items(): triples[t] = val["text"] for r, val in i[rules_key].items(): rules[r] = val["text"] for val in i[proofs_key]: inferences.append(val["text"]) proofs.append(val["proofs"]) triples_list.append(triples) rules_list.append(rules) if len(inferences) > 0: labels_list.append(inferences) proofs_list.append(proofs) else: labels_list.append([None]) proofs_list.append([None]) # print(triples_list[30], rules_list[30], labels_list[30], proofs_list[30]) return triples_list, rules_list, labels_list, proofs_list def __read_dataset_implication_enumeration( self, triples_key: str, rules_key: str, labels_key: str ) -> Tuple[List[Dict[str, str]], List[Dict[str, str]], List[List[str]]]: data = read_jsonl(self.dataset_path) triples_list = [] rules_list = [] labels_list = [] for i in data: triples = {} rules = {} labels = [] for t, val in i[triples_key].items(): triples[t] = val["text"] for r, val in i[rules_key].items(): rules[r] = val["text"] for p in i[labels_key]: labels.append(p["text"]) triples_list.append(triples) rules_list.append(rules) if len(labels) > 0: labels_list.append(labels) else: labels_list.append([None]) print(triples_list[30], rules_list[30], labels_list[30]) return triples_list, rules_list, labels_list def __read_dataset_abduction( self, triples_key: str, rules_key: str, abductions_key: str ) -> Tuple[ List[Dict[str, str]], List[Dict[str, str]], List[List[str]], List[List[str]], List[List[str]], ]: data = read_jsonl(self.dataset_path) triples_list = [] rules_list = [] questions_list = [] labels_list = [] proofs_list = [] for i in data: triples = {} rules = {} questions = [] labels = [] proofs = [] for t, val in i[triples_key].items(): triples[t] = val["text"] for r, val in i[rules_key].items(): rules[r] = val["text"] for abduc in i[abductions_key].values(): questions.append(abduc["question"]) tmp_labels = [] tmp_proofs = [] if len(abduc["answers"]) > 0: for answer in abduc["answers"]: tmp_labels.append(answer["text"]) tmp_proofs.append(answer["proof"]) else: tmp_labels.append(None) tmp_proofs.append(None) labels.append(tmp_labels) proofs.append(tmp_proofs) for q, l, p in zip(questions, labels, proofs): triples_list.append(triples) rules_list.append(rules) questions_list.append(q) labels_list.append(l) proofs_list.append(p) # print(triples_list[2150], rules_list[1250], questions_list[1250], labels_list[1250], proofs_list[1250]) return triples_list, rules_list, questions_list, labels_list, proofs_list def __getitem__( self, index: int ) -> Union[ Tuple[ Dict[str, str], Dict[str, str], List[str], List[str], List[str], List[str], List[int], ], Tuple[Dict[str, str], Dict[str, str], List[str], List[str], List[str]], Tuple[Dict[str, str], Dict[str, str], List[str], List[str]], Tuple[Dict[str, str], Dict[str, str], List[str]], ]: if self.task == "proof_generation_all": return ( self.triples[index], self.rules[index], self.questions[index], self.labels[index], self.proofs[index], self.proofs_intermerdiates[index], self.depths[index], ) elif self.task == "abduction": return ( self.triples[index], self.rules[index], self.questions[index], self.labels[index], self.proofs[index], ) elif self.task == "proof_generation_iter": return ( self.triples[index], self.rules[index], self.labels[index], self.proofs[index], ) else: return self.triples[index], self.rules[index], self.labels[index] def __str__(self) -> str: return f'The {self.split_set} set of {self.dataset_name}\'s ProofWriter for the task of "{self.task}" has {self.__len__()} instances' def __len__(self) -> int: return len(self.triples)
[docs] class FaiRRProofWriterDataset: def __init__( self, dataset_name: str, split_set: str, open_world_assumption: bool = False, ) -> None: try: if dataset_name not in PROOFWRITER_SUB_DATASETS: raise DatasetNameError() if split_set != "test" and dataset_name == "birds-electricity": raise SplitSetError(["test"]) if split_set == "val": split_set = "dev" elif split_set not in SPLIT_SETS: raise SplitSetError(SPLIT_SETS) if not os.path.exists(PROOFWRITER_DATASET_FOLDER): download_dataset(PROOFWRITER_DATASET_ZIP_URL, PROOFWRITER_DATASET) self.dataset_name = dataset_name self.split_set = split_set self.world_assumption = "CWA" if open_world_assumption: self.world_assumption = "OWA" self.dataset_path = f"{PROOFWRITER_DATASET_FOLDER}/{self.world_assumption}/{self.dataset_name}/meta-stage-{self.split_set}.jsonl" self.no_staged_dataset_path = f"{PROOFWRITER_DATASET_FOLDER}/{self.world_assumption}/{self.dataset_name}/meta-{self.split_set}.jsonl" ( self.triples, self.rules, self.questions, self.labels, self.proofs, ) = self.__read_dataset_proof_generation_iter( "triples", "rules", "allInferences" ) except DatasetNameError as err: print(err.message) print(f"The ProofWriter datasets are: {PROOFWRITER_SUB_DATASETS}") except SplitSetError as err: print(err.message) def __read_dataset_proof_generation_iter( self, triples_key: str, rules_key: str, proofs_key: str ) -> Tuple[ List[Dict[str, str]], List[Dict[str, str]], List[List[str]], List[List[str]], ]: data = read_jsonl(self.dataset_path) no_staged_data = read_jsonl(self.no_staged_dataset_path) questions = {} for i in no_staged_data: questions_tmp = [] for q in i["questions"].values(): questions_tmp.append(q["question"]) questions[i["id"]] = questions_tmp triples_list = [] rules_list = [] questions_list = [] labels_list = [] proofs_list = [] for i in data: id = re.sub("-add[0-9]+", "", i["id"]) for q in questions[id]: triples = {} rules = {} inferences = [] proofs = [] for t, val in i[triples_key].items(): triples[t] = val["text"] for r, val in i[rules_key].items(): rules[r] = val["text"] for val in i[proofs_key]: inferences.append(val["text"]) proofs.append(val["proofs"]) triples_list.append(triples) rules_list.append(rules) questions_list.append(q) if len(inferences) > 0: labels_list.append(inferences) proofs_list.append(proofs) else: labels_list.append([None]) proofs_list.append([None]) return triples_list, rules_list, questions_list, labels_list, proofs_list def __getitem__(self, index: int): return ( self.triples[index], self.rules[index], self.questions[index], self.labels[index], self.proofs[index], ) def __str__(self) -> str: return f'The {self.split_set} set of {self.dataset_name}\'s ProofWriter for the task of "{self.task}" has {self.__len__()} instances' def __len__(self) -> int: return len(self.triples)