rllm.nn.models.RelGNN

class rllm.nn.models.RelGNN(node_types: List[str], atomic_routes_edge_types: List[Tuple], hidden_dim: int, aggr: str = 'sum', num_layers: int = 2, num_heads: int = 1, simplified_MP=True)[source]

Bases: Module

The RelGNN model is a GNN framework specifically designed to leverage the unique structural characteristics of the graphs built from relational databases from paper “RelGNN: Composite Message Passing for Relational Deep Learning”.

Parameters:
  • node_types (List[str]) – The list of node types in the graph.

  • atomic_routes_edge_types (List[Tuple]) – The list of atomic message passing routes produced by get_atomic_routes().

  • hidden_dim (int) – The number of hidden dimensions.

  • aggr (str) – The aggregation method across parallel routes. (default: 'sum')

  • num_layers (int) – The number of message passing layers. (default: 2)

  • num_heads (int) – The number of attention heads. (default: 1)

  • simplified_MP (bool) – If True, skips routes whose edge index is absent in the batch. (default: True)

Example

>>> from rllm.nn.models import RelGNN
>>> from rllm.utils import get_atomic_routes
>>> routes = get_atomic_routes(hdata.edge_types)
>>> model = RelGNN(
...     node_types=hdata.node_types,
...     atomic_routes_edge_types=routes,
...     hidden_dim=128,
... )
forward(x_dict: Dict[str, Tensor], edge_index_dict: Dict[Tuple[str, str, str], Tensor]) Dict[str, Tensor][source]

Apply stacked relational message passing layers.

Parameters:
  • x_dict (Dict[str, Tensor]) – Input node features by node type.

  • edge_index_dict (Dict[Tuple[str, str, str], Tensor]) – Edge indices by route.

Returns:

Updated node embeddings.

Return type:

Dict[str, Tensor]

heteroconv_forward(conv_dict: ModuleDict, x_dict: Dict[str, Tensor], edge_index_dict: Dict[Tuple[str, str, str], Tensor]) Dict[str, Tensor][source]

Apply one layer of heterogeneous convolution across all atomic routes and aggregate the results per destination node type.

Parameters:
  • conv_dict (ModuleDict) – Convolution modules keyed by route string.

  • x_dict (Dict[str, Tensor]) – Input node features by node type.

  • edge_index_dict (Dict[Tuple[str, str, str], Tensor]) – Edge indices by edge type.

Returns:

Updated node embeddings by node type.

Return type:

Dict[str, Tensor]

reset_parameters()[source]

Reset all learnable parameters of the module.