Source code for rllm.nn.models.bridge

from typing import Any, Dict, List, Type, Union

import torch
from torch import Tensor
import torch.nn.functional as F

from rllm.types import ColType
from rllm.data import TableData
from rllm.nn.encoder import TabTransformerPreEncoder
from rllm.nn.conv.table_conv import TabTransformerConv
from rllm.nn.conv.graph_conv import GCNConv


class TableEncoder(torch.nn.Module):
    r"""TableEncoder is a submodule of the BRIDGE method,
    which mainly performs multi-layer convolution of the incoming table.
    The TableEncoder takes as input :class:`rllm.data.TableData` representing
    the tabular data and applies multiple convolutional layers to capture
    complex patterns and relationships within the data. Before outputting,
    the feature dictionary is concatenated to facilitate subsequent operations.

    Args:
        in_dim (int): Input dimensionality of the table data.
        out_dim (int): Output dimensionality for the encoded table data.
        num_layers (int, optional):
            Number of convolution layers (default: :obj:`1`).
        metadata (Dict[ColType, List[Dict[str, Any]]], optional):
            Metadata for each column type, specifying the statistics and
            properties of the columns. (default: :obj:`None`).
        table_conv (Type[torch.nn.Module], optional):
            The convolution module to be used for encoding the table data
            (default: :obj:`rllm.nn.conv.table_conv.TabTransformerConv`).
    """

    def __init__(
        self,
        in_dim: int,
        out_dim: int,
        num_layers: int = 1,
        metadata: Dict[ColType, List[Dict[str, Any]]] = None,
        table_conv: Type[torch.nn.Module] = TabTransformerConv,
    ) -> None:

        super().__init__()

        self.convs = torch.nn.ModuleList()
        self.pre_encoder = TabTransformerPreEncoder(out_dim=out_dim, metadata=metadata)
        for _ in range(num_layers):
            self.convs.append(table_conv(conv_dim=out_dim))

    def forward(self, table: TableData) -> Tensor:
        x = table.feat_dict
        x = self.pre_encoder(x, return_dict=True)
        for conv in self.convs:
            x = conv(x)
        x = torch.cat(list(x.values()), dim=1)
        x = x.mean(dim=1)
        return x


class GraphEncoder(torch.nn.Module):
    r"""GraphEncoder is a submodule of the BRIDGE method,
    which mainly performs multi-layer convolution of the incoming graph.
    This submodule is designed to handle graph-structured data. And it takes
    as input two tensor representing the node feature and graph structure.
    Each convolutional layer is followed by activation functions and optional
    normalization layers to enhance the representation learning capability.

    Args:
        in_dim (int): Input dimensionality of the data.
        out_dim (int): Output dimensionality for the encoded data.
        dropout (float): Dropout probability.
        num_layers (int): The number of layers of the convolution.
        graph_conv (Type[torch.nn.Module], optional):
            The convolution module to be used for encoding the graph data
            (default: :obj:`rllm.nn.conv.graph_conv.GCNConv`).
    """

    def __init__(
        self,
        in_dim,
        out_dim,
        dropout: float = 0.5,
        num_layers: int = 2,
        graph_conv: Type[torch.nn.Module] = GCNConv,
        norm: bool = False,
    ) -> None:
        super().__init__()
        self.dropout = dropout
        self.convs = torch.nn.ModuleList()

        for _ in range(num_layers - 1):
            self.convs.append(graph_conv(in_dim=in_dim, out_dim=in_dim, normalize=norm))
        self.convs.append(graph_conv(in_dim=in_dim, out_dim=out_dim, normalize=norm))

    def forward(
        self,
        x: Tensor,
        adj: Union[Tensor, List[Tensor]],
    ) -> Tensor:
        # Full batch training or full test
        if isinstance(adj, Tensor):
            for conv in self.convs[:-1]:
                x = F.dropout(x, p=self.dropout, training=self.training)
                x = F.relu(conv(x, adj))
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = self.convs[-1](x, adj)
            return x
        # Batch training
        elif isinstance(adj, list):
            for i, conv in enumerate(self.convs[:-1]):
                x = F.dropout(x, p=self.dropout, training=self.training)
                x = F.relu(conv(x, adj[-i - 1]))
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = self.convs[-1](x, adj[0])
            return x


[docs] class BRIDGE(torch.nn.Module): r"""The BRIDGE model introduced in the `"rLLM: Relational Table Learning with LLMs" <https://arxiv.org/abs/2407.20157>`__ paper. BRIDGE is a simple RTL method based on rLLM framework, which combines table neural networks (TNNs) and graph neural networks (GNNs) to deal with multi-table data and their interrelationships, and uses "foreign keys" to build relationships and analyze them to improve the performance of multi-table joint learning tasks. Args: table_encoder (TableEncoder): Encoder for tabular data. graph_encoder (GraphEncoder): Encoder for graph data. Example: >>> from rllm.nn.models.bridge import BRIDGE, TableEncoder, GraphEncoder >>> model = BRIDGE(TableEncoder(16, 32, metadata={}), GraphEncoder(32, 8)) """ def __init__( self, table_encoder: TableEncoder, graph_encoder: GraphEncoder, ) -> None: super().__init__() self.table_encoder = table_encoder self.graph_encoder = graph_encoder
[docs] def forward( self, table: TableData, non_table: Tensor, adj: Union[Tensor, List[Tensor]], ) -> Tensor: """ First, the Table Neural Network (TNN) learns the tabular data. Second, the learned representations are concatenated with the non-tabular data. Third, the Graph Neural Network (GNN) processes the combined data. along with the adjacency matrix to learn the overall representation. Args: table (Tensor): Input tabular data. non_table (Tensor): Input non-tabular data. adj (Tensor): Adjacency matrix. Returns: Tensor: Output table embedding features. """ t_embedds = self.table_encoder(table) if non_table is not None: node_feats = torch.cat([t_embedds, non_table], dim=0) else: node_feats = t_embedds node_feats = self.graph_encoder(node_feats, adj) return node_feats[: len(table), :]