Source code for logitorch.data_collators.prover_collator

from typing import Dict, List, Tuple

import numpy as np
import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import RobertaTokenizer

from logitorch.datasets.proof_qa.proofwriter_dataset import PROOFWRITER_LABEL_TO_ID


[docs] class Node: def __init__(self, head: str) -> None:
[docs] self.head = head
def __str__(self) -> str: return str(self.head)
[docs] class PRoverProofWriterCollator: """ A collator class for processing data in the PRoverProofWriter format. Args: | pretrained_roberta_tokenizer (str): The name or path of the pretrained RoBERTa tokenizer. Attributes: | tokenizer (RobertaTokenizer): The pretrained RoBERTa tokenizer. Methods: | get_proof_graph_with_fail(proof_str: str) -> Tuple[List[str], List[str]]: Extracts the proof graph and edges from a proof string with a "FAIL" node. | get_proof_graph(proof_str: str) -> Tuple[List[str], List[Tuple[str, str]]]: Extracts the proof graph and edges from a proof string. | get_node_edge_label_constrained(x: str) -> Tuple[List[int], List[np.ndarray]]: Generates node and edge labels for a given input. | __call__(batch) -> Tuple[Dict[str, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: Collates a batch of data into tokenized tensors. """ def __init__(self, pretrained_roberta_tokenizer: str) -> None:
[docs] self.tokenizer = RobertaTokenizer.from_pretrained(pretrained_roberta_tokenizer)
[docs] def get_proof_graph_with_fail(self, proof_str: str) -> Tuple[List[str], List[str]]: """ Extracts the proof graph and edges from a proof string with a "FAIL" node. Args: | proof_str (str): The proof string. Returns: | Tuple[List[str], List[str]]: A tuple containing the list of nodes and the list of edges. """ proof_str = proof_str[:-2].split("=")[1].strip()[1:-1] nodes = proof_str.split(" <- ") all_nodes = [] all_edges = [] for i in range(len(nodes) - 1): all_nodes.append(nodes[i]) if nodes[i + 1] != "FAIL": all_edges.append((nodes[i + 1], nodes[i])) return all_nodes, all_edges
[docs] def get_proof_graph( self, proof_str: str ) -> Tuple[List[str], List[Tuple[str, str]]]: """ Extracts the proof graph and edges from a proof string. Args: | proof_str (str): The proof string. Returns: | Tuple[List[str], List[Tuple[str, str]]]: A tuple containing the list of nodes and the list of edges. """ stack = [] last_open = 0 last_open_index = 0 pop_list = [] all_edges = [] all_nodes = [] proof_str = proof_str.replace("(", " ( ") proof_str = proof_str.replace(")", " ) ") proof_str = proof_str.split() should_join = False for i in range(len(proof_str)): _s = proof_str[i] x = _s.strip() if len(x) == 0: continue if x == "(": stack.append((x, i)) last_open = len(stack) - 1 last_open_index = i elif x == ")": for j in range(last_open + 1, len(stack)): if isinstance(stack[j][0], Node): pop_list.append((stack[j][1], stack[j][0])) stack = stack[:last_open] for j in range((len(stack))): if stack[j][0] == "(": last_open = j last_open_index = stack[j][1] elif x == "[" or x == "]": pass elif x == "->": should_join = True else: # terminal if x not in all_nodes: all_nodes.append(x) if should_join: new_pop_list = [] # Choose which ones to add the node to for (index, p) in pop_list: if index < last_open_index: new_pop_list.append((index, p)) else: all_edges.append((p.head, x)) pop_list = new_pop_list stack.append((Node(x), i)) should_join = False return all_nodes, all_edges
[docs] def get_node_edge_label_constrained( self, x: str ) -> Tuple[List[int], List[np.ndarray]]: """ Generates node and edge labels for a given input. Args: | x (str): The input. Returns: | Tuple[List[int], List[np.ndarray]]: A tuple containing the list of node labels and the list of edge labels. """ proofs = x[4] nrule = len(x[1]) sentence_scramble = [i[0] + 1 for i in enumerate(x[0])] nfact = len(sentence_scramble) sentence_scramble += [nfact + 1 + i[0] for i in enumerate(x[1])] proof = proofs.split("OR")[0] node_label = [0] * (nfact + nrule + 1) edge_label = np.zeros((nfact + nrule + 1, nfact + nrule + 1), dtype=int) if "FAIL" in proof: nodes, edges = self.get_proof_graph_with_fail(proof) else: nodes, edges = self.get_proof_graph(proof) component_index_map = {} for (i, index) in enumerate(sentence_scramble): if index <= nfact: component = "triple" + str(index) else: component = "rule" + str(index - nfact) component_index_map[component] = i component_index_map["NAF"] = nfact + nrule for node in nodes: index = component_index_map[node] node_label[index] = 1 edges = list(set(edges)) for edge in edges: start_index = component_index_map[edge[0]] end_index = component_index_map[edge[1]] edge_label[start_index][end_index] = 1 # Mask impossible edges for i in range(len(edge_label)): for j in range(len(edge_label)): # Ignore diagonal if i == j: edge_label[i][j] = -100 continue # Ignore edges between non-nodes if node_label[i] == 0 or node_label[j] == 0: edge_label[i][j] = -100 continue is_fact_start = False is_fact_end = False if i == len(edge_label) - 1 or sentence_scramble[i] <= nfact: is_fact_start = True if j == len(edge_label) - 1 or sentence_scramble[j] <= nfact: is_fact_end = True # No edge between fact/NAF -> fact/NAF if is_fact_start and is_fact_end: edge_label[i][j] = -100 continue # No edge between Rule -> fact/NAF if not is_fact_start and is_fact_end: edge_label[i][j] = -100 continue return node_label, list(edge_label.flatten())
def __call__( self, batch ) -> Tuple[ Dict[str, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor ]: """ Collates a batch of data into tokenized tensors. Args: | batch: The batch of data. Returns: | Tuple[Dict[str, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing the tokenized batch, proof offsets, node labels, edge labels, and labels. """ contexts = [] proofs_offsets = [] node_labels = [] edge_labels = [] labels = [] for i in batch: context_tokens = [] proof_offset = [] sentences = ["<s>"] for s in i[0].values(): sentences.append(s) for s in i[1].values(): sentences.append(s) for s in sentences: sentence_tokens = self.tokenizer.tokenize(s) context_tokens.extend(sentence_tokens) proof_offset.append(len(context_tokens)) sentences.append("</s>") sentences.append("</s>") sentences.append(i[2]) sentences.append("</s>") contexts.append("".join(sentences)) proofs_offsets.append(torch.tensor(proof_offset)) node_label, edge_label = self.get_node_edge_label_constrained(i) node_labels.append(torch.tensor(node_label)) edge_labels.append(torch.LongTensor(edge_label)) labels.append(PROOFWRITER_LABEL_TO_ID[str(i[3])]) tokenized_batch = self.tokenizer( contexts, add_special_tokens=False, padding=True, return_tensors="pt" ) proofs_offsets = pad_sequence(proofs_offsets, batch_first=True) node_labels = pad_sequence(node_labels, batch_first=True, padding_value=-100) edge_labels = pad_sequence(edge_labels, batch_first=True, padding_value=-100) labels = torch.tensor(labels) return tokenized_batch, proofs_offsets, node_labels, edge_labels, labels