rllm.nn.conv.graph_conv.HGTConv

class rllm.nn.conv.graph_conv.HGTConv(in_dim: int | Dict[str, int], out_dim: int, metadata: Tuple[List[str], List[Tuple[str, str]]], num_heads: int = 1, dropout: float = 0.0, *, aggr: str = 'sum', **kwargs)[source]

Bases: MessagePassing

The Heterogeneous Graph Transformer (HGT) layer implementation with message passing, as introduced in the “Heterogeneous Graph Transformer” paper.

This layer models type-specific node interactions by relation-aware attention and aggregates messages for each destination node type.

Parameters:
  • in_dim (Union[int, Dict[str, int]]) – Size of each input sample of every node type.

  • out_dim (int) – Size of each output sample of every node type.

  • metadata (Tuple[List[str], List[Tuple[str, str]]]) – The metadata of the heterogeneous graph, i.e. its node and edge types given by a list of strings and a list of string triplets, respectively.

  • num_heads (int, optional) – Number of multi-head attentions (default: 1).

  • dropout (float, optional) – Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training (default: 0.0).

  • aggr (str) – The aggregation method to use. Defaults: ‘sum’.

  • **kwargs (optional) – Additional arguments of MessagePassing.

forward(x_dict: Dict[str, Tensor], edge_index_dict: Dict[Tuple[str, str], Tensor])[source]

Perform heterogeneous transformer message passing by relation type.

Parameters:
  • x_dict (Dict[str, Tensor]) – Mapping from node type to node features.

  • edge_index_dict (Dict[Tuple[str, str], Union[Tensor, SparseTensor]]) – Mapping from typed relation keys to graph connectivity.

Returns:

Output embeddings per node type.

Return type:

Dict[str, Tensor]

Example

>>> import torch
>>> from rllm.nn.conv.graph_conv import HGTConv
>>> metadata = (['a', 'b'], [('a', 'b')])
>>> conv = HGTConv(16, 8, metadata, num_heads=1)
>>> x_dict = {'a': torch.randn(2, 16), 'b': torch.randn(3, 16)}
>>> edge_index_dict = {('a', 'b'): torch.tensor([[0, 1], [1, 2]])}
>>> out = conv(x_dict, edge_index_dict)
>>> out['b'].shape
torch.Size([3, 8])
message_and_aggregate(edge_index, q_dst, k_src, v_src, rel, dim_size)[source]

The message and aggregation interface to be overridden by subclasses.