Source code for logitorch.models.prover

import numpy as np
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.nn.init import xavier_normal_
from transformers import RobertaModel
from transformers.models.roberta.modeling_roberta import RobertaClassificationHead

from logitorch.data_collators.prover_collator import PRoverProofWriterCollator


class _NodeClassificationHead(nn.Module):
    def __init__(self, config):
        """
        Initializes the node classification head module.

        Args:
            config: The configuration object.
        """
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.out_proj = nn.Linear(config.hidden_size, config.num_labels)

        xavier_normal_(self.dense.weight)
        xavier_normal_(self.out_proj.weight)

    def forward(self, features, **kwargs):
        """
        Forward pass of the node classification head module.

        Args:
            features: The input features.

        Returns:
            The output tensor.
        """
        x = self.dropout(features)
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x


class _EdgeClassificationHead(nn.Module):
    def __init__(self, config):
        """
        Initializes the edge classification head module.

        Args:
            config: The configuration object.
        """
        super().__init__()
        self.dense = nn.Linear(3 * config.hidden_size, config.hidden_size)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.out_proj = nn.Linear(config.hidden_size, config.num_labels)

        xavier_normal_(self.dense.weight)
        xavier_normal_(self.out_proj.weight)

    def forward(self, features, **kwargs):
        """
        Forward pass of the edge classification head module.

        Args:
            features: The input features.

        Returns:
            The output tensor.
        """
        x = self.dropout(features)
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x


[docs] class PRover(nn.Module): def __init__(self, pretrained_roberta_model: str, num_labels: int = 2) -> None: """ Initializes the PRover model. Args: pretrained_roberta_model: The path or name of the pretrained RoBERTa model. num_labels: The number of labels for classification. """ super().__init__()
[docs] self.num_labels = num_labels
[docs] self.num_labels_edge = num_labels
[docs] self.proofwriter_collator = PRoverProofWriterCollator(pretrained_roberta_model)
[docs] self.encoder = RobertaModel.from_pretrained(pretrained_roberta_model)
[docs] self.config = self.encoder.config
[docs] self.naf_layer = nn.Linear(self.config.hidden_size, self.config.hidden_size)
[docs] self.classifier = RobertaClassificationHead(self.config)
[docs] self.classifier_node = _NodeClassificationHead(self.config)
[docs] self.classifier_edge = _EdgeClassificationHead(self.config)
[docs] def forward( self, x, proof_offsets=None, node_labels=None, edge_labels=None, qa_labels=None, max_node_length=None, max_edge_length=None, device: str = "cpu", ): """ Forward pass of the PRover model. Args: x: The input tensor. proof_offsets: The proof offsets. node_labels: The node labels. edge_labels: The edge labels. qa_labels: The QA labels. max_node_length: The maximum node length. max_edge_length: The maximum edge length. device: The device to run the model on. Returns: The model outputs. """ outputs = self.encoder(**x) sequence_outputs = outputs[0] cls_outputs = sequence_outputs[:, 0, :] naf_outputs = self.naf_layer(cls_outputs) logits = self.classifier(sequence_outputs) if max_node_length is None: max_node_length = node_labels.shape[1] if max_edge_length is None: max_edge_length = edge_labels.shape[1] if node_labels is None: batch_size = 1 else: batch_size = node_labels.shape[0] embedding_dim = sequence_outputs.shape[2] batch_node_embedding = torch.zeros( (batch_size, max_node_length, embedding_dim) ).to(device) batch_edge_embedding = torch.zeros( (batch_size, max_edge_length, 3 * embedding_dim) ).to(device) for batch_index in range(batch_size): prev_index = 1 sample_node_embedding = None count = 0 for offset in proof_offsets[batch_index][1:]: if offset == 0: break else: rf_embedding = torch.mean( sequence_outputs[batch_index, prev_index : (offset + 1), :], dim=0, ).unsqueeze(0) prev_index = offset + 1 count += 1 if sample_node_embedding is None: sample_node_embedding = rf_embedding else: sample_node_embedding = torch.cat( (sample_node_embedding, rf_embedding), dim=0 ) sample_node_embedding = torch.cat( (sample_node_embedding, naf_outputs[batch_index].unsqueeze(0)), dim=0 ) repeat1 = sample_node_embedding.unsqueeze(0).repeat( len(sample_node_embedding), 1, 1 ) repeat2 = sample_node_embedding.unsqueeze(1).repeat( 1, len(sample_node_embedding), 1 ) sample_edge_embedding = torch.cat( (repeat1, repeat2, (repeat1 - repeat2)), dim=2 ) sample_edge_embedding = sample_edge_embedding.view( -1, sample_edge_embedding.shape[-1] ) if sample_node_embedding.shape[0] < max_node_length: sample_node_embedding = torch.cat( ( sample_node_embedding, torch.zeros((max_node_length - count - 1, embedding_dim)).to( device ), ), dim=0, ) sample_edge_embedding = torch.cat( ( sample_edge_embedding, torch.zeros( ( max_edge_length - len(sample_edge_embedding), 3 * embedding_dim, ) ).to(device), ), dim=0, ) batch_node_embedding[batch_index, :, :] = sample_node_embedding batch_edge_embedding[batch_index, :, :] = sample_edge_embedding node_logits = self.classifier_node(batch_node_embedding) edge_logits = self.classifier_edge(batch_edge_embedding) outputs = (logits, node_logits, edge_logits) + outputs[2:] if qa_labels is not None: loss_fct = CrossEntropyLoss() qa_loss = loss_fct(logits.view(-1, self.num_labels), qa_labels.view(-1)) node_loss = loss_fct( node_logits.view(-1, self.num_labels), node_labels.view(-1) ) edge_loss = loss_fct( edge_logits.view(-1, self.num_labels_edge), edge_labels.view(-1) ) total_loss = qa_loss + node_loss + edge_loss outputs = (total_loss, qa_loss, node_loss, edge_loss) + outputs return outputs
[docs] def predict(self, triples, rules, question, device: str = "cpu"): """ Predicts the label for a given question. Args: triples: The triples. rules: The rules. question: The question. device: The device to run the model on. Returns: The predicted label. """ with torch.no_grad(): context_tokens = [] proof_offset = [] sentences = ["<s>"] nfact = len(triples) nrule = len(rules) node_length = nfact + nrule + 1 edge_length = node_length**2 for s in triples.values(): sentences.append(s) for s in rules.values(): sentences.append(s) for s in sentences: sentence_tokens = self.proofwriter_collator.tokenizer.tokenize(s) context_tokens.extend(sentence_tokens) proof_offset.append(len(context_tokens)) sentences.append("</s>") sentences.append("</s>") sentences.append(question) sentences.append("</s>") context = "".join(sentences) proofs_offsets = torch.tensor([proof_offset]) tokenized_context = self.proofwriter_collator.tokenizer( [context], add_special_tokens=False, padding=True, return_tensors="pt" ) logits = self( tokenized_context.to(device), proofs_offsets.to(device), max_node_length=node_length, max_edge_length=edge_length, device=device, ) pred_qa_label = logits[0].argmax() return pred_qa_label.item()