rllm.nn.models.RelGNN¶
- class rllm.nn.models.RelGNN(node_types: List[str], atomic_routes_edge_types: List[Tuple], hidden_dim: int, aggr: str = 'sum', num_layers: int = 2, num_heads: int = 1, simplified_MP=True)[source]¶
Bases:
ModuleThe RelGNN model is a GNN framework specifically designed to leverage the unique structural characteristics of the graphs built from relational databases from paper “RelGNN: Composite Message Passing for Relational Deep Learning”.
- Parameters:
node_types (List[str]) – The list of node types in the graph.
atomic_routes_edge_types (List[Tuple]) – The list of atomic message passing routes produced by
get_atomic_routes().hidden_dim (int) – The number of hidden dimensions.
aggr (str) – The aggregation method across parallel routes. (default:
'sum')num_layers (int) – The number of message passing layers. (default:
2)num_heads (int) – The number of attention heads. (default:
1)simplified_MP (bool) – If
True, skips routes whose edge index is absent in the batch. (default:True)
Example
>>> from rllm.nn.models import RelGNN >>> from rllm.utils import get_atomic_routes >>> routes = get_atomic_routes(hdata.edge_types) >>> model = RelGNN( ... node_types=hdata.node_types, ... atomic_routes_edge_types=routes, ... hidden_dim=128, ... )
- forward(x_dict: Dict[str, Tensor], edge_index_dict: Dict[Tuple[str, str, str], Tensor]) Dict[str, Tensor][source]¶
Apply stacked relational message passing layers.
- Parameters:
x_dict (Dict[str, Tensor]) – Input node features by node type.
edge_index_dict (Dict[Tuple[str, str, str], Tensor]) – Edge indices by route.
- Returns:
Updated node embeddings.
- Return type:
Dict[str, Tensor]
- heteroconv_forward(conv_dict: ModuleDict, x_dict: Dict[str, Tensor], edge_index_dict: Dict[Tuple[str, str, str], Tensor]) Dict[str, Tensor][source]¶
Apply one layer of heterogeneous convolution across all atomic routes and aggregate the results per destination node type.
- Parameters:
conv_dict (ModuleDict) – Convolution modules keyed by route string.
x_dict (Dict[str, Tensor]) – Input node features by node type.
edge_index_dict (Dict[Tuple[str, str, str], Tensor]) – Edge indices by edge type.
- Returns:
Updated node embeddings by node type.
- Return type:
Dict[str, Tensor]