rllm.nn.models.RelGNNModel

class rllm.nn.models.RelGNNModel(data: HeteroGraphData, col_stats_dict: Dict[str, Dict[ColType, List[Dict[StatType, Any]]]], atomic_routes_edge_types: List[Tuple[str, str, str]], hidden_dim: int, out_dim: int, tnn_hidden_dim: int = 128, tnn_num_layers: int = 4, relgnn_aggr: str = 'mean', relgnn_num_layers: int = 2, relgnn_num_heads: int = 1, relgnn_simplified_MP: bool = True, use_temporal_encoder: bool = True, reg_task: bool = False)[source]

Bases: Module

The relational table learning model with RelGNN as the HGNN backbone from paper “RelGNN: Composite Message Passing for Relational Deep Learning”. The implementation combines TableResNet as the TNN component and RelGNN as the HGNN component following the original paper, with an optional temporal encoding module.

Parameters:
  • data (HeteroGraphData) – The heterogeneous graph data.

  • col_stats_dict (Dict[str, Dict[ColType, List[Dict[StatType, Any]]]]) – The column statistics dictionary for each table.

  • atomic_routes_edge_types (List[Tuple]) – The list of atomic message passing routes produced by get_atomic_routes().

  • hidden_dim (int) – The hidden dimension.

  • out_dim (int) – The output dimension.

  • tnn_hidden_dim (int) – The hidden dimension for TNN. (default: 128)

  • tnn_num_layers (int) – The number of layers for TNN. (default: 4)

  • relgnn_aggr (str) – The aggregation method for RelGNN. (default: 'mean')

  • relgnn_num_layers (int) – The number of layers for RelGNN. (default: 2)

  • relgnn_num_heads (int) – The number of attention heads for RelGNN. (default: 1)

  • relgnn_simplified_MP (bool) – Whether to use simplified message passing in RelGNN. (default: True)

  • use_temporal_encoder (bool) – Whether to use the temporal encoder. (default: True)

  • reg_task (bool) – If True, uses a regression output head with GELU activation instead of classification. (default: False)

Example

>>> from rllm.nn.models import RelGNNModel
>>> from rllm.utils import get_atomic_routes
>>> routes = get_atomic_routes(hdata.edge_types)
>>> model = RelGNNModel(
...     data=hdata,
...     col_stats_dict=col_stats_dict,
...     atomic_routes_edge_types=routes,
...     hidden_dim=128,
...     out_dim=1,
... )
forward(batch: HeteroGraphData, target_table: str) Tensor[source]

Run table encoding, optional temporal encoding, RelGNN propagation, and the output head.

Parameters:
  • batch (HeteroGraphData) – Batched heterogeneous relational graph.

  • target_table (str) – The node type to predict.

Returns:

Output predictions for seed nodes in the target table, of shape [batch_size, out_dim].

Return type:

Tensor

reset_parameters()[source]

Reset all learnable parameters of the module.