Source code for rllm.nn.models.rdl

from typing import Any, Dict, List

import torch
from torch import Tensor
from torch.nn import ModuleDict

from rllm.types import ColType, StatType
from rllm.data import HeteroGraphData
from rllm.nn.models import TableResNet, HeteroSAGE
from rllm.nn.encoder import HeteroTemporalEncoder


[docs] class RDL(torch.nn.Module): r"""Relational Deep Learning (RDL) model from paper `"RelBench: A Benchmark for Deep Learning on Relational Databases" <https://arxiv.org/abs/2407.20060>`_. The RDL model combines Table Neural Networks (TNNs) and Heterogeneous Graph Neural Networks (HGNNs) to effectively learn from multi-table relational data. We consistently use TableResNet as the TNN component and HeteroSAGE as the HGNN component following the original paper with temporal encoding module. Args: data (HeteroGraphData): The heterogeneous graph data. col_stats_dict (Dict[str, Dict[ColType, List[Dict[StatType, Any]]]]): The column statistics dictionary for each table. hidden_dim (int): The hidden dimension. out_dim (int): The output dimension. tnn_hidden_dim (int): The hidden dimension for TNN. (default: :obj:`128`) tnn_num_layers (int): The number of layers for TNN. (default: :obj:`4`) hgnn_aggr (str): The aggregation method for HGNN. (default: :obj:`'mean'`) hgnn_num_layers (int): The number of layers for HGNN. (default: :obj:`2`) use_temporal_encoder (bool): Whether to use the temporal encoder. (default: :obj:`False`) reg_task (bool): If :obj:`True`, uses a regression output head with :class:`~torch.nn.GELU` activation instead of classification. (default: :obj:`False`) Example: >>> from rllm.nn.models import RDL >>> model = RDL( ... data=hdata, ... col_stats_dict=col_stats_dict, ... hidden_dim=128, ... out_dim=1, ... ) """ def __init__( self, data: HeteroGraphData, col_stats_dict: Dict[str, Dict[ColType, List[Dict[StatType, Any]]]], hidden_dim: int, out_dim: int, # TNN args tnn_hidden_dim: int = 128, tnn_num_layers: int = 4, # HGNN args hgnn_aggr: str = "mean", hgnn_num_layers: int = 2, # Temporal Encoder args use_temporal_encoder: bool = False, # Output head args reg_task: bool = False, ): super().__init__() # validate input for node_type in data.node_types: assert ( node_type in col_stats_dict ), f"Node type {node_type} not found in col_stats_dict" # build modules self.TNN_DICT = ModuleDict( { node_type: TableResNet( hidden_dim=tnn_hidden_dim, out_dim=hidden_dim, num_layers=tnn_num_layers, metadata=col_stats_dict[node_type], ) for node_type in data.node_types } ) self.use_temporal_encoder = use_temporal_encoder if use_temporal_encoder: self.TEMPORAL_ENCODER = HeteroTemporalEncoder( node_types=[ node_type for node_type in data.node_types if "time" in data[node_type] ], channels=hidden_dim, ) else: self.TEMPORAL_ENCODER = self.register_parameter("TEMPORAL_ENCODER", None) self.HGNN = HeteroSAGE( node_types=data.node_types, edge_types=data.edge_types, hidden_dim=hidden_dim, aggr=hgnn_aggr, num_layers=hgnn_num_layers, ) if reg_task: self.OUTPUT_HEAD = torch.nn.Sequential( torch.nn.Linear(hidden_dim, out_dim), torch.nn.GELU(), torch.nn.Linear(out_dim, out_dim), ) else: self.OUTPUT_HEAD = torch.nn.Sequential( torch.nn.Linear(hidden_dim, out_dim), torch.nn.BatchNorm1d(out_dim), torch.nn.ReLU(), torch.nn.Linear(out_dim, out_dim), ) self.reset_parameters()
[docs] def reset_parameters(self): r"""Reset all learnable parameters of the module.""" for tnn in self.TNN_DICT.values(): tnn.reset_parameters() if self.use_temporal_encoder: self.TEMPORAL_ENCODER.reset_parameters() self.HGNN.reset_parameters() for module in self.OUTPUT_HEAD.children(): if hasattr(module, "reset_parameters"): module.reset_parameters()
[docs] def forward( self, batch: HeteroGraphData, target_table: str, ) -> Tensor: r"""Run table encoding, optional temporal encoding, HGNN propagation, and the output head. Args: batch (HeteroGraphData): Batched heterogeneous relational graph data. target_table (str): The node type to predict. Returns: Tensor: Output predictions for seed nodes in the target table, of shape :obj:`[batch_size, out_dim]`. """ seed_time = batch[target_table].seed_time # 1. apply TNN to each node type (table) x_dict = {} for node_type, node_storage in batch.node_items(): x_dict[node_type] = self.TNN_DICT[node_type](node_storage.table) # 2. (optional) apply TEMPORAL_ENCODER to each temporal node type if self.use_temporal_encoder: assert hasattr(batch, "time_dict") assert hasattr(batch, "batch_dict") rel_time_dict = self.TEMPORAL_ENCODER( seed_time, batch.time_dict, batch.batch_dict ) for node_type, rel_time in rel_time_dict.items(): x_dict[node_type] = x_dict[node_type] + rel_time # 3. apply HGNN x_dict = self.HGNN(x_dict, batch.edge_index_dict) # 4. apply OUTPUT_HEAD to target table return self.OUTPUT_HEAD(x_dict[target_table][: seed_time.size(0)])