Source code for rllm.nn.conv.graph_conv.han_conv

from typing import Union, Tuple, List, Dict

import torch
from torch import Tensor
from torch.sparse import Tensor as SparseTensor
from torch.nn import Parameter
import torch.nn.functional as F

from rllm.utils import seg_softmax
from rllm.nn.conv.graph_conv import MessagePassing


[docs] class HANConv(MessagePassing): r"""The Heterogeneous Graph Attention Network (HAN) model implementation with message passing, as introduced in the `"Heterogeneous Graph Attention Network" <https://arxiv.org/abs/1903.07293>`__ paper. This model leverages the power of attention mechanisms in the context of heterogeneous graphs, allowing for the learning of node representations that capture both the structure and the multifaceted nature of the graph. Args: in_dim (int or Dict[str, int]): Size of each input sample of every node type. out_dim (int): Size of each output sample of every node type. metadata (Tuple[List[str], List[Tuple[str, str, str]]]): The metadata of the heterogeneous graph, *i.e.* its node and edge types given by a list of strings and a list of string triplets, respectively. num_heads (int, optional): Number of multi-head-attentions (default: :obj:`1`). negative_slop (float): LeakyReLU angle of the negative slope (default: :obj:`0.2`). dropout (float): Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training (default: :obj:`0`). aggr (str): The aggregation method to use. Defaults: 'sum'. **kwargs (optional): Additional arguments of :class:`MessagePassing`. """ node_dim = 0 def __init__( self, in_dim: Union[int, Dict[str, int]], out_dim: int, metadata: Tuple[List[str], List[Tuple[str, str]]], num_heads: int = 1, negative_slope: float = 0.2, dropout: float = 0.0, *, aggr: str = "sum", **kwargs, ): # default use 'sum' aggregator super().__init__(aggr=aggr, aggr_kwargs=kwargs) node_types, edge_types = metadata # If in_dim is not dict, use the same in_dim for all node types if not isinstance(in_dim, dict): in_dim = {node_type: in_dim for node_type in node_types} self.in_dim = in_dim self.out_dim = out_dim self.num_heads = num_heads self.negative_slope = negative_slope self.dropout = dropout # Linear projection self.lin_dict = torch.nn.ModuleDict() for node_type, in_dim in self.in_dim.items(): self.lin_dict[node_type] = torch.nn.Linear(in_dim, out_dim) # Multi-head node attention self.lin_src = torch.nn.ParameterDict() self.lin_dst = torch.nn.ParameterDict() hidden_dim = out_dim // num_heads for edge_type in edge_types: edge_type = "__".join(edge_type) self.lin_src[edge_type] = Parameter(torch.empty(1, num_heads, hidden_dim)) self.lin_dst[edge_type] = Parameter(torch.empty(1, num_heads, hidden_dim)) # meta-path attention self.k_lin = torch.nn.Linear(out_dim, out_dim, bias=True) self.q = Parameter(torch.empty(1, out_dim)) # reset parameters self.reset_parameters() def reset_parameters(self): torch.nn.init.xavier_normal_(self.k_lin.weight) torch.nn.init.xavier_normal_(self.q) for edge_type in self.lin_src.keys(): torch.nn.init.xavier_normal_(self.lin_src[edge_type]) torch.nn.init.xavier_normal_(self.lin_dst[edge_type])
[docs] def forward( self, x_dict: Dict[str, Tensor], edge_index_dict: Dict[Tuple[str, str], Union[Tensor, SparseTensor]], return_semantic_attn_weights: bool = False, ): r"""Apply HAN message passing over typed edges and semantic fusion. Args: x_dict (Dict[str, Tensor]): Mapping from node type to node feature matrix. edge_index_dict (Dict[Tuple[str, str], Union[Tensor, SparseTensor]]): Mapping from edge type triplets to typed graph connectivity. return_semantic_attn_weights (bool): If True, also return semantic attention weights for each node type. Returns: Union[Dict[str, Tensor], Tuple[Dict[str, Tensor], Dict[str, Tensor]]]: Per-node-type output embeddings, optionally with semantic attention weights. Example: >>> import torch >>> from rllm.nn.conv.graph_conv import HANConv >>> metadata = (['u', 'i'], [('u', 'i')]) >>> conv = HANConv(16, 8, metadata) >>> x_dict = {'u': torch.randn(2, 16), 'i': torch.randn(3, 16)} >>> edge_index_dict = {('u', 'i'): torch.tensor([[0, 1], [1, 2]])} >>> out_dict = conv(x_dict, edge_index_dict) >>> out_dict['i'].shape torch.Size([3, 8]) """ H, D = self.num_heads, self.out_dim // self.num_heads node_dict, out_dict = {}, {} # Linear projection for node_type, x in x_dict.items(): node_dict[node_type] = self.lin_dict[node_type](x).view( -1, H, D ) # (N, in_dim) -> (N, H, D) out_dict[node_type] = [] # Iterate over edge types for edge_type, edge_index in edge_index_dict.items(): src_node_type, dst_node_type = edge_type edge_type = "__".join(edge_type) # multi-head node attention # (N, H, D) * (1, H, D) -> (N, H) src_x = node_dict[src_node_type] dst_x = node_dict[dst_node_type] alpha_src = (self.lin_src[edge_type] * src_x).sum(dim=-1) alpha_dst = (self.lin_dst[edge_type] * dst_x).sum(dim=-1) alpha = (alpha_src, alpha_dst) # message passing out = self.propagate( None, edge_index=edge_index, src_x=src_x, alpha=alpha, dim_size=dst_x.size(0), ) out = F.relu(out) out_dict[dst_node_type].append(out) # meta-path attention semantic_attn_dict = {} for node_type, outs in out_dict.items(): outs = torch.stack(outs, dim=0) # (num_edge_types, N, out_dim) k = torch.tanh(self.k_lin(outs)).mean( dim=1, keepdim=False ) # (num_edge_types, out_dim) attn_score = (self.q * k).sum(dim=-1, keepdim=False) # (num_edge_types) attn = F.softmax(attn_score, dim=0) outs = attn.view(-1, 1, 1) * outs out = outs.sum(dim=0, keepdim=False) out_dict[node_type] = out semantic_attn_dict[node_type] = attn if return_semantic_attn_weights: return out_dict, semantic_attn_dict return out_dict
[docs] def message_and_aggregate(self, edge_index, src_x, alpha, dim_size): edge_index, _ = self.__unify_edgeindex__(edge_index) alpha_src, alpha_dst = self.retrieve_feats(alpha, edge_index) src_x = self.retrieve_feats(src_x, edge_index, dim=0) # (E, H, D) # alpha: (E, H) alpha = alpha_src + alpha_dst alpha = F.leaky_relu(alpha, self.negative_slope) alpha = seg_softmax(alpha, edge_index[1, :], num_segs=dim_size) alpha = F.dropout(alpha, p=self.dropout, training=self.training) # msg: (E, out_dim[H*D]) msgs = src_x * alpha.unsqueeze(-1) # (E, H, D) * (E, H) -> (E, H, D) msgs = msgs.view(-1, self.out_dim) # (E, H, D) -> (E, H*D) return self.aggr_module( msgs, edge_index[1, :], dim=self.node_dim, dim_size=dim_size )
def __repr__(self): return f"{self.__class__.__name__}({self.out_dim}, num_heads={self.num_heads})"