rllm.nn.models.RDL¶
- class rllm.nn.models.RDL(data: HeteroGraphData, col_stats_dict: Dict[str, Dict[ColType, List[Dict[StatType, Any]]]], hidden_dim: int, out_dim: int, tnn_hidden_dim: int = 128, tnn_num_layers: int = 4, hgnn_aggr: str = 'mean', hgnn_num_layers: int = 2, use_temporal_encoder: bool = False, reg_task: bool = False)[source]¶
Bases:
ModuleRelational Deep Learning (RDL) model from paper “RelBench: A Benchmark for Deep Learning on Relational Databases”. The RDL model combines Table Neural Networks (TNNs) and Heterogeneous Graph Neural Networks (HGNNs) to effectively learn from multi-table relational data. We consistently use TableResNet as the TNN component and HeteroSAGE as the HGNN component following the original paper with temporal encoding module.
- Parameters:
data (HeteroGraphData) – The heterogeneous graph data.
col_stats_dict (Dict[str, Dict[ColType, List[Dict[StatType, Any]]]]) – The column statistics dictionary for each table.
hidden_dim (int) – The hidden dimension.
out_dim (int) – The output dimension.
tnn_hidden_dim (int) – The hidden dimension for TNN. (default:
128)tnn_num_layers (int) – The number of layers for TNN. (default:
4)hgnn_aggr (str) – The aggregation method for HGNN. (default:
'mean')hgnn_num_layers (int) – The number of layers for HGNN. (default:
2)use_temporal_encoder (bool) – Whether to use the temporal encoder. (default:
False)reg_task (bool) – If
True, uses a regression output head withGELUactivation instead of classification. (default:False)
Example
>>> from rllm.nn.models import RDL >>> model = RDL( ... data=hdata, ... col_stats_dict=col_stats_dict, ... hidden_dim=128, ... out_dim=1, ... )
- forward(batch: HeteroGraphData, target_table: str) Tensor[source]¶
Run table encoding, optional temporal encoding, HGNN propagation, and the output head.
- Parameters:
batch (HeteroGraphData) – Batched heterogeneous relational graph data.
target_table (str) – The node type to predict.
- Returns:
Output predictions for seed nodes in the target table, of shape
[batch_size, out_dim].- Return type:
Tensor