rllm.nn.models.RDL

class rllm.nn.models.RDL(data: HeteroGraphData, col_stats_dict: Dict[str, Dict[ColType, List[Dict[StatType, Any]]]], hidden_dim: int, out_dim: int, tnn_hidden_dim: int = 128, tnn_num_layers: int = 4, hgnn_aggr: str = 'mean', hgnn_num_layers: int = 2, use_temporal_encoder: bool = False, reg_task: bool = False)[source]

Bases: Module

Relational Deep Learning (RDL) model from paper “RelBench: A Benchmark for Deep Learning on Relational Databases”. The RDL model combines Table Neural Networks (TNNs) and Heterogeneous Graph Neural Networks (HGNNs) to effectively learn from multi-table relational data. We consistently use TableResNet as the TNN component and HeteroSAGE as the HGNN component following the original paper with temporal encoding module.

Parameters:
  • data (HeteroGraphData) – The heterogeneous graph data.

  • col_stats_dict (Dict[str, Dict[ColType, List[Dict[StatType, Any]]]]) – The column statistics dictionary for each table.

  • hidden_dim (int) – The hidden dimension.

  • out_dim (int) – The output dimension.

  • tnn_hidden_dim (int) – The hidden dimension for TNN. (default: 128)

  • tnn_num_layers (int) – The number of layers for TNN. (default: 4)

  • hgnn_aggr (str) – The aggregation method for HGNN. (default: 'mean')

  • hgnn_num_layers (int) – The number of layers for HGNN. (default: 2)

  • use_temporal_encoder (bool) – Whether to use the temporal encoder. (default: False)

  • reg_task (bool) – If True, uses a regression output head with GELU activation instead of classification. (default: False)

Example

>>> from rllm.nn.models import RDL
>>> model = RDL(
...     data=hdata,
...     col_stats_dict=col_stats_dict,
...     hidden_dim=128,
...     out_dim=1,
... )
forward(batch: HeteroGraphData, target_table: str) Tensor[source]

Run table encoding, optional temporal encoding, HGNN propagation, and the output head.

Parameters:
  • batch (HeteroGraphData) – Batched heterogeneous relational graph data.

  • target_table (str) – The node type to predict.

Returns:

Output predictions for seed nodes in the target table, of shape [batch_size, out_dim].

Return type:

Tensor

reset_parameters()[source]

Reset all learnable parameters of the module.