rllm.nn.conv.graph_conv.HANConv

class rllm.nn.conv.graph_conv.HANConv(in_dim: int | Dict[str, int], out_dim: int, metadata: Tuple[List[str], List[Tuple[str, str]]], num_heads: int = 1, negative_slope: float = 0.2, dropout: float = 0.0, *, aggr: str = 'sum', **kwargs)[source]

Bases: MessagePassing

The Heterogeneous Graph Attention Network (HAN) model implementation with message passing, as introduced in the “Heterogeneous Graph Attention Network” paper.

This model leverages the power of attention mechanisms in the context of heterogeneous graphs, allowing for the learning of node representations that capture both the structure and the multifaceted nature of the graph.

Parameters:
  • in_dim (int or Dict[str, int]) – Size of each input sample of every node type.

  • out_dim (int) – Size of each output sample of every node type.

  • metadata (Tuple[List[str], List[Tuple[str, str, str]]]) – The metadata of the heterogeneous graph, i.e. its node and edge types given by a list of strings and a list of string triplets, respectively.

  • num_heads (int, optional) – Number of multi-head-attentions (default: 1).

  • negative_slop (float) – LeakyReLU angle of the negative slope (default: 0.2).

  • dropout (float) – Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training (default: 0).

  • aggr (str) – The aggregation method to use. Defaults: ‘sum’.

  • **kwargs (optional) – Additional arguments of MessagePassing.

forward(x_dict: Dict[str, Tensor], edge_index_dict: Dict[Tuple[str, str], Tensor], return_semantic_attn_weights: bool = False)[source]

Apply HAN message passing over typed edges and semantic fusion.

Parameters:
  • x_dict (Dict[str, Tensor]) – Mapping from node type to node feature matrix.

  • edge_index_dict (Dict[Tuple[str, str], Union[Tensor, SparseTensor]]) – Mapping from edge type triplets to typed graph connectivity.

  • return_semantic_attn_weights (bool) – If True, also return semantic attention weights for each node type.

Returns:

Per-node-type output embeddings, optionally with semantic attention weights.

Return type:

Union[Dict[str, Tensor], Tuple[Dict[str, Tensor], Dict[str, Tensor]]]

Example

>>> import torch
>>> from rllm.nn.conv.graph_conv import HANConv
>>> metadata = (['u', 'i'], [('u', 'i')])
>>> conv = HANConv(16, 8, metadata)
>>> x_dict = {'u': torch.randn(2, 16), 'i': torch.randn(3, 16)}
>>> edge_index_dict = {('u', 'i'): torch.tensor([[0, 1], [1, 2]])}
>>> out_dict = conv(x_dict, edge_index_dict)
>>> out_dict['i'].shape
torch.Size([3, 8])
message_and_aggregate(edge_index, src_x, alpha, dim_size)[source]

The message and aggregation interface to be overridden by subclasses.