rllm.nn.conv.graph_conv.HGTConv¶
- class rllm.nn.conv.graph_conv.HGTConv(in_dim: int | Dict[str, int], out_dim: int, metadata: Tuple[List[str], List[Tuple[str, str]]], num_heads: int = 1, dropout: float = 0.0, *, aggr: str = 'sum', **kwargs)[source]¶
Bases:
MessagePassingThe Heterogeneous Graph Transformer (HGT) layer implementation with message passing, as introduced in the “Heterogeneous Graph Transformer” paper.
This layer models type-specific node interactions by relation-aware attention and aggregates messages for each destination node type.
- Parameters:
in_dim (Union[int, Dict[str, int]]) – Size of each input sample of every node type.
out_dim (int) – Size of each output sample of every node type.
metadata (Tuple[List[str], List[Tuple[str, str]]]) – The metadata of the heterogeneous graph, i.e. its node and edge types given by a list of strings and a list of string triplets, respectively.
num_heads (int, optional) – Number of multi-head attentions (default:
1).dropout (float, optional) – Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training (default:
0.0).aggr (str) – The aggregation method to use. Defaults: ‘sum’.
**kwargs (optional) – Additional arguments of
MessagePassing.
- forward(x_dict: Dict[str, Tensor], edge_index_dict: Dict[Tuple[str, str], Tensor])[source]¶
Perform heterogeneous transformer message passing by relation type.
- Parameters:
x_dict (Dict[str, Tensor]) – Mapping from node type to node features.
edge_index_dict (Dict[Tuple[str, str], Union[Tensor, SparseTensor]]) – Mapping from typed relation keys to graph connectivity.
- Returns:
Output embeddings per node type.
- Return type:
Dict[str, Tensor]
Example
>>> import torch >>> from rllm.nn.conv.graph_conv import HGTConv >>> metadata = (['a', 'b'], [('a', 'b')]) >>> conv = HGTConv(16, 8, metadata, num_heads=1) >>> x_dict = {'a': torch.randn(2, 16), 'b': torch.randn(3, 16)} >>> edge_index_dict = {('a', 'b'): torch.tensor([[0, 1], [1, 2]])} >>> out = conv(x_dict, edge_index_dict) >>> out['b'].shape torch.Size([3, 8])